如何解决火炬以2D张量收集3D张量作为“聚集图”
对于给定尺寸的张量data.shape [B,N,F]
和indices.shape [N,K]
,
其中K
是每个点N的“邻居”的索引,
是否存在一种简单的方法来收集邻居,使得output.shape [B,K,F]
吗?
data = [[[1,2,3],[2,3,4],[3,4,5]],[[3,5],[6,7,8],4]] # Shape 2,3
indices = [[1,2],[1,0],[0,0]] # Shape 3,2
output = [
[[[2,[[2,3]],[[1,3]]],[[[6,4]],[[6,5]]]
]
例如,每批中的第一个点“关联”到点[1,2]
,因此output[0][0] = [[2,5]]
是输入的第一批中的点1,2
。
我的尝试:
batched_indices = indices.unsqueeze(0).unsqueeze(-1).repeat(batch_size,1,feature_size)
data_neighs = data.unsqueeze(2).repeat(1,num_neighs,1)
output = torch.gather(data_neighs,batched_indices,dim=1)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。