微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

pytorch 数据加载器:沿数据加载器输出的一个维度连接批处理

如何解决pytorch 数据加载器:沿数据加载器输出的一个维度连接批处理

我的数据集的 __getitem__ 函数返回一个 torch.stft() M x N x D 张量,其中 N 是具有可变长度的音频输入系列。每一项都在 __getitem__ 函数内读取。我希望将批次沿第二维 (N) 连接起来。因此,通过迭代数据加载器,我将获得形状为:M x (N x batch_size) x D 的数据。 这个问题有没有可能的解决方案?

解决方法

您可以使用传递给 DataLoader 的自定义整理函数来执行此操作:

import torch
from torch.utils.data import DataLoader

M = 20
D = 12
N = 30
a = torch.rand((M,N,D))
b = torch.rand((M,D))

def my_collate(batch):
    c = torch.stack(batch,dim=1)
    return c.permute(0,2,1,3)

c = my_collate([a,b]) # output shape  MxNxBxD-> torch.Size([20,30,12])

然后传递给DataLoader:

loader = DataLoader(dataset=datasetObject,batch_size=1,collate_fn=my_collate)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。