微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

PyTorch:除顶部 k 之外的所有向量元素都归零?

如何解决PyTorch:除顶部 k 之外的所有向量元素都归零?

我正在尝试创建一个新的激活层,我们称之为 topk,其工作方式如下。它将把一个大小为 n 的向量 x 作为输入(将前一层输出乘以权重矩阵并加上偏置的结果)和一个正整数 k 并输出一个大小为 n 的向量 topk(x),其元素是:

              x_i (if x_i is one of the top k elements of x) 
topk(x)_i = 
              0 (otherwise)

在计算topk(x)的梯度时,x的前k个元素的梯度应该是1,其他的都是0。

我应该如何实现这一点?任何帮助将不胜感激。

解决方法

您可以使用 torch.topk

k = 2
output = torch.randn(5)
vals,idx = output.topk(k)

topk = torch.zeros_like(output)
topk[idx] = vals
>>> topk
tensor([1.0557,0.0000,1.4562,0.0000])

请注意,虽然 'values'topk() 是可微的,但 'indices' are not(类似于 argmax 不是可微函数的方式)。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?