如何解决火炬:Dataloader 中的 num_workers 参数引发错误
我想使用 GPU 学习我的模型,但是当我尝试在 PyTorch 的 Dataloader 中指向线程数时出现错误。我愿意:
X_train,X_test,y_train,y_test = model_selection.train_test_split(X_reduced,y,test_size=0.4)
X_test,X_val,y_test,y_val = model_selection.train_test_split(X_test,test_size=0.5)
dtype = config['dtype']
train_dataset = torch.utils.data.TensorDataset(torch.tensor(X_train,dtype=dtype,requires_grad=True),torch.tensor(y_train.values,requires_grad=True))
valid_dataset = torch.utils.data.TensorDataset(torch.tensor(X_val,torch.tensor(y_val.values,requires_grad=True))
train_dataloader = torch.utils.data.DataLoader(train_dataset,shuffle=True,batch_size=config['batch_size'],num_workers=8)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset,num_workers=8)
我也在使用 torchLightning:
class PostflopNet(pl.LightningModule):
def __init__(self,n_inputs):
super().__init__()
self.encoder = torch.nn.Sequential(
torch.nn.Linear(n_inputs,512),torch.nn.ReLU(),torch.nn.Linear(512,128),torch.nn.Linear(128,32),torch.nn.Linear(32,8),torch.nn.Linear(8,1))
def forward(self,x):
embedding = self.encoder(x)
return embedding
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(),lr=1.0e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=50,gamma=0.6)
return [optimizer],[scheduler]
def training_step(self,train_batch,batch_idx):
inputs,labels = train_batch
predicted = self(inputs)
loss = torch.nn.MSELoss().__call__(inputs,labels)
print('train_loss: ',loss)
print(inputs.shape)
return loss
model = PostflopNet(X_train.shape[1])
trainer = pl.Trainer(max_epochs=config['n_train_epoch'],gpus=1)
trainer.fit(model,train_dataloader,valid_dataloader)
但我收到此错误:
File "/home/neighbourhood@netsrv.pw/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py",line 55,in default_collate
return torch.stack(batch,out=out)
RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation,but one of the arguments requires grad.
如果我删除 Dataloader 中的 num_worker 参数,它就会成功学习。
有什么问题?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。