微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用DataLoaders在PyTorch中验证数据集

如何解决使用DataLoaders在PyTorch中验证数据集

我想在PyTorch和Torchvision中加载MNIST数据集,将其分为训练,验证和测试部分。到目前为止,我有

def load_dataset():
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            '/data/',train=True,download=True,transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor()])),batch_size=batch_size_train,shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            '/data/',train=False,batch_size=batch_size_test,shuffle=True)

如果训练数据集位于DataLoader中,如何将其分为训练和验证?我想将训练数据集中的最后10000个示例用作验证数据集(我知道我应该做CV以获得更准确的结果,我只想在此处进行快速验证)。

解决方法

在PyTorch中将训练数据集分为训练和验证过程比原来要困难得多。

首先,将训练集分为训练和验证子集(类Subset),它们是数据集(类Dataset):

train_subset,val_subset = torch.utils.data.random_split(
        train,[50000,10000],generator=torch.Generator().manual_seed(1))

然后从这些数据集中获取实际数据:

X_train = train_subset.dataset.data[train_subset.indices]
y_train = train_subset.dataset.targets[train_subset.indices]

X_val = val_subset.dataset.data[val_subset.indices]
y_val = val_subset.dataset.targets[val_subset.indices]

请注意,通过这种方式我们没有没有Dataset对象,因此我们不能使用DataLoader对象进行批量训练。如果要使用DataLoader,它们可以直接与子集一起使用:

train_loader = DataLoader(dataset=train_subset,shuffle=True,batch_size=BATCH_SIZE)
val_loader = DataLoader(dataset=train_subset,shuffle=False,batch_size=BATCH_SIZE)
,

如果您想确保您的分组具有平衡的类别,您可以使用 train_test_split 中的 sklearn

import torchvision
from torch.utils.data import DataLoader,Subset
from sklearn.model_selection import train_test_split

VAL_SIZE = 0.1
BATCH_SIZE = 64

mnist_train = torchvision.datasets.MNIST(
    '/data/',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
)
mnist_test = torchvision.datasets.MNIST(
    '/data/',train=False,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
)

# generate indices: instead of the actual data we pass in integers instead
train_indices,val_indices,_,_ = train_test_split(
    range(len(mnist_train)),mnist_train.targets,stratify=mnist_train.targets,test_size=VAL_SIZE,)

# generate subset based on indices
train_split = Subset(mnist_train,train_indices)
val_split = Subset(mnist_train,val_indices)

# create batches
train_batches = DataLoader(train_split,batch_size=BATCH_SIZE,shuffle=True)
val_batches = DataLoader(val_split,shuffle=True)
test_batches = DataLoader(mnist_test,shuffle=True)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。