微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何仅从PyTorch的FashionMNIST数据集中获取特定的类?

如何解决如何仅从PyTorch的FashionMNIST数据集中获取特定的类?

FashionMNIST数据集具有10个不同的输出类别。如何仅使用特定类来获取此数据集的子集?就我而言,我只想要运动鞋,套头衫,凉鞋和衬衫类的图片(它们的类别分别为7,2,5和6)。

这就是我加载数据集的方式。

train_dataset_full = torchvision.datasets.FashionMNIST(data_folder,train = True,download = True,transform = transforms.ToTensor())

我遵循的方法如下。 依次遍历数据集,然后将返回的元组中的第一个元素(即类)与我所需的类进行比较。我被困在这里。如果返回的值为true,如何将这个观察值追加/添加到空数据集?

sneaker = 0
pullover = 0
sandal = 0
shirt = 0
for i in range(60000):
    if train_dataset_full[i][1] == 7:
        sneaker += 1
    elif train_dataset_full[i][1] == 2:
        pullover += 1
    elif train_dataset_full[i][1] == 5:
        sandal += 1
    elif train_dataset_full[i][1] == 6:
        shirt += 1

现在,我想代替sneaker += 1pullover += 1sandal += 1shirt += 1做类似empty_dataset.append(train_dataset_full[i])或类似的事情。

如果上述方法不正确,请提出另一种方法

解决方法

最后找到了答案。

dataset_full = torchvision.datasets.FashionMNIST(data_folder,train = True,download = True,transform = transforms.ToTensor())
# Selecting classes 7,2,5 and 6
idx = (dataset_full.targets==7) | (dataset_full.targets==2) | (dataset_full.targets==5) | (dataset_full.targets==6)
dataset_full.targets = dataset_full.targets[idx]
dataset_full.data = dataset_full.data[idx]
,

您可以使用列表理解来匹配标签。例如

idx = dataset.train_labels == 1
dataset.train_labels = dataset.train_labels[idx]

那只会选择您想要的标签。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。