如何解决PyTorch:除顶部 k 之外的所有向量元素都归零?
我正在尝试创建一个新的激活层,我们称之为 topk,其工作方式如下。它将把一个大小为 n 的向量 x 作为输入(将前一层输出乘以权重矩阵并加上偏置的结果)和一个正整数 k 并输出一个大小为 n 的向量 topk(x),其元素是:
x_i (if x_i is one of the top k elements of x)
topk(x)_i =
0 (otherwise)
在计算topk(x)的梯度时,x的前k个元素的梯度应该是1,其他的都是0。
我应该如何实现这一点?任何帮助将不胜感激。
解决方法
您可以使用 torch.topk
:
k = 2
output = torch.randn(5)
vals,idx = output.topk(k)
topk = torch.zeros_like(output)
topk[idx] = vals
>>> topk
tensor([1.0557,0.0000,1.4562,0.0000])
请注意,虽然 'values'
的 topk()
是可微的,但 'indices'
are not(类似于 argmax 不是可微函数的方式)。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。