微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么Keras不能在lstm层返回单元状态的完整序列?

如何解决为什么Keras不能在lstm层返回单元状态的完整序列?

我正在尝试实现一种关注机制,其中我需要单元格状态的完整序列(就像隐藏状态的完整序列一样)。 Keras LSTM但是仅返回最后一个单元格状态:

output,state_h,state_c = layers.LSTM(units=45,return_state=True,return_sequences=True)

state_c的形状为(batch size,1,45),其中输出(全序列隐藏状态)的形状为(batch size,5,45)。 5是时间窗口长度

为什么Keras不返回全序列细胞状态?与下面的方法相比,有没有更好的方法获取完整的细胞状态序列?

full_hidden,full_cell,outputs = [],[],[]
state = None
input = layers.Input(shape=(time_window,features),dtype='float32')
output = layers.LSTM(units=45,return_state=True)

for i in range(time_window):
    input_t = input[:,i,:]
    input_t = tf.expand_dims(input_t,1)
    out,state_c = lstm(input_t,initial_state=state)
    state = state_h,state_c
    full_hidden.append(state_h)
    full_cell.append(state_c)
    outputs.append(out)

解决方法

您需要将标志return_sequences设置为True才能获取所有时间状态。您使用的标志return_state=True使图层返回最终状态。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。