如何解决在 TensorFlow 中训练子类模型期间无法设置检查点
我一直在尝试实现子类化注意力模型,但无法在训练期间设置检查点。我已经查看了 save_and_serialize 和 save and load Keras Models,但对如何使用我的许多自定义层有效地实现它感到困惑。以下是完整代码片段。
class Embedding(tf.keras.layers.Layer):
"""docstring for Embedding."""
def __init__(self,vocab,emd_d,name):
super(Embedding,self).__init__()
cprint(f"Initializing {name}'s Embedding layer",'magenta')
self.E = self.add_weight('Embedding Layer',shape=(vocab,emd_d),dtype=tf.float32,initializer='random_uniform',trainable=True)
if name == 'DECODER':
self.Eo = self.add_weight('Embedding Layer Output',shape=(emd_d,vocab),trainable=True)
cprint(f"\t\tinitialization of {name}'s Embedding layer COMPLETE",'green')
def call(self,input):
return tf.gather(self.E,input,axis=0)
class Model(tf.keras.layers.Layer):
"""docstring for Model."""
def __init__(self,enc_T,dec_T,enc_emb_d,dec_emb_d,enc_voc,dec_voc,enc_units,dec_units,inp_token,targ_token,enc_return_state=False,dec_return_state=False,enc_f=tf.nn.tanh,dec_f=tf.nn.tanh):
super(Model,self).__init__()
cprint('PREPARING THE MODEL','yellow')
self.enc_T = enc_T
self.dec_T = dec_T
self.enc_emb_d = enc_emb_d
self.dec_emb_d = dec_emb_d
self.enc_voc = enc_voc
self.dec_voc = dec_voc
self.enc_units = enc_units
self.dec_units = dec_units
self.inp_token = inp_token
self.targ_token = targ_token
self.enc_emb = Embedding(enc_voc,'ENCODER')
self.enc = Encoder(enc_T,f=enc_f,return_state=enc_return_state)
self.dec_emb = Embedding(dec_voc,'DECODER')
self.dec = Decoder(dec_T,f=dec_f,return_state=dec_return_state)
self._trainable_weights = self.enc_emb.trainable_weights + \
self.enc.trainable_weights + \
self.dec_emb.trainable_weights + \
self.dec.trainable_weights
cprint('MODEL IS READY','yellow')
def forward(self,x):
i = self.targ_token.word_index['<start>']
y = tf.ones((x.shape[0],1),dtype=tf.int32)*i
emd_x = self.enc_emb(x)
state = self.enc(emd_x)
emd_y = self.dec_emb(y)
out = self.dec(emd_y,state)
return out
def Loss(self,y_true,logits):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true,logits=logits)
return tf.reduce_sum(loss)
def gradient(self,x,y,opt):
with tf.GradientTape() as tape:
out = self.forward(x)
logits = tf.matmul(out,self.dec_emb.Eo)
loss = self.Loss(y,logits)
grads = tape.gradient(loss,self._trainable_weights)
opt.apply_gradients(zip(grads,self._trainable_weights))
config = self.get_config()
custom_obj = {'Encoder Emb Layer': self.enc_emb,'Encoder Layer': self.enc,'Decoder Emb Layer': self.dec_emb,'Decoder Layer': self.dec}
with tf.keras.utils.custom_object_scope(custom_obj):
new_model = tf.keras.Model.from_config(config)
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory,"ckpt")
checkpoint = tf.train.Checkpoint(optimizer=opt,model=new_model)
checkpoint.save(file_prefix=checkpoint_prefix)
return loss
def get_config(self):
config = super(Model,self).get_config()
config.update({'Encoder timesteps': self.enc_T,'Decoder timesteps': self.dec_T,'Encoder Emb Dim': self.enc_emb_d,'Decoder Emb Dim': self.dec_emb_d,'Encoder Vocab': self.enc_voc,'Decoder Vocab': self.dec_voc,'Encoder Units': self.enc_units,'Decoder Units': self.dec_units,'Input Token': self.inp_token,'Target Token': self.targ_token,'layers': [self.enc_emb,self.enc,self.dec_emb,self.dec]})
return config
@classmethod
def from_config(cls,config):
return cls(**config)
以上述逻辑运行我收到以下错误:
Traceback (most recent call last):
File "E:\Project\Tensorflow 2.0\Attention\attention.py",line 412,in <module>
learning_rate,epochs)
File "E:\Project\Tensorflow 2.0\Attention\attention.py",line 244,in fit
L = self.gradient(input_batch,target_batch,optim)
File "E:\Project\Tensorflow 2.0\Attention\attention.py",line 220,in gradient
new_model = tf.keras.Model.from_config(config)
File "C:\Users\prana\AppData\Local\Programs\Python\python37\lib\site-packages\tensorflow\python\keras\engine\network.py",line 987,in from_config
config,custom_objects)
File "C:\Users\prana\AppData\Local\Programs\Python\python37\lib\site-packages\tensorflow\python\keras\engine\network.py",line 2019,in reconstruct_from_config
process_layer(layer_data)
File "C:\Users\prana\AppData\Local\Programs\Python\python37\lib\site-packages\tensorflow\python\keras\engine\network.py",line 1993,in process_layer
layer_name = layer_data['name']
TypeError: 'Embedding' object is not subscriptable
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。