如何解决Python:从张量中删除索引
我有两个张量。
tensor_A 对应一批 8 张图片,有 20 类对象,每张图片大小为 256 x 256
tensor_B 对应 len 20 满 1 或 0 的 8 个数组,对应对象类是否存在
tensor_A.shape = ([8,20,256,256])
tensor_B.shape = ([8,20])
从 tensor_A,我想删除对应于 tensor_B 中的 1 的索引
例如如果 tensor_B[0] = [1,1,0]
我想做 tensor_A[0,:,:].drop 然后 tensor_A[0,2,:].drop 等等,但都是一步
到目前为止,我已经使用以下方法确定了对应于 1 的索引:
for i in range(8):
(tensor_B[i,:] == 0).nonzero())
# code for dropping here
不确定如何进行
解决方法
你想要的不起作用,因为:
# A -> tensor of shape (8,20,256,256)
# B -> tensor of shape (8,20)
# If B[0] = [1,1,0]
dropped_A_0 = A[0,B[0] == 0,:,:]
# dropped_A_0 -> tensor of shape (1,13,256)
# If B[1] = [1,0]
dropped_A_1 = A[1,B[1] == 0,:]
# dropped_A_1 -> tensor of shape (1,16,256)
你看到问题了吗?当您从 A 的行中“删除”值时,它们不再具有相同的形状,因此不能作为单个张量一起存在。您可以拥有的是带有删除值的 A 行列表:
dropped_A = []
for i in range(len(A)):
dropped_A.append(A[i,B[i] == 0,:])
您可以做的另一件事是将 A 中不需要的值设置为 0。
A[B == 1] = 0
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。