如何解决tensorflow_probability 的 DenseVariational 层不适用于 ModelCheckpoint 回调
我在保存包含 tensorflow_probability 的 DenseVariational 层的模型时遇到问题。
尝试保存完整模型时出现错误:
Epoch 00001: saving model to files/model.h5
Traceback (most recent call last):
File "train_lstm.py",line 60,in <module>
main()
File "train_lstm.py",line 43,in main
history,model = model_handler.fit_model(
File "/home/nilsflaschel/Projects/Autopilot/Forecastings/lstm/handlers/model_handler.py",line 114,in fit_model
history = model.fit(
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py",line 1145,in fit
callbacks.on_epoch_end(epoch,epoch_logs)
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py",line 428,in on_epoch_end
callback.on_epoch_end(epoch,logs)
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py",line 1344,in on_epoch_end
self._save_model(epoch=epoch,logs=logs)
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py",line 1408,in _save_model
self.model.save(filepath,overwrite=True,options=self._options)
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py",line 2001,in save
save.save_model(self,filepath,overwrite,include_optimizer,save_format,File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py",line 153,in save_model
hdf5_format.save_model_to_hdf5(
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py",line 115,in save_model_to_hdf5
model_Metadata = saving_utils.model_Metadata(model,include_optimizer)
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/saving/saving_utils.py",line 158,in model_Metadata
raise e
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/saving/saving_utils.py",line 155,in model_Metadata
model_config['config'] = model.get_config()
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py",line 650,in get_config
return copy.deepcopy(get_network_config(self))
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py",line 1349,in get_network_config
layer_config = serialize_layer_fn(layer)
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py",line 250,in serialize_keras_object
raise e
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py",line 245,in serialize_keras_object
config = instance.get_config()
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py",line 699,in get_config
raise NotImplementedError('Layer %s has arguments in `__init__` and '
NotImplementedError: Layer DenseVariational has arguments in `__init__` and therefore must override `get_config`.
如果我只想保存权重,那么这个:
Epoch 00001: saving model to files/model.h5
Traceback (most recent call last):
File "train_lstm.py",line 1405,in _save_model
self.model.save_weights(
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py",line 2108,in save_weights
hdf5_format.save_weights_to_hdf5_group(f,self.layers)
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py",line 642,in save_weights_to_hdf5_group
param_dset = g.create_dataset(name,val.shape,dtype=val.dtype)
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/h5py/_hl/group.py",line 139,in create_dataset
self[name] = dset
File "/home/nilsflaschel/miniconda3/envs/autopilot/lib/python3.8/site-packages/h5py/_hl/group.py",line 373,in __setitem__
h5o.link(obj.id,self.id,name,lcpl=lcpl,lapl=self._lapl)
File "h5py/_objects.pyx",line 54,in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx",line 55,in h5py._objects.with_phil.wrapper
File "h5py/h5o.pyx",line 202,in h5py.h5o.link
RuntimeError: Unable to create link (name already exists)
我的模型架构:
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None,7,73)] 0
_________________________________________________________________
lstm (LSTM) (None,64) 35328
_________________________________________________________________
dense_variational (DenseVari (None,2) 390
_________________________________________________________________
distribution_lambda (distrib multiple 0
=================================================================
Total params: 35,718
Trainable params: 35,718
Non-trainable params: 0
_________________________________________________________________
编译模型后,我打印出weight.name
:
Model Weights:
0 lstm/lstm_cell/kernel:0
1 lstm/lstm_cell/recurrent_kernel:0
2 lstm/lstm_cell/bias:0
3 dense_variational/constant:0
4 dense_variational/constant:0
我猜这个问题源于重复的名称。这很奇怪,我没有解释,因为只有一个 DenseVariational
层。
原则上我只是使用了这里的代码: https://www.tensorflow.org/probability/examples/Probabilistic_Layers_Regression
tf.keras.callbacks.ModelCheckpoint(
filepath="files/model.h5",mode="min",monitor="val_loss",save_best_only=False,verbose=1,save_freq="epoch",save_weights_only=True
)
以及完整模型:
negloglik = lambda y,rv_y: -rv_y.log_prob(y)
input1 = tf.keras.layers.Input(shape=X_train.shape[1:])
net = tf.keras.layers.LSTM(64,dropout=0.2)(input1)
net = tfp.layers.DenseVariational(
units=1 + 1,make_posterior_fn=posterior_mean_field,make_prior_fn=prior_trainable,kl_weight=1/X_train.shape[0],)(net)
net = tfp.layers.distributionLambda(
lambda t: tfd.normal(loc=t[...,:1],scale=1e-3 + tf.math.softplus(0.01 * t[...,1:])),)(net)
model = tf.keras.Model([input1],net)
model.compile(
loss=negloglik,optimizer=tfa.optimizers.AdamW(learning_rate=learning_rate,weight_decay=5e-5),)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。