微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在 Keras 中获得可重现的权重初始化?

如何解决如何在 Keras 中获得可重现的权重初始化?

  1. 我将 numpy 和 tensorflow 随机种子都设置为 suggested
  2. 生成一些数据 - 这部分是可重复的,总是给出相同的结果
  3. 创建一个简单的网络并进行预测(无需训练,仅使用随机权重) - 每次预测都不同
import numpy as np
from tensorflow.keras.layers import Dense,Dropout
from tensorflow.keras import Sequential,optimizers
import tensorflow as tf

np.random.seed(32)
tf.set_random_seed(33)

random_data = np.random.rand(10,2048)
print(random_data[:,0])

def make_classifier():
    model = Sequential()
    model.add(Dense(1024,activation='relu',input_dim=2048))
    model.add(Dropout(0.5))
    model.add(Dense(1,activation='sigmoid'))
    model.compile(optimizer=optimizers.RMSprop(lr=2e-4),loss='binary_crossentropy')
    return model
model = make_classifier()
# model.summary()
model.predict(random_data)

当我重新运行整个单元格时,打印语句总是输出 [0.85888927 0.23846818 0.17757634 0.07244977 0.71119893 0.09223853 0.86074647 0.31838194 0.7568638 0.38197083]。但是每次的预测都不一样:

array([[0.5825965 ],[0.8677979 ],[0.70151913],[0.64572096],[0.78101623],[0.76483005],[0.7946336 ],[0.6281612 ],[0.8208673 ],[0.8273002 ]],dtype=float32)
array([[0.51012236],[0.6562015 ],[0.5593666 ],[0.686155  ],[0.6488372 ],[0.5966359 ],[0.6236731 ],[0.58099884],[0.68447435],[0.58886844]],dtype=float32)

等等。

  1. 如何为刚刚初始化的网络获得可重现的预测?
  2. 权重初始化究竟何时发生?当我编译模型或?..

解决方法

tf.keras.initializers 对象具有用于可重现初始化的 seed 参数。

import tensorflow as tf
import numpy as np

initializer = tf.keras.initializers.GlorotUniform(seed=42)

for _ in range(10):
    print(np.round(initializer((4,)),3))
[-0.377 -0.003  0.373 -0.831]
[-0.377 -0.003  0.373 -0.831]
[-0.377 -0.003  0.373 -0.831]
[-0.377 -0.003  0.373 -0.831]
[-0.377 -0.003  0.373 -0.831]
[-0.377 -0.003  0.373 -0.831]
[-0.377 -0.003  0.373 -0.831]
[-0.377 -0.003  0.373 -0.831]
[-0.377 -0.003  0.373 -0.831]
[-0.377 -0.003  0.373 -0.831]

在 Keras 层中,您可以像这样使用它:

tf.keras.layers.Dense(1024,activation='relu',input_dim=2048,kernel_initializer=tf.keras.initializers.GlorotUniform(seed=42))
,

您使用的是哪个版本的 tensorflow (python3 -c 'import tensorflow; print(tensorflow.__version__)')

我使用 2.3.1,用 tf.random.set_seed() 设置我的种子,我确实得到了可重复的结果。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。