微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Pytorch 成对连接张量

如何解决Pytorch 成对连接张量

我想以批处理方式计算特定维度上的成对串联。

例如

x = torch.tensor([[[0],[1],[2]],[[3],[4],[5]]])
x.shape = torch.Size([2,3,1])

我想得到 y 使得 y一个维度上所有向量对的串联,即:

y = torch.tensor([[[[0,0],[0,1],2]],[[1,[1,[[2,[2,2]]],[[[3,3],[3,4],5]],[[4,[4,[[5,[5,5]]]])

y.shape = torch.Size([2,2])

因此,本质上,对于每个 x[i,:],您生成所有向量对,并将它们连接到最后一个维度。 有没有直接的方法来做到这一点?

解决方法

一种可能的方法是:

    all_ordered_idx_pairs = torch.cartesian_prod(torch.tensor(range(x.shape[1])),torch.tensor(range(x.shape[1])))
    y = torch.stack([x[i][all_ordered_idx_pairs] for i in range(x.shape[0])])

对张量进行整形后:

y = y.view(x.shape[0],x.shape[1],-1)

你得到:

y = torch.tensor([[[[0,0],[0,1],2]],[[1,[1,[[2,[2,2]]],[[[3,3],[3,4],5]],[[4,[4,[[5,[5,5]]]])

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。