如何解决如何在Keras中实现对Bahdanau的关注?
如何在Keras中实施Bahdanau注意层。喀拉拉邦是否有Bahdanau注意层(例如密实,lstm)? 如果没有,那么请您解释一下如何在keras中实现它。
解决方法
TensorFlow-Keras中有一个Bahdanau Attention的实现,更确切地说是tensorflow-addons
。
您可以在这里看看:https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BahdanauAttention
要使用tensorflow-addons
,请确保您pip install tensorflow-addons
。
在本教程中,您可以在这里找到:(https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt),例如,Luong的注意;您可以轻松地对其进行修改并选择Bahdanhau Attetion。
#ENCODER
class EncoderNetwork(tf.keras.Model):
def __init__(self,input_vocab_size,embedding_dims,rnn_units ):
super().__init__()
self.encoder_embedding = tf.keras.layers.Embedding(input_dim=input_vocab_size,output_dim=embedding_dims)
self.encoder_rnnlayer = tf.keras.layers.LSTM(rnn_units,return_sequences=True,return_state=True )
#DECODER
class DecoderNetwork(tf.keras.Model):
def __init__(self,output_vocab_size,rnn_units):
super().__init__()
self.decoder_embedding = tf.keras.layers.Embedding(input_dim=output_vocab_size,output_dim=embedding_dims)
self.dense_layer = tf.keras.layers.Dense(output_vocab_size)
self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units)
# Sampler
self.sampler = tfa.seq2seq.sampler.TrainingSampler()
# Create attention mechanism with memory = None
self.attention_mechanism = self.build_attention_mechanism(dense_units,None,BATCH_SIZE*[Tx])
self.rnn_cell = self.build_rnn_cell(BATCH_SIZE)
self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell,sampler= self.sampler,output_layer=self.dense_layer)
def build_attention_mechanism(self,units,memory,memory_sequence_length):
# HERE
return tfa.seq2seq.LuongAttention(units,memory = memory,memory_sequence_length=memory_sequence_length)
#return tfa.seq2seq.BahdanauAttention(units,memory_sequence_length=memory_sequence_length)
# wrap decodernn cell
def build_rnn_cell(self,batch_size ):
rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnncell,self.attention_mechanism,attention_layer_size=dense_units)
return rnn_cell
def build_decoder_initial_state(self,batch_size,encoder_state,Dtype):
decoder_initial_state = self.rnn_cell.get_initial_state(batch_size = batch_size,dtype = Dtype)
decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
return decoder_initial_state
encoderNetwork = EncoderNetwork(input_vocab_size,rnn_units)
decoderNetwork = DecoderNetwork(output_vocab_size,rnn_units)
optimizer = tf.keras.optimizers.Adam()
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。