微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

注意Tensorflow的Bert双向LSTM

如何解决注意Tensorflow的Bert双向LSTM

我目前正在尝试从本文(https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9206937&casa_token=UW7CpTT3gc0AAAAA:ErrINfc0Coy4CKJKS2IX9GpnOHiI-k6kAKkHZIrFYwkw3CMkWVHpVE0JQzExpVStErF9_V1GJ3245Q&tag=1)复制以下分类模块。

本文对网络的描述如下:BERT与 双向LSTM,关注层和密集线性层。

当我尝试实现其他转换器时,我已经有了BERT(和其他)序列的表示。这意味着我只需要将序列的矢量表示(由Bert或其他转换器生成)输入Bidir。具有关注层的lstm。但是,我无法使其正常工作。

该论文的作者分享了他的代码,我想知道如何将Pytorch代码转换为TensorFlow模型。如您所见,他使用Bert对我已经拥有的序列进行矢量表示。因此,我仅对带有关注层的bidir.lstm感兴趣。


class Attention(nn.Module):
    def __init__(self,attention_size):
        super(Attention,self).__init__()
        self.attention = new_parameter(attention_size,1)
    def forward(self,x_in):
        # after this,we have (batch,dim1) with a diff weight per each cell
        attention_score = torch.matmul(x_in,self.attention).squeeze()
        attention_score = F.softmax(attention_score).view(x_in.size(0),x_in.size(1),1)
        scored_x = x_in * attention_score
        # Now,sum across dim 1 to get the expected feature vector
        condensed_x = torch.sum(scored_x,dim=1)
        return condensed_x
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.lstm=nn.LSTM(input_size=768,hidden_size=384,num_layers=2,dropout=.5,bidirectional=True)
        self.attention=Attention(768)
        self.classifier=nn.Linear(768,2)
    def forward(self,input,attention_mask):
        _,x=self.bert(input,attention_mask=attention_mask)
        x,(h,c)=self.lstm(x.unsqueeze(0))
        x=self.attention(x.view(x.shape[1],1,768))
        x=self.classifier(x)
        return x*

我的当前模型如下所示,我将768的Bert表示输入到Bidir中。 LSTM。现在,我需要以某种方式添加注意力机制。有人知道如何在Tensorflow中做到这一点吗?

import tensorflow as tf
from tensorflow.keras.layers import Bidirectional,LSTM,Dense,Dropout,Attention
model = tf.keras.Sequential()
model.add(Bidirectional(LSTM(384,return_sequences=True),input_shape=(1,768)))
model.add(Dropout(0.5))
model.add(Dense(1,activation="sigmoid"))

model.summary()

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。