微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

分类器对象没有属性 train

如何解决分类器对象没有属性 train

我在 datatset.py 模块中收到此错误,该模块用于为 siamese 网络创建正负对。知道为什么给我这个错误吗? 我将这个 Siamese MNIST 称为我在训练数据集上的主要模块“itr1.py”,用于我自己的自定义数据集。

Error log

siamese_train_dataset = SiameseMNIST(train_dataset)
siamese_test_dataset = SiameseMNIST(test_dataset)


class SiameseMNIST(Dataset):
    def __init__(self,dataset,train=True):
        self.dataset = dataset
        self.train = self.dataset.train
        self.transform = self.dataset.transform

        if self.train:
            self.train_labels = self.dataset.train_labels
            self.train_data = self.dataset.train_data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}
        else:
            # generate fixed pairs for testing
            self.test_labels = self.dataset.test_labels
            self.test_data = self.dataset.test_data
            self.labels_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            positive_pairs = [[i,random_state.choice(self.label_to_indices[self.test_labels[i].item()]),1]
                              for i in range(0,len(self.test_data),2)]

            negative_pairs = [[i,random_state.choice(self.label_to_indices[
                                                       np.random.choice(
                                                           list(self.labels_set - set([self.test_labels[i].item()]))
                                                       )
                                                   ]),0]
                              for i in range(1,2)]
            self.test_pairs = positive_pairs + negative_pairs

    def __getitem__(self,index):
        if self.train:
            target = np.random.randint(0,2)
            img1,label1 = self.train_data[index],self.train_labels[index].item()
            if target == 1:
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                siamese_label = np.random.choice(list(self.labels_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])
            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.fromarray(img1.numpy(),mode='L')
        img2 = Image.fromarray(img2.numpy(),mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return (img1,img2),target

    def __len__(self):
        return len(self.dataset)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。