如何解决自定义堆叠 LSTM 的输出与 nn.LSTM 不同
我实现了多层 LSTM,但是如果 init_state 不是 None,结果与 nn.LSTM 不同。我将 LSTM 模型中的权重加载到我的自定义模型和 pytorch 的 nn.LSTM 模型中。我怀疑我可能在前向功能上做错了什么。任何帮助将不胜感激。非常感谢!
class StackedLSTMs(nn.Module):
def __init__(self,input_sz:int,hidden_sz: int,num_layers: int):
super().__init__()
self.num_layers = num_layers
self.hidden_sz = hidden_sz
self.LSTMs = nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.LSTMs.append(nn.LSTMCell(input_sz,hidden_sz))
#self.LSTMs.append(NaiveCustomLSTMCell(input_sz,hidden_sz))
else:
self.LSTMs.append(nn.LSTMCell(hidden_sz,hidden_sz))
#self.LSTMs.append(NaiveCustomLSTMCell(hidden_sz,hidden_sz))
def forward(self,x,h: Optional[Tuple[torch.Tensor,torch.Tensor]] = None):
print('hidden',h)
seq_size,bs,_ = x.size()
outputs = []
if h is None:
hn = torch.zeros(self.num_layers,self.hidden_sz)
cn = torch.zeros(self.num_layers,self.hidden_sz)
else:
(hn,cn) = h
for t in range(seq_size):
for layer,lstm in enumerate(self.LSTMs):
if layer == 0:
hn[layer,:,:],cn[layer,:] = lstm(x[t,(hn[layer,:]))
else:
hn[layer,:] = lstm(hn[layer-1,:]))
temp = hn[self.num_layers - 1,:].detach().clone()
outputs.append(temp)
outputs = torch.stack(outputs,dim=0)
h = (hn,cn)
#outputs = outputs.transpose(0,1).contiguous()
return outputs,h
torch.manual_seed(999)
lstms = nn.LSTM(320,320,2)
stackedlstms = StackedLSTMs(320,2)
stackedlstms.LSTMs[0].weight_ih = oldmodel.prediction.dec_rnn.lstm.weight_ih_l0
stackedlstms.LSTMs[0].weight_hh = oldmodel.prediction.dec_rnn.lstm.weight_hh_l0
stackedlstms.LSTMs[0].bias_ih = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
stackedlstms.LSTMs[0].bias_hh = oldmodel.prediction.dec_rnn.lstm.bias_hh_l0
stackedlstms.LSTMs[1].weight_ih = oldmodel.prediction.dec_rnn.lstm.weight_ih_l1
stackedlstms.LSTMs[1].weight_hh = oldmodel.prediction.dec_rnn.lstm.weight_hh_l1
stackedlstms.LSTMs[1].bias_ih = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
stackedlstms.LSTMs[1].bias_hh = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
lstms.weight_ih_l0 = oldmodel.prediction.dec_rnn.lstm.weight_ih_l0
lstms.weight_hh_l0 = oldmodel.prediction.dec_rnn.lstm.weight_hh_l0
lstms.bias_ih_l0 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
lstms.bias_hh_l0 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
lstms.weight_ih_l1 = oldmodel.prediction.dec_rnn.lstm.weight_ih_l1
lstms.weight_hh_l1 = oldmodel.prediction.dec_rnn.lstm.weight_hh_l1
lstms.bias_ih_l1 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
lstms.bias_hh_l1 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
hidden = torch.load('hidden.pt')
newembedt = torch.load('newembed_t.pt')
lstms_res = lstms(newembedt,hidden)
stackedlstms_res = stackedlstms(newembedt,hidden)
print(torch.sum(abs(lstms_res[0]-stackedlstms_res[0])))
print(torch.sum(abs(lstms_res[1][0]-stackedlstms_res[1][0])))
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。