微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在PyTorch的张量的每一行中随机设置固定数量的元素

如何解决如何在PyTorch的张量的每一行中随机设置固定数量的元素

我想知道下面的代码是否还有更有效的替代方法,而无需在第四行使用“ for”循环?

import torch
n,d = 37700,7842
k = 4
sample = torch.cat([torch.randperm(d)[:k] for _ in range(n)]).view(n,k)
mask = torch.zeros(n,d,dtype=torch.bool)
mask.scatter_(dim=1,index=sample,value=True)

基本上,我想做的是创建一个nd的掩码张量,以使每一行中正好k随机元素为True。

解决方法

这是没有循环的一种方法。让我们从一个随机矩阵开始,在该矩阵中所有元素均按iid绘制,在这种情况下,均匀地分布在[0,1]上。然后,我们为每一行取第k个分位数,并在每一行上将所有较小或相等的元素设置为True,并将其余元素设置为False:

rand_mat = torch.rand(n,d)
k_th_quant = torch.topk(rand_mat,k,largest = False)[0][:,-1:]
mask = rand_mat <= k_th_quant

不需要循环:) x2.1598比您在CPU上附加的代码快。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。