如何解决Pytorch 成对连接张量
我想以批处理方式计算特定维度上的成对串联。
例如
x = torch.tensor([[[0],[1],[2]],[[3],[4],[5]]])
x.shape = torch.Size([2,3,1])
我想得到 y
使得 y
是一个维度上所有向量对的串联,即:
y = torch.tensor([[[[0,0],[0,1],2]],[[1,[1,[[2,[2,2]]],[[[3,3],[3,4],5]],[[4,[4,[[5,[5,5]]]])
y.shape = torch.Size([2,2])
因此,本质上,对于每个 x[i,:]
,您生成所有向量对,并将它们连接到最后一个维度。
有没有直接的方法来做到这一点?
解决方法
一种可能的方法是:
all_ordered_idx_pairs = torch.cartesian_prod(torch.tensor(range(x.shape[1])),torch.tensor(range(x.shape[1])))
y = torch.stack([x[i][all_ordered_idx_pairs] for i in range(x.shape[0])])
对张量进行整形后:
y = y.view(x.shape[0],x.shape[1],-1)
你得到:
y = torch.tensor([[[[0,0],[0,1],2]],[[1,[1,[[2,[2,2]]],[[[3,3],[3,4],5]],[[4,[4,[[5,[5,5]]]])
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。