微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

自定义堆叠 LSTM 的输出与 nn.LSTM 不同

如何解决自定义堆叠 LSTM 的输出与 nn.LSTM 不同

我实现了多层 LSTM,但是如果 init_state 不是 None,结果与 nn.LSTM 不同。我将 LSTM 模型中的权重加载到我的自定义模型和 pytorch 的 nn.LSTM 模型中。我怀疑我可能在前向功能上做错了什么。任何帮助将不胜感激。非常感谢!


class StackedLSTMs(nn.Module):
    def __init__(self,input_sz:int,hidden_sz: int,num_layers: int):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_sz = hidden_sz
        self.LSTMs = nn.ModuleList()
        for layer in range(num_layers):
            if layer == 0:
                self.LSTMs.append(nn.LSTMCell(input_sz,hidden_sz))
                #self.LSTMs.append(NaiveCustomLSTMCell(input_sz,hidden_sz))
            else:
                self.LSTMs.append(nn.LSTMCell(hidden_sz,hidden_sz))
                #self.LSTMs.append(NaiveCustomLSTMCell(hidden_sz,hidden_sz))


    def forward(self,x,h: Optional[Tuple[torch.Tensor,torch.Tensor]] = None):
        print('hidden',h)
        seq_size,bs,_ = x.size()
        outputs = []
        if h is None:
            hn = torch.zeros(self.num_layers,self.hidden_sz)
            cn = torch.zeros(self.num_layers,self.hidden_sz)
        else:
            (hn,cn) = h

        for t in range(seq_size):
            for layer,lstm in enumerate(self.LSTMs):
                if layer == 0:
                    hn[layer,:,:],cn[layer,:] = lstm(x[t,(hn[layer,:]))
                else:
                    hn[layer,:] = lstm(hn[layer-1,:]))
            temp = hn[self.num_layers - 1,:].detach().clone()
            outputs.append(temp)
        outputs = torch.stack(outputs,dim=0)
        h = (hn,cn)
        #outputs = outputs.transpose(0,1).contiguous()
        return outputs,h


torch.manual_seed(999)
lstms = nn.LSTM(320,320,2)
stackedlstms = StackedLSTMs(320,2)

stackedlstms.LSTMs[0].weight_ih = oldmodel.prediction.dec_rnn.lstm.weight_ih_l0
stackedlstms.LSTMs[0].weight_hh = oldmodel.prediction.dec_rnn.lstm.weight_hh_l0
stackedlstms.LSTMs[0].bias_ih = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
stackedlstms.LSTMs[0].bias_hh = oldmodel.prediction.dec_rnn.lstm.bias_hh_l0

stackedlstms.LSTMs[1].weight_ih = oldmodel.prediction.dec_rnn.lstm.weight_ih_l1
stackedlstms.LSTMs[1].weight_hh = oldmodel.prediction.dec_rnn.lstm.weight_hh_l1
stackedlstms.LSTMs[1].bias_ih = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
stackedlstms.LSTMs[1].bias_hh = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1

lstms.weight_ih_l0 = oldmodel.prediction.dec_rnn.lstm.weight_ih_l0
lstms.weight_hh_l0 = oldmodel.prediction.dec_rnn.lstm.weight_hh_l0
lstms.bias_ih_l0 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
lstms.bias_hh_l0 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0

lstms.weight_ih_l1 = oldmodel.prediction.dec_rnn.lstm.weight_ih_l1
lstms.weight_hh_l1 = oldmodel.prediction.dec_rnn.lstm.weight_hh_l1
lstms.bias_ih_l1 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
lstms.bias_hh_l1 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1

hidden = torch.load('hidden.pt')
newembedt = torch.load('newembed_t.pt')

lstms_res = lstms(newembedt,hidden)
stackedlstms_res = stackedlstms(newembedt,hidden)

print(torch.sum(abs(lstms_res[0]-stackedlstms_res[0])))
print(torch.sum(abs(lstms_res[1][0]-stackedlstms_res[1][0])))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?