如何解决在PyTorch中进行非随机替换的均匀采样
给出形状为4x4的布尔张量(掩码):
x = torch.tensor([
[ True,True,False,True ],[ False,False ],])
我想相应地对(7)进行采样,结果为4x7形状:
tensor([[0,1,3,3],[3,[1,2,2]])
我能想到的最接近的是以下实现:
def uniform_sampling(tensor,count = 1):
indices = torch.arange(0,tensor.shape[-1],device = tensor.device).expand(tensor.shape)
samples_count = tensor.long().sum(-1)
output = tensor.long() * (count // samples_count)[:,None]
remainder = count - output.sum(-1)
rem1 = torch.stack((remainder,tensor.sum(-1) - remainder),-1).flatten()
rem2 = torch.stack((torch.ones_like(remainder),torch.zeros_like(remainder)),-1).flatten()
remaining = rem2.repeat_interleave(rem1,0)
output[tensor > 0] += remaining
samples = indices[tensor].repeat_interleave(output[tensor],-1).view(-1,count)
return samples
uniform_sampling(x,count = 7)
是否有任何(也许是本地的)PyTorch功能可以执行相同但更快,更有效的工作?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。