微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

基于 LSTM 的 EEG 信号分类架构基于“Channel LSTM”

如何解决基于 LSTM 的 EEG 信号分类架构基于“Channel LSTM”

我有一个多类分类问题,我在 python 3.6 中使用了 keras 和 tensorflow。基于本文中提到的“stacked LSTM 层 (a)”,我有一个很好的分类实现:Deep Learning Human Mind for Automated Visual Classification

这样的事情:

model.add(LSTM(256,input_shape=(32,15360),return_sequences=True))
model.add(LSTM(128),return_sequences=True)
model.add(LSTM(64),return_sequences=False)

model.add(Dense(6,activation='softmax'))

让 32 是 EEG 通道的数量,15360 是 96 秒记录中 160 Hz 的信号长度

我想实现上面文章中提到的“Channel LSTM 和 Common LSTM (b)”策略,但我不知道我应该如何通过这种新策略来制作我的模型。 >

请帮帮我。谢谢

enter image description here

解决方法

首先,您在使用 Common LSTM 实现编码器时遇到问题,默认情况下 LSTM layer of keras 接受形状为 (batch,timesteps,channel) 的输入,因此如果您设置 {{ 1}} 那么模型将读取为 input_shape=(32,15360)timesteps=32,这与您的意图相反。

因为使用Common LSTM的第一层编码器描述为:

在每个时间步长 t,第一层取输入 s(·,t)(在这个 从某种意义上说,“通用”意味着所有 EEG 通道最初都被输入 8 相同的 LSTM 层)

所以使用Common LSTM的编码器的正确实现是:

channel=15360

哪些输出(PS:你可以总结一下你的原始实现,你会看到 import tensorflow as tf from tensorflow.keras import layers,models timesteps = 15360 channels_num = 32 model = models.Sequential() model.add(layers.LSTM(256,input_shape=(timesteps,channels_num),return_sequences=True)) model.add(layers.LSTM(128,return_sequences=True)) model.add(layers.LSTM(64,return_sequences=False)) model.add(layers.Dense(6,activation='softmax')) model.summary() 更大):

Total params

其次,因为使用Channel LSTM和Common LSTM的编码器描述为:

第一个编码层由几个 LSTM 组成,每个 LSTM 连接到 只有一个输入通道:例如,第一个 LSTM 处理输入 datas(1,·),第二个LSTM处理ss(2,·),依此类推。这样, 每个“通道 LSTM”的输出是单个通道的总结 数据。然后第二个编码层执行通道间分析, 通过接收所有通道的连接输出向量作为输入 LSTM。如上,最深LSTM在最后一个时间步的输出 用作编码器的输出向量。

由于第一层的每个 LSTM 只处理一个通道,所以我们需要 LSTM 的数量等于第一层的通道数量,下面的代码展示了如何使用 Channel LSTM 和 Common 构建一个编码器LSTM

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
lstm (LSTM)                  (None,15360,256)        295936
_________________________________________________________________
lstm_1 (LSTM)                (None,128)        197120
_________________________________________________________________
lstm_2 (LSTM)                (None,64)                49408
_________________________________________________________________
dense (Dense)                (None,6)                 390
=================================================================
Total params: 542,854
Trainable params: 542,854
Non-trainable params: 0
_________________________________________________________________

输出:

import tensorflow as tf
from tensorflow.keras import layers,models

timesteps = 15360
channels_num = 32

first_layer_inputs = []
second_layer_inputs = []
for i in range(channels_num):
    l_input = layers.Input(shape=(timesteps,1))
    first_layer_inputs.append(l_input)
    l_output = layers.LSTM(1,return_sequences=True)(l_input)
    second_layer_inputs.append(l_output)

x = layers.Concatenate()(second_layer_inputs)
x = layers.LSTM(128,return_sequences=True)(x)
x = layers.LSTM(64,return_sequences=False)(x)
outputs = layers.Dense(6,activation='softmax')(x)

model = models.Model(inputs=first_layer_inputs,outputs=outputs)

model.summary()

现在因为模型需要形状为 Model: "functional_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_2 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_3 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_4 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_5 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_6 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_7 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_8 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_9 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_10 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_11 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_12 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_13 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_14 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_15 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_16 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_17 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_18 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_19 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_20 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_21 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_22 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_23 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_24 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_25 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_26 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_27 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_28 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_29 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_30 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_31 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ input_32 (InputLayer) [(None,1)] 0 __________________________________________________________________________________________________ lstm (LSTM) (None,1) 12 input_1[0][0] __________________________________________________________________________________________________ lstm_1 (LSTM) (None,1) 12 input_2[0][0] __________________________________________________________________________________________________ lstm_2 (LSTM) (None,1) 12 input_3[0][0] __________________________________________________________________________________________________ lstm_3 (LSTM) (None,1) 12 input_4[0][0] __________________________________________________________________________________________________ lstm_4 (LSTM) (None,1) 12 input_5[0][0] __________________________________________________________________________________________________ lstm_5 (LSTM) (None,1) 12 input_6[0][0] __________________________________________________________________________________________________ lstm_6 (LSTM) (None,1) 12 input_7[0][0] __________________________________________________________________________________________________ lstm_7 (LSTM) (None,1) 12 input_8[0][0] __________________________________________________________________________________________________ lstm_8 (LSTM) (None,1) 12 input_9[0][0] __________________________________________________________________________________________________ lstm_9 (LSTM) (None,1) 12 input_10[0][0] __________________________________________________________________________________________________ lstm_10 (LSTM) (None,1) 12 input_11[0][0] __________________________________________________________________________________________________ lstm_11 (LSTM) (None,1) 12 input_12[0][0] __________________________________________________________________________________________________ lstm_12 (LSTM) (None,1) 12 input_13[0][0] __________________________________________________________________________________________________ lstm_13 (LSTM) (None,1) 12 input_14[0][0] __________________________________________________________________________________________________ lstm_14 (LSTM) (None,1) 12 input_15[0][0] __________________________________________________________________________________________________ lstm_15 (LSTM) (None,1) 12 input_16[0][0] __________________________________________________________________________________________________ lstm_16 (LSTM) (None,1) 12 input_17[0][0] __________________________________________________________________________________________________ lstm_17 (LSTM) (None,1) 12 input_18[0][0] __________________________________________________________________________________________________ lstm_18 (LSTM) (None,1) 12 input_19[0][0] __________________________________________________________________________________________________ lstm_19 (LSTM) (None,1) 12 input_20[0][0] __________________________________________________________________________________________________ lstm_20 (LSTM) (None,1) 12 input_21[0][0] __________________________________________________________________________________________________ lstm_21 (LSTM) (None,1) 12 input_22[0][0] __________________________________________________________________________________________________ lstm_22 (LSTM) (None,1) 12 input_23[0][0] __________________________________________________________________________________________________ lstm_23 (LSTM) (None,1) 12 input_24[0][0] __________________________________________________________________________________________________ lstm_24 (LSTM) (None,1) 12 input_25[0][0] __________________________________________________________________________________________________ lstm_25 (LSTM) (None,1) 12 input_26[0][0] __________________________________________________________________________________________________ lstm_26 (LSTM) (None,1) 12 input_27[0][0] __________________________________________________________________________________________________ lstm_27 (LSTM) (None,1) 12 input_28[0][0] __________________________________________________________________________________________________ lstm_28 (LSTM) (None,1) 12 input_29[0][0] __________________________________________________________________________________________________ lstm_29 (LSTM) (None,1) 12 input_30[0][0] __________________________________________________________________________________________________ lstm_30 (LSTM) (None,1) 12 input_31[0][0] __________________________________________________________________________________________________ lstm_31 (LSTM) (None,1) 12 input_32[0][0] __________________________________________________________________________________________________ concatenate (Concatenate) (None,32) 0 lstm[0][0] lstm_1[0][0] lstm_2[0][0] lstm_3[0][0] lstm_4[0][0] lstm_5[0][0] lstm_6[0][0] lstm_7[0][0] lstm_8[0][0] lstm_9[0][0] lstm_10[0][0] lstm_11[0][0] lstm_12[0][0] lstm_13[0][0] lstm_14[0][0] lstm_15[0][0] lstm_16[0][0] lstm_17[0][0] lstm_18[0][0] lstm_19[0][0] lstm_20[0][0] lstm_21[0][0] lstm_22[0][0] lstm_23[0][0] lstm_24[0][0] lstm_25[0][0] lstm_26[0][0] lstm_27[0][0] lstm_28[0][0] lstm_29[0][0] lstm_30[0][0] lstm_31[0][0] __________________________________________________________________________________________________ lstm_32 (LSTM) (None,128) 82432 concatenate[0][0] __________________________________________________________________________________________________ lstm_33 (LSTM) (None,64) 49408 lstm_32[0][0] __________________________________________________________________________________________________ dense (Dense) (None,6) 390 lstm_33[0][0] ================================================================================================== Total params: 132,614 Trainable params: 132,614 Non-trainable params: 0 __________________________________________________________________________________________________ 的输入,所以我们必须在输入模型之前重新排序数据集的轴,以下示例代码展示了如何将轴从 (channel,batch,1) 重新排序为 {{ 1}}:

(batch,channel)

输出:

(channel,1)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。