微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 TensorFlow 中训练子类模型期间无法设置检查点

如何解决在 TensorFlow 中训练子类模型期间无法设置检查点

我一直在尝试实现子类化注意力模型,但无法在训练期间设置检查点。我已经查看了 save_and_serializesave and load Keras Models,但对如何使用我的许多自定义层有效地实现它感到困惑。以下是完整代码片段。

class Embedding(tf.keras.layers.Layer):
    """docstring for Embedding."""

    def __init__(self,vocab,emd_d,name):
        super(Embedding,self).__init__()
        cprint(f"Initializing {name}'s Embedding layer",'magenta')
        self.E = self.add_weight('Embedding Layer',shape=(vocab,emd_d),dtype=tf.float32,initializer='random_uniform',trainable=True)
        if name == 'DECODER':
            self.Eo = self.add_weight('Embedding Layer Output',shape=(emd_d,vocab),trainable=True)
        cprint(f"\t\tinitialization of {name}'s Embedding layer COMPLETE",'green')

    def call(self,input):
        return tf.gather(self.E,input,axis=0)


class Model(tf.keras.layers.Layer):
    """docstring for Model."""

    def __init__(self,enc_T,dec_T,enc_emb_d,dec_emb_d,enc_voc,dec_voc,enc_units,dec_units,inp_token,targ_token,enc_return_state=False,dec_return_state=False,enc_f=tf.nn.tanh,dec_f=tf.nn.tanh):
        super(Model,self).__init__()
        cprint('PREPARING THE MODEL','yellow')
        self.enc_T = enc_T
        self.dec_T = dec_T
        self.enc_emb_d = enc_emb_d
        self.dec_emb_d = dec_emb_d
        self.enc_voc = enc_voc
        self.dec_voc = dec_voc
        self.enc_units = enc_units
        self.dec_units = dec_units
        self.inp_token = inp_token
        self.targ_token = targ_token
        self.enc_emb = Embedding(enc_voc,'ENCODER')
        self.enc = Encoder(enc_T,f=enc_f,return_state=enc_return_state)
        self.dec_emb = Embedding(dec_voc,'DECODER')
        self.dec = Decoder(dec_T,f=dec_f,return_state=dec_return_state)
        self._trainable_weights = self.enc_emb.trainable_weights + \
                                  self.enc.trainable_weights + \
                                  self.dec_emb.trainable_weights + \
                                  self.dec.trainable_weights
        cprint('MODEL IS READY','yellow')

    def forward(self,x):
        i = self.targ_token.word_index['<start>']
        y = tf.ones((x.shape[0],1),dtype=tf.int32)*i
        emd_x = self.enc_emb(x)
        state = self.enc(emd_x)
        emd_y = self.dec_emb(y)
        out = self.dec(emd_y,state)
        return out

    def Loss(self,y_true,logits):
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true,logits=logits)
        return tf.reduce_sum(loss)

    def gradient(self,x,y,opt):
        with tf.GradientTape() as tape:
            out = self.forward(x)
            logits = tf.matmul(out,self.dec_emb.Eo)
            loss = self.Loss(y,logits)
        grads = tape.gradient(loss,self._trainable_weights)
        opt.apply_gradients(zip(grads,self._trainable_weights))

        config = self.get_config()
        custom_obj = {'Encoder Emb Layer': self.enc_emb,'Encoder Layer': self.enc,'Decoder Emb Layer': self.dec_emb,'Decoder Layer': self.dec}
        with tf.keras.utils.custom_object_scope(custom_obj):
            new_model = tf.keras.Model.from_config(config)
        checkpoint_directory = "/tmp/training_checkpoints"
        checkpoint_prefix = os.path.join(checkpoint_directory,"ckpt")
        checkpoint = tf.train.Checkpoint(optimizer=opt,model=new_model)
        checkpoint.save(file_prefix=checkpoint_prefix)
        return loss

    def get_config(self):
        config = super(Model,self).get_config()
        config.update({'Encoder timesteps': self.enc_T,'Decoder timesteps': self.dec_T,'Encoder Emb Dim': self.enc_emb_d,'Decoder Emb Dim': self.dec_emb_d,'Encoder Vocab': self.enc_voc,'Decoder Vocab': self.dec_voc,'Encoder Units': self.enc_units,'Decoder Units': self.dec_units,'Input Token': self.inp_token,'Target Token': self.targ_token,'layers': [self.enc_emb,self.enc,self.dec_emb,self.dec]})
        return config

    @classmethod
    def from_config(cls,config):
        return cls(**config)

以上述逻辑运行我收到以下错误

Traceback (most recent call last):
  File "E:\Project\Tensorflow 2.0\Attention\attention.py",line 412,in <module>
    learning_rate,epochs)
  File "E:\Project\Tensorflow 2.0\Attention\attention.py",line 244,in fit
    L = self.gradient(input_batch,target_batch,optim)
  File "E:\Project\Tensorflow 2.0\Attention\attention.py",line 220,in gradient
    new_model = tf.keras.Model.from_config(config)
  File "C:\Users\prana\AppData\Local\Programs\Python\python37\lib\site-packages\tensorflow\python\keras\engine\network.py",line 987,in from_config
    config,custom_objects)
  File "C:\Users\prana\AppData\Local\Programs\Python\python37\lib\site-packages\tensorflow\python\keras\engine\network.py",line 2019,in reconstruct_from_config
    process_layer(layer_data)
  File "C:\Users\prana\AppData\Local\Programs\Python\python37\lib\site-packages\tensorflow\python\keras\engine\network.py",line 1993,in process_layer
    layer_name = layer_data['name']
TypeError: 'Embedding' object is not subscriptable

我想知道导致此错误后台出了什么问题。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。