微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

将 tf1 中的代码转换为 tf2 时出错

如何解决将 tf1 中的代码转换为 tf2 时出错

值在哪里

rnn_size: 512
batch_size: 128


rnn_inputs: Tensor("embedding_lookup/Identity_1:0",shape=(?,?,128),dtype=float32)
sequence_length: Tensor("inputs_length:0",),dtype=int32)
cell_fw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb6d0>
cell_bw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb910>

通过获取 enc_state 值

enc_output,enc_state = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,rnn_inputs,sequence_length,dtype=tf.float32)

enc_state 值在哪里

enc_state: LSTMStateTuple(c=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(?,512) dtype=float32>,h=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(?,512) dtype=float32>)

TF1 代码

initial_state = tf.contrib.seq2seq.DynamicAttentionWrapperState(enc_state,_zero_state_tensors(rnn_size,batch_size,tf.float32))

通过

转换成TF2
initial_state = tfa.seq2seq.AttentionWrapper(enc_state,tf.float32))

获取错误


TypeError                                 Traceback (most recent call last)
<ipython-input-54-d87646b9df5d> in <module>()
      8                                                     threshold) 
      9             model = build_graph(keep_probability,rnn_size,num_layers,---> 10                                 learning_rate,embedding_size,direction)
     11             train(model,epochs,log_string)

6 frames
/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py in check_type(argname,value,expected_type,memo)
    596                 raise TypeError(
    597                     'type of {} must be {}; got {} instead'.
--> 598                     format(argname,qualified_name(expected_type),qualified_name(value)))
    599     elif isinstance(expected_type,TypeVar):
    600         # Only happens on < 3.6

TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead

还可以解释错误的最后一行,即

    TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。