微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在Keras中实现对Bahdanau的关注?

如何解决如何在Keras中实现对Bahdanau的关注?

如何在Keras中实施Bahdanau注意层。喀拉拉邦是否有Bahdanau注意层(例如密实,lstm)? 如果没有,那么请您解释一下如何在keras中实现它。

解决方法

TensorFlow-Keras中有一个Bahdanau Attention的实现,更确切地说是tensorflow-addons

您可以在这里看看:https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BahdanauAttention

要使用tensorflow-addons,请确保您pip install tensorflow-addons

在本教程中,您可以在这里找到:(https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt),例如,Luong的注意;您可以轻松地对其进行修改并选择Bahdanhau Attetion。

#ENCODER
class EncoderNetwork(tf.keras.Model):
    def __init__(self,input_vocab_size,embedding_dims,rnn_units ):
        super().__init__()
        self.encoder_embedding = tf.keras.layers.Embedding(input_dim=input_vocab_size,output_dim=embedding_dims)
        self.encoder_rnnlayer = tf.keras.layers.LSTM(rnn_units,return_sequences=True,return_state=True )
    
    #DECODER
    class DecoderNetwork(tf.keras.Model):
        def __init__(self,output_vocab_size,rnn_units):
            super().__init__()
            self.decoder_embedding = tf.keras.layers.Embedding(input_dim=output_vocab_size,output_dim=embedding_dims) 
            self.dense_layer = tf.keras.layers.Dense(output_vocab_size)
            self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units)
            # Sampler
            self.sampler = tfa.seq2seq.sampler.TrainingSampler()
            # Create attention mechanism with memory = None
            self.attention_mechanism = self.build_attention_mechanism(dense_units,None,BATCH_SIZE*[Tx])
            self.rnn_cell =  self.build_rnn_cell(BATCH_SIZE)
            self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell,sampler= self.sampler,output_layer=self.dense_layer)
    
        def build_attention_mechanism(self,units,memory,memory_sequence_length):
             
           # HERE
            return tfa.seq2seq.LuongAttention(units,memory = memory,memory_sequence_length=memory_sequence_length)
            #return tfa.seq2seq.BahdanauAttention(units,memory_sequence_length=memory_sequence_length)
    
        # wrap decodernn cell  
        def build_rnn_cell(self,batch_size ):
            rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnncell,self.attention_mechanism,attention_layer_size=dense_units)
            return rnn_cell
        
        def build_decoder_initial_state(self,batch_size,encoder_state,Dtype):
            decoder_initial_state = self.rnn_cell.get_initial_state(batch_size = batch_size,dtype = Dtype)
            decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state) 
            return decoder_initial_state
    
    encoderNetwork = EncoderNetwork(input_vocab_size,rnn_units)
    decoderNetwork = DecoderNetwork(output_vocab_size,rnn_units)
    optimizer = tf.keras.optimizers.Adam()

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。