微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

从没有重复的张量中选择顶部 K 值

如何解决从没有重复的张量中选择顶部 K 值

torch.Tensor.topk 提供了一种有效的方法来沿一个维度提取张量中的前 k 个值。是否可以将前 k 个值限制为不重复

例如

input = torch.tensor([0.2,0.2,0.1])
k = 2
dim = 0


output[0] = torch.tensor([0.2,0.1])
output[1] = torch.longtensor([0,2])

解决方法

您可以在输入张量上应用 torch.unique

>>> input.unique().topk(k=2).values
tensor([0.2000,0.1000])

请注意,此时您将丢失索引。


编辑:实际上 torch.unique 有一个对结果进行排序的选项(默认情况下该选项处于启用状态)。

>>> input
tensor([0.0000,0.3000,0.2000,0.1000])

>>> input.unique(return_inverse=True)[1].unique(sorted=False)
tensor([1,2,3,0])

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。