微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

来自Keras的TF-Hub的Albert层

如何解决来自Keras的TF-Hub的Albert层

我已经使用tf-hub通过以下方式将Bert实现为在keras模型中使用的层:

import numpy as np
my_list = [0,5,6,8,-10]
def choose_at_random(a_list):
    choice = np.random.randint(0,len(a_list))
    return choice   

如果我想以相同的方式使用tf-hub将Albert层实现,则仅将import os import tensorflow as tf import tensorflow_hub as hub from keras.layers import Layer from tensorflow.keras import backend as K os.environ["TFHUB_CACHE_DIR"] = "./data/bert_models" class BertLayer(Layer): def __init__( self,n_fine_tune_layers=10,pooling="first",bert_path='https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1',**kwargs,): self.n_fine_tune_layers = n_fine_tune_layers self.trainable = True self.output_size = 768 self.pooling = pooling self.bert_path = bert_path if self.pooling not in ["first","mean"]: raise NameError( f"Undefined pooling type (must be either first or mean,but is {self.pooling}" ) super(BertLayer,self).__init__(**kwargs) def build(self,input_shape): self.bert = hub.Module( self.bert_path,trainable=self.trainable,name=f"{self.name}_module" ) # Remove unused layers trainable_vars = self.bert.variables if self.pooling == "first": trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name] trainable_layers = ["pooler/dense"] elif self.pooling == "mean": trainable_vars = [ var for var in trainable_vars if not "/cls/" in var.name and not "/pooler/" in var.name ] trainable_layers = [] else: raise NameError( f"Undefined pooling type (must be either first or mean,but is {self.pooling}" ) # Select how many layers to fine tune for i in range(self.n_fine_tune_layers): trainable_layers.append(f"encoder/layer_{str(11 - i)}") # Update trainable vars to contain only the specified layers trainable_vars = [ var for var in trainable_vars if any([l in var.name for l in trainable_layers]) ] # Add to trainable weights for var in trainable_vars: self._trainable_weights.append(var) for var in self.bert.variables: if var not in self._trainable_weights: self._non_trainable_weights.append(var) super(BertLayer,self).build(input_shape) def call(self,inputs): inputs = [K.cast(x,dtype="int32") for x in inputs] input_ids,input_mask,segment_ids = inputs bert_inputs = dict( input_ids=input_ids,input_mask=input_mask,segment_ids=segment_ids ) if self.pooling == "first": pooled = self.bert(inputs=bert_inputs,signature="tokens",as_dict=True)[ "pooled_output" ] elif self.pooling == "mean": result = self.bert(inputs=bert_inputs,as_dict=True)[ "sequence_output" ] mul_mask = lambda x,m: x * tf.expand_dims(m,axis=-1) masked_reduce_mean = lambda x,m: tf.reduce_sum(mul_mask(x,m),axis=1) / ( tf.reduce_sum(m,axis=1,keepdims=True) + 1e-10) input_mask = tf.cast(input_mask,tf.float32) pooled = masked_reduce_mean(result,input_mask) else: raise NameError(f"Undefined pooling type (must be either first or mean,but is {self.pooling}") return pooled def compute_output_shape(self,input_shape): return (input_shape[0][0],self.output_size) 更改为Albert模型URL才能正常工作吗?还是应该对代码进行其他更改?我只想确保Albert层能正常工作。谢谢。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?