微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

自定义数据集和数据加载器

如何解决自定义数据集和数据加载器

我是 pytorch 的新手。 我的大数据集由两个 txt 文件组成,一个用于数据,另一个用于目标数据。 在训练文件中每行是长度为 340 的列表,在目标中每行是长度为 136 的列表。

我想问一下如何定义我的数据集,以便我可以使用 DataLoader 加载我的数据来训练 pytorch 模型?

我希望你回答

解决方法

Dataset 中的

torch.utils.data 是表示数据集的抽象类。您的自定义数据集应继承 Dataset 并覆盖以下方法:

__len__() 使 len(dataset) 返回数据集的大小。
__getitem__() 支持索引,使得 dataset[i] 可用于获取第 i 个样本

例如编写自定义数据集
我已经为您编写了一个通用的自定义数据加载器作为您的问题陈述。
这里 data.txt 有数据,label.txt 有标签。

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self):
        
       
        with open('data.txt','r') as f:
                self.data_info = f.readlines()
        
        with open('label.txt','r') as f:
                self.label_info = f.readlines()        


    def __getitem__(self,index):
        
        single_data = self.data_info[index].rstrip('\n')
        

        single_label = self.label_info[index].rstrip('\n')

        return ( single_data,single_label)

    def __len__(self):
        return len(self.data_info)
# Testing 
d = CustomDataset()
print(d[1]) # should output data along with label

这将是您案例的基础,但必须进行一些与您的案例相匹配的更改。

注意:您必须根据数据集进行必要的更改

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。