微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

递归网络PyTorch LSTM

如何解决递归网络PyTorch LSTM

我正在尝试在PyTorch中堆叠的LSTM网络的各层之间进行规范化。网络看起来像这样:

class LSTMClassifier(nn.Module):
    def __init__(self,input_dim,hidden_dim,layer_dim,output_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.lstm1 = nn.LSTM(input_dim,batch_first=True)
        self.lstm2 = nn.LSTM(input_dim,batch_first=True)
        self.fc1 = nn.Linear(hidden_dim,32)
        self.fc2 = nn.Linear(32,1)
        self.dropout = nn.Dropout(p=0.2)
        self.batch_normalisation1 = nn.Batchnorm1d(hidden_dim)
        self.batch_normalisation2 = nn.Batchnorm1d(hidden_dim)

    def forward(self,x):
        h0,c0 = self.init_hidden(x)
        out,(hn1,cn1) = self.lstm1(x,(h0,c0))
        out = self.dropout(out)                     # error line
        out = self.batch_normalisation1(out)
        
        h1,c1 = self.init_hidden(out)
        out,(hn2,cn2) = self.lstm2(out,(h1,c1))
        out = self.dropout(out)
        out = self.batch_normalisation1(out)
        
        h2,c2 = self.init_hidden(out)
        out,(hn3,cn3) = self.lstm2(out,(h2,c2))
        out = self.dropout(out)
        out = self.batch_normalisation1(out)
        
        out = self.fc1(out[:,-1,:])
        out = self.dropout(out)
        out = self.fc2(out)
        return out
    
    def init_hidden(self,x):
        h0 = torch.zeros(self.layer_dim,x.size(0),self.hidden_dim)
        c0 = torch.zeros(self.layer_dim,self.hidden_dim)
        return [t for t in (h0,c0)]

在上面我提到的地方出现了一个错误,这是由于Batchnorm1d期望二维输入。

特别是,我正在初始化模型并将批处理的数据传递给网络,如下所示:

model = LSTMClassifier(5,128,3,1)
model(X)

错误RuntimeError: running_mean should contain 3 elements not 128

X输入张量的形状为 torch.Size([10,5]),即批量大小为10,每个输入的尺寸为3 x 5,即5个特征和3个时间步长。

由于Batchnorm1d试图在错误的维度上进行标准化而导致错误-在网络中,变量out的形状为torch.Size([1,128]),即5个输入要素被映射为128个超变量。

我可以重塑转发函数中的变量,但这似乎是不必要的。我也尝试过使用Batchnorm2d,但是它需要一个4d张量,而我的变量不是。有什么方法可以解决这个问题吗?

此外,我正在尝试在我的网络中添加规范化以加快培训速度-我不确定PyTorch Batchnorm函数的工作方式,因此不胜感激。具体来说,为什么我们要对时间维度而非特征维度进行归一化?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。