如何在Keras API自定义层tanh_layer中输入4个张量head,arch,pos,embedding

如何解决如何在Keras API自定义层tanh_layer中输入4个张量head,arch,pos,embedding

在将以下hidden_​​state *(word_head,word_arch,word_pos,嵌入隐藏状态,所有shape(?,256)张量的张量)输入到自定义层**(tanh_layer下面)时,我遇到了几天的问题。它弹出一条错误消息,提示输入声明与图层不匹配。在类方法init,build或tanh_layer的调用中,如何定义4个张量作为输入,以便模型识别4个输入张量?我怀疑会在调用方法中传递参数“ inputs”,并使head,arch,pos,text = input传递4个张量,但该层无法识别。
这是用于神经机器翻译的biLSTM +注意解码器模型,我通过tanh层而不是串联来添加依赖项解析上下文(头部,拱门,pos)(我做到了没有问题)。 请帮助和赞赏。 (下面的1.我的模型的源代码2. tanh_layer的源代码)

ValueError:层tanh_layer期望有1个输入,但它收到了4个输入张量。收到的输入:[]


def define_model(src_vocab,tar_vocab,src_timesteps,tar_timesteps,word_pos_vocab_size,n_units):
    word_head_input = Input(shape=(src_timesteps,),name = 'word_head')
    word_head_reshape = Reshape((src_timesteps,1),name='word_head_reshape') (word_head_input)
    word_head_hidden_state = Bidirectional(LSTM(n_units,dropout =0.3,input_shape=(src_timesteps,1)))(word_head_reshape)

    word_arch_input = Input(shape=(src_timesteps,name = 'word_arch')
    word_arch_reshape = Reshape((src_timesteps,name='word_arch_reshape') (word_arch_input)
    word_arch_hidden_state = Bidirectional(LSTM(n_units,dropout = 0.3,1)))(word_arch_reshape)

    word_pos_input = Input(shape=(src_timesteps,name="word_pos")
    word_pos_one_hot = Embedding(word_pos_vocab_size,n_units,input_length=src_timesteps,mask_zero=True)(word_pos_input)
    word_pos_hidden_state = Bidirectional(LSTM(n_units,dropout = 0.3))(word_pos_one_hot)

    encoder_input = Input(shape=(src_timesteps,name="word_text")
    one_hot = Embedding(src_vocab,mask_zero=True)(encoder_input)
    embedding_hidden_state = Bidirectional(LSTM(n_units,dropout=0.3))(one_hot)

*  hidden_state = [word_head_hidden_state,word_arch_hidden_state,word_pos_hidden_state,embedding_hidden_state]

**  tanh_hidden_state = tanh_layer(n_units,tar_timesteps)(hidden_state)

    decoder_output = AttentionDecoder(n_units,src_timesteps)(tanh_hidden_state)
    model = Model(inputs = [encoder_input,word_head_input,word_arch_input,word_pos_input],outputs = decoder_output,name="Autoencoder")
return model


import tensorflow as tf
from keras import backend as K
from keras.layers import Layer
from keras import regularizers,constraints,initializers,activations
from keras.engine import InputSpec


class tanh_layer(Layer):
    def __init__(self,units,activation ='tanh',name ='tanh_layer',kernel_initializer='glorot_uniform',bias_initializer='zeros',kernel_regularizer=None,bias_regularizer=None,kernel_constraint=None,bias_constraint=None,**kwargs):
        """
        Implements a tanh layer to integrate the head word,word arch,POS tags and text in a FFN
        """
        self.units = units
        self.tar_timesteps = tar_timesteps
        self.activation = activations.get(activation)
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        print('start _init_')
        print('self.units',self.units)

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        super(tanh_layer,self).__init__(**kwargs)
        self.name = name
        print('self.name',self.name)
        print('end_initi_')


    def build(self,input_shape,*arg,**kwargs):
        """
        Initialize Um,Vm,Wm,Xm and bm for the merging layer
        Matrices for creating the context vector
        """
        
        print('start build')
        print('input_shape_dim',input_shape)
        print(input_shape[0],input_shape[1],input_shape[2],input_shape[3])
        self.batch_size,self.units = input_shape[0]
        self.batch_size,self.units = input_shape[1]
        self.batch_size,self.units = input_shape[2]
        self.batch_size,self.units = input_shape[3]
        print('input_shape',self.units)

        self.U_m = self.add_weight(shape=(self.units,self.units),name='U_m',initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,trainable = True)
        self.V_m = self.add_weight(shape=(self.units,name='V_m',trainable = True)
        self.W_m = self.add_weight(shape=(self.units,name='W_m',trainable = True)
        self.X_m = self.add_weight(shape=(self.units,name='X_m',trainable = True)
        self.b_m = self.add_weight(shape=(self.units,name='b_m',initializer=self.bias_initializer,regularizer=self.bias_regularizer,constraint=self.bias_constraint,trainable = True)

        self.input_spec = InputSpec(shape=input_shape)
        print(self.input_spec)
        super(tanh_layer,self).build(input_shape)
        
        print('U_m',self.U_m)
        print('V_m',self.V_m)
        print('W_m',self.W_m)
        print('X_m',self.X_m)
        print('b_m',self.b_m)
        print('U_m',type(self.U_m))
        print('V_m',type(self.V_m))
        print('W_m',type(self.W_m))
        print('X_m',type(self.X_m))
        print('b_m',type(self.b_m))
        
        print('end build')
        self.built = True

    def call(self,inputs,**kwargs):
        """
        Call the merging all four components in here
        Merging all four components in here
        """
        print('start call')
        self.inputs = inputs
        text,head,arch,pos = self.inputs

        print('self.head',type(self.head))
        print('self.arch',type(self.arch))
        print('self.pos',type(self.pos))
        print('self.text',type(self.text))
        print('self.units',self.units)
        print('self.head',self.head)
        print('self.arch',self.arch)
        print('self.pos',self.pos)
        print('self.text',self.text)
        print('hidden_state',self.hidden_state)
        print('hidden_state',type(self.hidden_state))
#        h_state = head + arch + pos + text
        # calculate the sum of the hidden state:
        self.tanh_hidden_state = activations.tanh(
            K.dot(self.head,self.U_m)
            + K.dot(self.arch,self.V_m)
            + K.dot(self.pos,self.W_m)
            + K.dot(self.text,self.X_m)
            + self.b_m)
        
        self.tanh_hidden_state = tf.reshape(self.tanh_hidden_state,[-1,self.tar_timesteps,self.units])
        
        print('tanh_hidden_state',K.int_shape(self.tanh_hidden_state))
        print('tanh_hidden_state',self.tanh_hidden_state)
        print('type tanh_hidden_state',type(self.tanh_hidden_state))
        print('endcall')
        return self.tanh_hidden_state
    
    def compute_output_shape(self,input_shape):
        print('compute_output_shape',K.int_shape(self.tanh_hidden_state))
        return K.int_shape(self.tanh_hidden_state)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)> insert overwrite table dwd_trade_cart_add_inc > select data.id, > data.user_id, > data.course_id, > date_format(
错误1 hive (edu)> insert into huanhuan values(1,'haoge'); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive> show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 <configuration> <property> <name>yarn.nodemanager.res