微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在模型中使用 tf.gradients 并仍然使用自定义训练循环?

如何解决如何在模型中使用 tf.gradients 并仍然使用自定义训练循环?

我想制作一个 TensorFlow 模型,其中输出遵守数学条件,即输出 0 是标量函数,所有后续输出都是其偏导数 w.r.t.输入。这是因为我的观察是标量函数及其偏函数,不使用偏函数进行训练会浪费信息。

就目前而言,如果我不构建自定义训练循环,即当我不使用 Eager Execution 时,仅使用 tf.gradients 即可。模型是这样构建的,训练按预期进行:

import tensorflow as tf


from tensorflow.keras import losses
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks

# Creating a model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Dense,Dropout,Flatten,Concatenate,Input,Lambda,)

# Custom activation function
from tensorflow.keras.layers import Activation
from tensorflow.keras import backend as K

import numpy
import matplotlib.pyplot as plt

import tensorboard

layer_width = 200
dense_layer_number = 3

def lambda_gradient(args):
    layer = args[0]
    inputs = args[1]
    return tf.gradients(layer,inputs)[0]

# Input is a 2 dimensional vector
inputs = tf.keras.Input(shape=(2,),name="coordinate_input")

# Build `dense_layer_number` times a dense layers of width `layer_width`
stream = inputs
for i in range(dense_layer_number):
    stream = Dense(
        layer_width,activation="relu",name=f"dense_layer_{i}"
    )(stream)

# Build one dense layer that reduces the 200 nodes to a scalar output
scalar = Dense(1,name="network_to_scalar",activation=custom_activation)(stream)

# Take the gradient of the scalar w.r.t. the model input
gradient = Lambda(lambda_gradient,name="gradient_layer")([scalar,inputs])

# Combine them to form the model output
concat = Concatenate(name="concat_scalar_gradient")([scalar,gradient])

# Wrap everything in a model
model = tf.keras.Model(inputs=inputs,outputs=concat)

loss = "MSE"
optimizer = "Adam"

# And compile
model.compile(loss=loss,optimizer=optimizer)

然而,当我想要进行在线培训(即使用增量数据集)时,他们的问题就出现了。在这种情况下,我不会在最后编译我的模型。相反,我写了一个循环(在调用 model.compile 之前):

# ... continue from prevIoUs minus model.compile

loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam()

# Iterate over the batches of a dataset and train.
for i_batch in range(number_of_batches):

    with tf.GradientTape() as tape:
        # Predict w.r.t. the inputs X
        prediction_Y = model(batches_X[i_batch])
        
        # Compare batch prediction to batch observation
        loss_value = loss_fn(batches_Y[i_batch],prediction_Y)

    gradients = tape.gradient(loss_value,model.trainable_weights)
    optimizer.apply_gradients(zip(gradients,model.trainable_weights))

然而,这在 prediction_Y = model(batches_X[i_batch]) 处给出了以下异常:

RuntimeError: tf.gradients is not supported when eager execution is enabled. Use tf.GradientTape instead.

由于大多数示例、教程和文档仅涉及使用梯度进行训练,而不是在模型中,因此我找不到任何好的资源来解决这个问题。我试图找到如何使用渐变胶带,但在模型设计阶段我无法弄清楚如何使用它。任何指针将不胜感激!

使用的版本:

$ python --version                                         
Python 3.8.5
$ python -c "import tensorflow as tf;print(tf.__version__);print(tf.keras.__version__)"
2.2.0
2.3.0-tf

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?