微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

是否可以获得中间梯度? (张量流)

如何解决是否可以获得中间梯度? (张量流)

使用梯度胶带时可以计算使用后的梯度:

with tf.GradientTape() as tape:
        out = model(x,training=True)
        out = tf.reshape(out,(num_img,1,10)) # Resizing 
        loss = tf.keras.losses.categorical_crossentropy(y,out) 
        gradient = tape.gradient(loss,model.trainable_variables)

然而,在 cifar10 输入的情况下,这将返回输入图像的梯度。 有没有办法访问中间步骤的梯度,使它们经过“一些”训练?

解决方法

编辑:感谢您的评论,我对您的问题有了更好的了解。 下面的代码远非理想,没有考虑批量训练等,但它可能会给你一个很好的起点。 我编写了一个自定义训练步骤,它基本上替代了 model.fit 方法。可能有更好的方法可以做到这一点,但它应该可以让您快速比较梯度。

def custom_training(model,data):
    x,y = data
    # Training 
    with tf.GradientTape() as tape:
        y_pred = model(x,training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = tf.keras.losses.mse(y,y_pred)
        
    trainable_vars = model.trainable_variables
    gradients = tape.gradient(loss,trainable_vars)
    tf.keras.optimizers.Adam().apply_gradients(zip(gradients,trainable_vars))
    # computing the gradient without optimizing it!
    with tf.GradientTape() as tape:
        y_pred = model(x,training=False)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = tf.keras.losses.mse(y,y_pred)
    trainable_vars = model.trainable_variables
    gradients_plus = tape.gradient(loss,trainable_vars)
    
    return gradients,gradients_plus

让我们假设一个非常简单的模型:

import tensorflow as tf

train_data = tf.random.normal((1000,32))
train_features = tf.random.normal((1000,))

inputs = tf.keras.layers.Input(shape=(32))
hidden_1 = tf.keras.layers.Dense(32)(inputs)
hidden_2 = tf.keras.layers.Dense(32)(hidden_1)
outputs = tf.keras.layers.Dense(1)(hidden_2)

model = tf.keras.Model(inputs,outputs)

并且您想计算所有层相对于输入的梯度。 您可以使用以下内容:

with tf.GradientTape(persistent=True) as tape:
    tape.watch(inputs)
    out_intermediate = []
    inputs = train_data
    cargo = model.layers[0](inputs)
    for layer in model.layers[1:]:
        cargo = layer(cargo)
        out_intermediate.append(cargo)
        
for x in out_intermediate:
    print(tape.gradient(x,inputs))

如果您想计算自定义损失,我建议Customize what happens in Model.fit

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?