微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

未使用 Gradient Tape 和 apply_gradients() 更新权重

如何解决未使用 Gradient Tape 和 apply_gradients() 更新权重

我正在构建一个带有自定义损失函数的 DNN,我正在使用 TensorFlow.kerasenter code here 中的 Gradient Tape 训练这个 DNN。代码运行没有任何错误,但是,就我可以检查 DNN 的权重而言,权重根本没有更新。我完全遵循 TensorFlow 网站上的建议并搜索答案,但仍然不明白是什么原因。这是我的代码

import numpy as np

import tensorflow as tf
from tensorflow.keras.layers import Input,Dense,LeakyReLU,Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras import optimizers

# Generate a random train data
c0_train = np.array([30 * np.random.uniform() for i in range(10000)])

# Build a simple DNN
c0_input = Input(shape=(1,),name='c0')
hidden_1 = Dense(100)(c0_input)
activation_1 = LeakyReLU(alpha=0.1)(hidden_1)
hidden_2 = Dense(100)(activation_1)
activation_2 = LeakyReLU(alpha=0.1)(hidden_2)
hidden_3 = Dense(100)(activation_2)
activation_3 = LeakyReLU(alpha=0.1)(hidden_3)
x0_output = Dense(1,name='x0')(activation_3)

model = Model(inputs=c0_input,outputs=x0_output)

# Calculating the loss function 
def cal_loss(c0_input):
  x0_output = model(c0_input)
  loss = tf.reduce_mean(
      tf.multiply(c0_input,tf.square(tf.subtract(x0_output,c0_input))))
  return loss

# Compute the gradient calculation
@tf.function
def compute_loss_grads(c0_input):
  with tf.GradientTape() as tape:
    loss = cal_loss(c0_input)
  grads = tape.gradient(loss,model.trainable_variables)
  return loss,grads

# Optimizer
opt = optimizers.Adam(learning_rate=0.01)

# Start looping
for epoch in range(50):
  print('Epoch = ',epoch)
  # Compute the loss and gradients
  [loss,grads] = compute_loss_grads(tf.cast(c0_train,tf.float32))
  # Adjust the weights of the model
  opt.apply_gradients(zip(grads,model.trainable_variables))

我已经使用 model.get_weights() 检查了模型的权重,它们在运行循环之前和之后看起来完全一样。那么这里的问题是什么?还有一个问题,我如何打印出每个 epoch 的损失?

解决方法

权重确实发生了变化。您可以进行如下检查;建立模型后保存您的权重文件(这些是初始权重)。

model = Model(inputs=c0_input,outputs=x0_output)
a_weg = model.get_weights()

现在,运行您的训练循环。训练完成后,得到新的权重如下,对比前后对比。

b_weg = model.get_weights()

a_weg[:1]
[array([[ 0.03541631,-0.02134866,0.17080751,0.10538128,0.1361396,0.08645812,0.114059,0.216836,-0.22464292,-0.21979895,-0.23927784,-0.00685263,0.2167016,0.09989142,-0.17772573,0.16095945,-0.10120587,-0.22456157,-0.22947621,0.04009536,0.01029667,-0.18134505,-0.11318983,0.10220072,0.10100928,b_weg[:1]
[array([[ 0.05140253,0.00969543,0.15155758,0.07171137,0.15917814,0.10883425,0.11428417,0.17012525,-0.25049415,-0.20693016,-0.20231842,0.005939,0.19197173,0.07405043,-0.14260964,0.12490476,-0.11532102,-0.24605738,-0.25135723,0.01863468,0.0311144,-0.20050383,-0.11864465,0.07961675,0.11557189,

这就是您可以在每个时期打印损失分数的方式。

# Start looping
for epoch in range(5):
  # Compute the loss and gradients
  [loss,grads] = compute_loss_grads(tf.cast(c0_train,tf.float32))
  # Adjust the weights of the model
  opt.apply_gradients(zip(grads,model.trainable_variables))
  print('Epoch = ',epoch,' - loss = ',loss.numpy())
Epoch =  0  - loss =  5962.977
Epoch =  1  - loss =  3042.2874
Epoch =  2  - loss =  2877.9978
Epoch =  3  - loss =  2607.5347
Epoch =  4  - loss =  2173.3213

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?