微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow LSTM 返回什么?

如何解决Tensorflow LSTM 返回什么?

我正在使用编码器/解码器模式编写德语->英语翻译器,其中编码器通过传递其最后一个 LSTM 层的状态输出连接到解码器 作为解码器的 LSTM 的输入状态。

不过,我被卡住了,因为我不知道如何解释编码器的 LSTM 的输出一个小例子:

tensor = tf.random.normal( shape = [ 2,2,2 ])
lstm = tf.keras.layers.LSTM(units=4,return_sequences=True,return_state=True )
result = lstm( ( tensor )
print( "result:\n",result )

在 Tensorflow 2.0.0 中执行这个会产生:

result:
[
<tf.Tensor: id=6423,shape=(2,3),dtype=float32,numpy=
array([[[ 0.05060377,-0.00500009,-0.10052835],[ 0.01804499,0.0022153,0.01820258]],[[ 0.00813384,-0.08705016,0.06510869],[-0.00241707,-0.05084776,0.08321179]]],dtype=float32)>,<tf.Tensor: id=6410,numpy=
array([[ 0.01804499,0.01820258],0.08321179]],<tf.Tensor: id=6407,numpy=
array([[ 0.04316794,0.00382055,0.04829971],[-0.00499733,-0.10105743,0.1755833 ]],dtype=float32)>
]

结果是三个张量的列表。第一个似乎是所有的输出 隐藏状态,由 return_sequences=True 选择。我的问题是: result 中的第二个和第三个张量的解释是什么?

解决方法

Keras 中的 LSTM 单元为您提供三个输出:

  • 一个输出状态 o_t(第一个输出)
  • 隐藏状态 h_t(第二个输出)
  • 细胞状态 c_t(第三个输出)

你可以在这里看到一个 LSTM 单元: LSTM input/output diagram

输出状态通常传递给任何上层,但不会传递给右侧的任何层。您将在预测最终输出时使用此状态。

单元状态是从之前的 LSTM 单元传输到当前 LSTM 单元的信息。当它到达 LSTM 单元时,单元决定是否应该删除来自单元状态的信息,即我们将“忘记”某些状态。这是由一个遗忘门完成的:这个门将当前特征 x_t 作为输入和来自前一个单元 h_{t-1} 的隐藏状态。它输出一个概率向量,我们将其与最后一个细胞状态 c_{t-1} 相乘。在确定我们想要忘记哪些信息后,我们用输入门更新细胞状态。该门将当前特征 x_t 作为输入和来自前一个单元格 h_{t-1} 的隐藏状态,并产生一个输入,该输入被添加到最后一个单元格状态(我们已经忘记了信息)。这个总和就是新的细胞状态 c_t。 为了获得新的隐藏状态,我们将单元状态与隐藏状态向量结合起来,隐藏状态向量也是一个概率向量,决定了来自单元状态的哪些信息应该保留,哪些应该丢弃。

正如您正确解释的那样,第一个张量是所有隐藏状态的输出。

第二个张量是隐藏输出,即$h_t$,它充当神经网络的短期记忆 第三个张量是细胞输出,即$c_t$,作为神经网络的长时记忆

keras-documentation中写道

whole_seq_output,final_memory_state,final_carry_state = lstm(inputs)

不幸的是,他们不使用术语隐藏和细胞状态。在他们的术语中,记忆状态是短期记忆,即隐藏状态。进位状态通过所有 LSTM 单元进行,即单元状态。

我们也可以使用 source code of the LSTM cell 来验证这一点,其中向前的步骤由

给出
def step(cell_inputs,cell_states):
    """Step function that will be used by Keras RNN backend."""
    h_tm1 = cell_states[0]   #previous memory state
    c_tm1 = cell_states[2]   #previous carry state

    z = backend.dot(cell_inputs,kernel)
    z += backend.dot(h_tm1,recurrent_kernel)
    z = backend.bias_add(z,bias)

    z0,z1,z2,z3 = array_ops.split(z,4,axis=1)

    i = nn.sigmoid(z0)
    f = nn.sigmoid(z1)
    c = f * c_tm1 + i * nn.tanh(z2)
    o = nn.sigmoid(z3)

    h = o * nn.tanh(c)
    return h,[h,c]

从公式中我们可以很容易地看出,第一个和第二个输出是输出/隐藏状态,第三个输出是细胞状态。并且还声明他们将隐藏状态命名为“记忆状态”,将细胞状态命名为“携带状态”

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?