如何解决如何通过 Pytorch 中的跟踪值有效地成对池化张量?
我有一个形状为 T
的 pytorch 张量 (batch_size,window_size,filters,3,3)
,我想通过跟踪来汇集张量。具体来说,我想通过比较成对帧的轨迹来获得大小为 T_pooled
的张量 (batch_size,window_size//2,3)
。例如,如果window_size=4
,那么我们将比较T[i,k,3]
和T[i,1,3]
的轨迹,并选择轨迹较小的子张量作为T_pooled[i,3]
。同样,比较T[i,2,3]
得到T_pooled[i,3]
。
这可以通过循环 i
和 k
来完成,但这非常缓慢且效率低下。有没有办法对这个池化操作进行矢量化以加快速度?
编辑: 这是我迄今为止尝试过的。它使用列表理解和 for 循环。在大小为 (128,120,22,3) 的张量上运行大约需要 2.5 秒。
def TPL_Pairwise(x):
x_pooled=torch.zeros(x.shape[0],x.shape[1]//2,x.shape[2],x.shape[3],x.shape[4])
#compute tensorized trace
trace=torch.einsum('ijkll->ijkl',x).sum(-1)
for i in range(x.shape[0]):
for j in range(x.shape[2]):
keep=[ x[i,j] if trace[i,j] <= trace[i,k+1,j] else x[i,j] for k in range(0,x.shape[1],2)]
x_pooled[i,:,j]=torch.stack(keep)
return x_pooled
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。