微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

调用tensorflow2模型未返回在call方法中定义的输出

如何解决调用tensorflow2模型未返回在call方法中定义的输出

当我调用tf2模型时,它并没有按照我在tf Model子类中定义call()方法的方式返回给我的值。

相反,调用模型的call()方法会返回我在build()方法中定义的张量

为什么会这样,我该如何解决

import numpy as np
import tensorflow as tf

num_items = 1000
emb_dim = 32
lstm_dim = 32

class rnn_model(tf.keras.Model):
    def __init__(self,num_items,emb_dim): 
        super(rnn_model,self).__init__()
        
        self.emb   = tf.keras.layers.Embedding(num_items,emb_dim,name='embedding_layer')
        self.GRU   = tf.keras.layers.LSTM(lstm_dim,name='rnn_layer')
        self.dense = tf.keras.layers.Dense(num_items,activation = 'softmax',name='final_layer')
        
    def call(self,inp,is_training=True):  
        emb = self.emb(inp)
        gru = self.GRU(emb)
        # logits=self.dense(gru)
        
        return gru # (bs,lstm_dim=50)
    
    def build(self,inp_shape):
      x = tf.keras.Input(shape=inp_shape,name='input_layer')
      # return tf.keras.Model(inputs=[x],outputs=self.call(x))
      return tf.keras.Model(inputs=[x],outputs=self.dense(self.call(x)))

maxlen = 10
model = rnn_model(num_items,emb_dim).build((maxlen,))
model.summary()

gru_out = model(inp)
print(gru_out.shape) # should have been (bs=16,lstm_dim=32)

以下是我得到的输出-

Model: "functional_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_layer (InputLayer)     [(None,10)]              0         
_________________________________________________________________
embedding_layer (Embedding)  (None,10,32)            32000     
_________________________________________________________________
rnn_layer (LSTM)             (None,32)                8320      
_________________________________________________________________
final_layer (Dense)          (None,1000)              33000     
=================================================================
Total params: 73,320
Trainable params: 73,320
Non-trainable params: 0
_________________________________________________________________
(16,1000)

我打算仅在模型末尾使用'final_layer'或致密层,将其输入到采样的softmax函数中,在该函数中,它将与gru_out一起使用以计算损失(以训练型号)。

在测试时,我打算将gru_out手动传递到model.get_layer('final_layer')中以获取最终登录信息。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?