微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何使用PyTorch从本地目录导入MNIST数据集

如何解决如何使用PyTorch从本地目录导入MNIST数据集

我正在PyTorch中编写一个众所周知的问题MNIST database of handwritten digits代码。我从主要网站下载了训练和测试数据集,包括标记的数据集。数据集格式为t10k-images-idx3-ubyte.gz提取后为t10k-images-idx3-ubyte。我的数据集文件夹看起来像

MINST
 Data
  train-images-idx3-ubyte.gz
  train-labels-idx1-ubyte.gz
  t10k-images-idx3-ubyte.gz
  t10k-labels-idx1-ubyte.gz

现在,我写了一个代码来加载像波纹管这样的数据

def load_dataset():
    data_path = "/home/MNIST/Data/"
    xy_trainPT = torchvision.datasets.ImageFolder(
        root=data_path,transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        xy_trainPT,batch_size=64,num_workers=0,shuffle=True
    )
    return train_loader

我的代码显示Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp

如何解决此问题,我还想检查是否从数据集中加载了我的图像(只是一个数字包含前5个图像)?

解决方法

欢迎来到stackoverflow!

MNIST数据集不存储为图像,而是以二进制格式(如ubyte扩展名所示)存储。因此,ImageFolder不是您想要的类型数据集。相反,您将需要使用MNIST dataset class。如果您还没有下载数据,它甚至可以下载:)

这是一个数据集类,因此只需使用正确的root路径实例化,然后将其作为数据加载器的参数,一切就可以正常工作。

如果要检查图像,只需使用数据加载器的get方法,然后将结果保存为png文件(您可能需要先将张量转换为numpy数组)

,

阅读此Extract images from .idx3-ubyte file or GZIP via Python

更新

您可以使用此格式导入数据

xy_trainPT = torchvision.datasets.MNIST(
    root="~/Handwritten_Deep_L/",train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),)

现在,download=True发生的事情是,您的代码将首先在根目录(给定路径)中检查是否包含任何数据集。

如果no,则将从网络上下载数据集。

如果yes此路径已经包含一个数据集,那么您的代码将使用现有的数据集运行,而不会从互联网上下载。

您可以检查,首先给出一个路径without any dataset(将从Internet下载数据),然后给出另一个路径which already contains dataset则不会下载数据。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。