微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在自定义训练循环之前/之后检查 Keras 模型的可训练权重是否发生变化

如何解决如何在自定义训练循环之前/之后检查 Keras 模型的可训练权重是否发生变化

我正在尝试验证自定义训练循环是否会更改 Keras 模型的权重。我目前的方法是在训练前 deepcopy model.trainable_weights 列表,然后在训练后将其与 model.trainable_weights 进行比较。这是进行这种比较的有效方法吗?我的方法的结果表明权重确实发生了变化(无论如何这是预期的结果,因为每个时期的损失明显减少),但我只想验证我所做的是否有效。下面是稍作修改Keras custom training loop tutorial 代码加上我用来比较模型训练前后权重变化的代码

# Imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from copy import deepcopy

# The model
inputs = keras.Input(shape=(784,),name="digits")
x1 = layers.Dense(64,activation="relu")(inputs)
x2 = layers.Dense(64,activation="relu")(x1)
outputs = layers.Dense(10,name="predictions")(x2)
model = keras.Model(inputs=inputs,outputs=outputs)

##########################
# WEIGHTS BEFORE TRAINING
##########################
# I use deepcopy here to avoid mutating the weights list during training
weights_before_training = deepcopy(model.trainable_weights)

##########################
# Keras Tutorial
##########################

# Load data
(x_train,y_train),(x_test,y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train,(-1,784))
x_test = np.reshape(x_test,784))

# Reduce the size of the data to speed up training
x_train = x_train[:128] 
x_test = x_test[:128]
y_train = y_train[:128]
y_test = y_test[:128]

# Make tf dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_dataset = train_dataset.shuffle(buffer_size=64).batch(16)

# The training loop
print('Begin Training')
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
epochs = 2
for epoch in range(epochs):
    # Logging start of epoch
    print("\nStart of epoch %d" % (epoch,))

    # Save loss values for logging
    loss_values = []

    # Iterate over the batches of the dataset.
    for step,(x_batch_train,y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train,training=True)  # Logits for this minibatch
            loss_value = loss_fn(y_batch_train,logits)

        # Append to list for logging
        loss_values.append(loss_value)

        grads = tape.gradient(loss_value,model.trainable_weights)

        optimizer.apply_gradients(zip(grads,model.trainable_weights))

    print('Epoch Loss:',np.mean(loss_values))

print('End Training')
##########################
# WEIGHTS AFTER TRAINING
##########################

weights_after_training = model.trainable_weights

# Note: `trainable_weights` is a list of kernel and bias tensors.
print()
print('Begin Trainable Weights Comparison')
for i in range(len(weights_before_training)):
    print(f'Trainable Tensors for Element {i + 1} of List Are Equal:',tf.reduce_all(tf.equal(weights_before_training[i],weights_after_training[i])).numpy())
print('End Trainable Weights Comparison')

>>> Begin Training
>>> Start of epoch 0
>>> Epoch Loss: 44.66055
>>> 
>>> Start of epoch 1
>>> Epoch Loss: 5.306543
>>> End Training
>>>
>>> Begin Trainable Weights Comparison
>>> Trainable Tensors for Element 1 of List Are Equal : False
>>> Trainable Tensors for Element 2 of List Are Equal : False
>>> Trainable Tensors for Element 3 of List Are Equal : False
>>> Trainable Tensors for Element 4 of List Are Equal : False
>>> Trainable Tensors for Element 5 of List Are Equal : False
>>> Trainable Tensors for Element 6 of List Are Equal : False
>>> End Trainable Weights Comparison

解决方法

从评论中总结并添加更多信息,以造福社区:

上面代码遵循的方法,即比较deepcopy(model.trainable_weights)之前的Training和{{之后的model.trainable_weights 1}} 使用 Model 进行训练,是正确的方法

除此之外,如果我们不想训练模型,我们可以冻结{{1}的所有Custom Training Loop使用代码, Layers

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?