微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow 在与常数张量相乘后失去对变量/梯度的跟踪

如何解决Tensorflow 在与常数张量相乘后失去对变量/梯度的跟踪

我有一个带有一些自定义 tensorflow 层的 tensorflow 模型。我通过调用 tf.Variablesbuild() 方法中构建我的 self.add_weight(),因为它应该完成。然后我在调用之前将这些权重与其他一些常数张量相乘(考虑它的基础变化)。看来 tensorflow 失去了对我的变量的跟踪。然而,它们并没有消失在我的层的可训练变量中。 这是一个重现我想要做错误的示例:

class ToyLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(ToyLayer,self).__init__()
        self.basis_vector = tf.constant([1,0.,1])

    def build(self,input_shape):
        self.variable = self.add_weight(shape=(1,))
        self.effective_weight = self.variable*self.basis_vector

    def call(self,inputs,**kwargs):
        return tf.tensordot(inputs,self.effective_weight,axes=1)


layer = ToyLayer()
x = tf.random.normal((3,))
with tf.GradientTape() as tape:
    y = layer(x)
print(layer.trainable_weights)
print(tape.gradient(y,layer.trainable_weights))

可训练的权重仍然是它们所需要的,但对于梯度,我得到了 None。 将常数张量更改为 tf.Variable 无济于事。

如果我尝试用 tf.GradientTape() 做一些类似的事情,如果我将变量与梯度磁带中的向量相乘,我会得到正确的梯度,但如果在磁带之前进行向量变量乘法,则不会得到梯度。因此,在图层中,当将变量与向量相乘时,我的渐变似乎还没有被记录下来。我该如何解决这个问题?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?