微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在Tensorflow 2.x中使用波束搜索实现seq2seq模型

如何解决在Tensorflow 2.x中使用波束搜索实现seq2seq模型

我想使用波束搜索创建seq2seq模型。我遇到的问题是找不到关于如何使用TensorFlow在我的解码器中实现波束搜索的指南。

背景

通过使用seq2seq模型,可以将手写的数学公式转换为乳胶代码。我遵循了this教程,但是它是使用Tensorflow的旧版本编写的,我并不完全熟悉。

我知道这个问题与图像字幕问题非常相似。因此,我认为一个不错的起点是实现从this Tensorflow tutorial到图像到乳胶问题的相同管道。

当前方法

我目前的做法是

class Decoder(tf.keras.Model):
    def __init__(self,embedding_dim,units,vocab_size):
        super(Decoder,self). __init__()
        self.units = units

        self.embedding = tf.keras.layers.Embedding(vocab_size,embedding_dim)
        self.gru = tf.keras.layers.GRU(self.units,return_sequences=True,return_state=True,recurrent_initializer='glorot_uniform')
        self.fc1 = tf.keras.layers.Dense(self.units)
        self.fc2 = tf.keras.layers.Dense(vocab_size)

        self.attention = BahdanauAttention(self.units)

    def call(self,x,features,hidden):
        # defining attention as a separate model
        context_vector,attention_weights = self.attention(features,hidden)

        # x shape after passing through embedding == (batch_size,1,embedding_dim)
        x = self.embedding(x)

        # x shape after concatenation == (batch_size,embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector,1),x],axis=-1)

        # passing the concatenated vector to the GRU
        output,state = self.gru(x)

        # shape == (batch_size,max_length,hidden_size)
        x = self.fc1(output)

        # x shape == (batch_size * max_length,hidden_size)
        x = tf.reshape(x,(-1,x.shape[2]))

        # output shape == (batch_size * max_length,vocab)
        x = self.fc2(x)

        return x,state,attention_weights

    def reset_state(self,batch_size):
        return tf.zeros((batch_size,self.units))

感谢您的帮助!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。