微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用特定值创建pytorch张量二进制掩码

如何解决使用特定值创建pytorch张量二进制掩码

为我提供了一个带有整数的pytorch二维张量,以及总是出现在张量的每一行中的2个整数。 我想创建一个二进制掩码,在这2个整数的两个外观之间将包含1,否则为0。例如,如果整数是4和2,而一维数组是{{1 }},返回的掩码将是:[1,1,9,4,6,5,2,11,3,4] 是否有任何高效且快速方法来计算此蒙版而无需迭代?

解决方法

也许有点混乱,但是它不需要迭代就可以工作。在下面的示例中,我假设使用了一个示例张量m,并将其应用于解决方案,用它来解释起来比使用一般符号更容易。

import torch

vals=[2,8]#let's assume those are the constant values that appear in each row

#target tensor
m=torch.tensor([[1.,2.,7.,8.,5.],[4.,1.,8.]])

#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]

v=(k.int()+p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2)

#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)

#you only need two masks,no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1))>=q

final_mask=msk_0.int() * msk_1.int()

print(final_mask)

我们得到

tensor([[0,1,0],[0,1]],dtype=torch.int32)

关于两个掩码mask_0mask_1(如果不清楚它们是什么),请注意nz_indexes[:,0]包含m每一行的列索引,找到vals[0],并且nz_indexes[:,1]同样包含m的每一行,其中找到vals[1]的列索引。

,

完全基于先前的解决方案,以下是修改后的解决方案:

import torch

vals=[2,5.,6.,4.],4)

#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1))>=q
msk_2=(nz_indexes[:,2].repeat(m.shape[1],1))<=q
msk_3=(nz_indexes[:,3].repeat(m.shape[1],1))>=q

final_mask=msk_0.int() * msk_1.int() + msk_2.int() * msk_3.int()

print(final_mask)

我们终于得到

tensor([[0,dtype=torch.int32)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。