如何解决在Enocder-Decoder架构中添加自定义注意层以使用keras进行神经机器翻译
我正在尝试使用tensorflow-2 Keras将神经机器翻译成印地文的神经机器翻译任务。我已经创建了自己的自定义注意层,但是我无法弄清楚如何在编码器解码器层之间插入。而且我还需要一些帮助来准备用于训练和推理的数据集,因为那里有很多不同的实现方式,而且我感到困惑,所以如果有人可以帮助我,那将真的很有帮助! 这是代码:
english_encoded_text = english_tokenizer.texts_to_sequences(english_data)
hindi_encoded_text = hindi_tokenizer.texts_to_sequences(hindi_data)
在这里,我已将数据转换为整数,现在我很困惑是填充这些序列还是一键填充! 现在是模型:
#attention layer
class attention(Layer):
def __init__(self,**kwargs):
super(attention,self).__init__(**kwargs)
def build(self,input_shape):
self.W = self.add_weight(name="attention_weight",shape=(input_shape[-1],1),initializer="normal")
self.b = self.add_weight(name="attention_bias",shape=(input_shape[1],initializer="zeros")
super(attention,self).build(input_shape)
def call(self,x):
energies = K.squeeze(K.tanh(K.dot(x,self.W) + self.b),axis=-1)
alphas = K.softmax(energies)
alphas = K.expand_dims(alphas,axis=-1)
context_vector = x*alphas
return K.sum(context_vector,axis=1)
def compute_output_shape(self,input_shape):
return (input_shape[0],input_shape[-1])
def get_config(self):
return super(attention,self).get_config()
#encoder-decoder with attention
latent_dim = 100
embedding_features = 256
s0 = Input(shape=(latent_dim,),name='s0') #initial hidden state for decoder
c0 = Input(shape=(latent_dim,name='c0') #initial cell state
encoder_inputs = Input(shape=(max_eng_sentence_length,english_vocab))
encoder_embedding = Embedding(english_vocab,embedding_features,input_length=max_eng_sentence_length)(encoder_inputs)
encoder_lstm = Bidirectional(LSTM(latent_dim,return_sequences=True,dropout=0.3,recurrent_dropout=0.2))(encoder_embedding)
attention_context_vector = attention()(encoder_lstm)
decoder_layer = LSTM(latent_dim,return_state = True)
s,_,c = decoder_layer(attention_context_vector,initial_state=[s0,c0])
output_layer = Dense(hindi_vocab,activation="softmax")
out = output_layer(s)
model = Model(inputs=[X,s0,c0],outputs=out)
我遇到以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-21-1ea2aa5b4365> in <module>()
8 encoder_inputs = Input(shape=(max_eng_sentence_length,english_vocab))
9 encoder_embedding = Embedding(english_vocab,input_length=max_eng_sentence_length)(encoder_inputs)
---> 10 encoder_lstm = Bidirectional(LSTM(latent_dim,recurrent_dropout=0.2))(encoder_embedding)
11 attention_context_vector = attention()(encoder_lstm)
12 decoder_layer = LSTM(latent_dim,return_state = True)
3 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/wrappers.py in __call__(self,inputs,initial_state,constants,**kwargs)
528
529 if initial_state is None and constants is None:
--> 530 return super(Bidirectional,self).__call__(inputs,**kwargs)
531
532 # Applies the same workaround as in `RNN.__call__`
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self,*args,**kwargs)
924 if _in_functional_construction_mode(self,args,kwargs,input_list):
925 return self._functional_construction_call(inputs,--> 926 input_list)
927
928 # Maintains info about the `Layer.call` stack.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self,input_list)
1090 # Todo(reedwm): We should assert input compatibility after the inputs
1091 # are casted,not before.
-> 1092 input_spec.assert_input_compatibility(self.input_spec,self.name)
1093 graph = backend.get_graph()
1094 # Use `self._name_scope()` to avoid auto-incrementing the name.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/input_spec.py in assert_input_compatibility(input_spec,layer_name)
178 'expected ndim=' + str(spec.ndim) + ',found ndim=' +
179 str(ndim) + '. Full shape received: ' +
--> 180 str(x.shape.as_list()))
181 if spec.max_ndim is not None:
182 ndim = x.shape.ndims
ValueError: Input 0 of layer bidirectional is incompatible with the layer: expected ndim=3,found ndim=4. Full shape received: [None,50,28103,256]
我需要帮助才能完成这项工作! 预先感谢!
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。