如何解决分类器对象没有属性 train
我在 datatset.py 模块中收到此错误,该模块用于为 siamese 网络创建正负对。知道为什么给我这个错误吗? 我将这个 Siamese MNIST 称为我在训练数据集上的主要模块“itr1.py”,用于我自己的自定义数据集。
siamese_train_dataset = SiameseMNIST(train_dataset)
siamese_test_dataset = SiameseMNIST(test_dataset)
class SiameseMNIST(Dataset):
def __init__(self,dataset,train=True):
self.dataset = dataset
self.train = self.dataset.train
self.transform = self.dataset.transform
if self.train:
self.train_labels = self.dataset.train_labels
self.train_data = self.dataset.train_data
self.labels_set = set(self.train_labels.numpy())
self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
for label in self.labels_set}
else:
# generate fixed pairs for testing
self.test_labels = self.dataset.test_labels
self.test_data = self.dataset.test_data
self.labels_set = set(self.test_labels.numpy())
self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
for label in self.labels_set}
random_state = np.random.RandomState(29)
positive_pairs = [[i,random_state.choice(self.label_to_indices[self.test_labels[i].item()]),1]
for i in range(0,len(self.test_data),2)]
negative_pairs = [[i,random_state.choice(self.label_to_indices[
np.random.choice(
list(self.labels_set - set([self.test_labels[i].item()]))
)
]),0]
for i in range(1,2)]
self.test_pairs = positive_pairs + negative_pairs
def __getitem__(self,index):
if self.train:
target = np.random.randint(0,2)
img1,label1 = self.train_data[index],self.train_labels[index].item()
if target == 1:
siamese_index = index
while siamese_index == index:
siamese_index = np.random.choice(self.label_to_indices[label1])
else:
siamese_label = np.random.choice(list(self.labels_set - set([label1])))
siamese_index = np.random.choice(self.label_to_indices[siamese_label])
img2 = self.train_data[siamese_index]
else:
img1 = self.test_data[self.test_pairs[index][0]]
img2 = self.test_data[self.test_pairs[index][1]]
target = self.test_pairs[index][2]
img1 = Image.fromarray(img1.numpy(),mode='L')
img2 = Image.fromarray(img2.numpy(),mode='L')
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2)
return (img1,img2),target
def __len__(self):
return len(self.dataset)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。