微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在PyTorch中进行非随机替换的均匀采样

如何解决在PyTorch中进行非随机替换的均匀采样

给出形状为4x4的布尔张量(掩码):

x = torch.tensor([ 
    [ True,True,False,True ],[ False,False ],])

我想相应地对(7)进行采样,结果为4x7形状:

tensor([[0,1,3,3],[3,[1,2,2]])

我能想到的最接近的是以下实现:

def uniform_sampling(tensor,count = 1):
    indices = torch.arange(0,tensor.shape[-1],device = tensor.device).expand(tensor.shape)
    samples_count = tensor.long().sum(-1)
    output = tensor.long() * (count // samples_count)[:,None]
    remainder = count - output.sum(-1)
    
    rem1 = torch.stack((remainder,tensor.sum(-1) - remainder),-1).flatten()
    rem2 = torch.stack((torch.ones_like(remainder),torch.zeros_like(remainder)),-1).flatten()
    remaining = rem2.repeat_interleave(rem1,0)
    
    output[tensor > 0] += remaining
    samples = indices[tensor].repeat_interleave(output[tensor],-1).view(-1,count)
    
    return samples

uniform_sampling(x,count = 7)

是否有任何(也许是本地的)PyTorch功能可以执行相同但更快,更有效的工作?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。