如何解决使用 PyTorch-ligtning 训练 RNN 时出现 AssertionError
我是 PyTorch 的新手,所以我使用 PyTorch-Lightning 来训练简单的(Vanilla)RNN:
1.数据准备
import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import numpy as np
import pandas as pd
#...
#X_train,Y_train are np arrs with shape (n,t,d)
X_train_tensors = torch.Tensor(X_train).to(device)
Y_train_tensors = torch.Tensor(Y_train).to(device)
#create train dataset
train = TensorDataset(X_train_tensors,Y_train_tensors)
trainloader = DataLoader(train,batch_size=32,shuffle=False)
2.创建Learner类
#use pl to create learner
class Learner(pl.LightningModule):
def __init__(self,model:nn.Module):
super().__init__()
self.model = model
def training_step(self,batch,batch_idx):
x,y = batch
y_hat = self.model(x)
loss = nn.MSELoss()(y_hat,y)
logs = {'train_loss': loss}
return {'loss': loss,'log': logs}
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(),lr=0.001)
3.创建模型并使用训练器
NN = nn.Sequential(
nn.RNN(1,3,nonlinearity='tanh',batch_first=True),nn.RNN(3,5,nn.Linear(5,1)
)
model = Learner(NN)
trainer = pl.Trainer(max_epochs=100,weights_summary='full')
trainer.fit(model,train_dataloader=trainloader)
我有这个断言错误:
AssertionError Traceback (most recent call last)
<ipython-input-29-781e293c05ed> in <module>()
10 model = Learner(NN)
11 trainer = pl.Trainer(max_epochs=100,weights_summary='full')
---> 12 trainer.fit(model,train_dataloader=trainloader)
13 #only when called it uses the test loop
14 trainer.test(model,testloader)
16 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py in forward(self,input,hx)
242 max_batch_size = int(batch_sizes[0])
243 else:
--> 244 assert isinstance(input,Tensor)
245 batch_sizes = None
246 max_batch_size = input.size(0) if self.batch_first else input.size(1)
当我检查 github https://github.com/pytorch/pytorch/blob/d09abf004cc16f8fd5f320e3d5d07c383c174ea7/torch/nn/modules/rnn.py#L247 中的 baseRNN 时,我没有发现 assert!
你能帮忙吗?
解决方法
昨天为这个错误添加了一个修复程序(这就是为什么你没有在 github 代码中看到它)。在此处查看相关问题的 PR:
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。