如何解决调用tensorflow2模型未返回在call方法中定义的输出
当我调用tf2模型时,它并没有按照我在tf Model子类中定义call()方法的方式返回给我的值。
相反,调用模型的call()方法会返回我在build()方法中定义的张量
为什么会这样,我该如何解决?
import numpy as np
import tensorflow as tf
num_items = 1000
emb_dim = 32
lstm_dim = 32
class rnn_model(tf.keras.Model):
def __init__(self,num_items,emb_dim):
super(rnn_model,self).__init__()
self.emb = tf.keras.layers.Embedding(num_items,emb_dim,name='embedding_layer')
self.GRU = tf.keras.layers.LSTM(lstm_dim,name='rnn_layer')
self.dense = tf.keras.layers.Dense(num_items,activation = 'softmax',name='final_layer')
def call(self,inp,is_training=True):
emb = self.emb(inp)
gru = self.GRU(emb)
# logits=self.dense(gru)
return gru # (bs,lstm_dim=50)
def build(self,inp_shape):
x = tf.keras.Input(shape=inp_shape,name='input_layer')
# return tf.keras.Model(inputs=[x],outputs=self.call(x))
return tf.keras.Model(inputs=[x],outputs=self.dense(self.call(x)))
maxlen = 10
model = rnn_model(num_items,emb_dim).build((maxlen,))
model.summary()
gru_out = model(inp)
print(gru_out.shape) # should have been (bs=16,lstm_dim=32)
以下是我得到的输出-
Model: "functional_11"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) [(None,10)] 0
_________________________________________________________________
embedding_layer (Embedding) (None,10,32) 32000
_________________________________________________________________
rnn_layer (LSTM) (None,32) 8320
_________________________________________________________________
final_layer (Dense) (None,1000) 33000
=================================================================
Total params: 73,320
Trainable params: 73,320
Non-trainable params: 0
_________________________________________________________________
(16,1000)
我打算仅在模型末尾使用'final_layer'或致密层,将其输入到采样的softmax函数中,在该函数中,它将与gru_out一起使用以计算损失(以训练型号)。
在测试时,我打算将gru_out手动传递到model.get_layer('final_layer')中以获取最终登录信息。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。