如何解决Pytorch 闪电数据模块覆盖警告:方法“.setup()”的签名与“LightningDataModule”类中的基本方法的签名不匹配
以下是一个有效的 Pytorch Lightning DataModule。
import os
from pytorch_lightning import LightningDataModule
import torchvision.datasets as datasets
from torchvision.transforms import transforms
import torch
from torch.utils.data import DataLoader
from Testing.Research.config.paths import mnist_data_download_folder
class PressureDataModule(LightningDataModule):
def __init__(self,config):
super().__init__()
self._config = config
def prepare_data(self):
pass
def setup(self,stage):
# transform
transform = transforms.Compose([transforms.ToTensor()])
mnist_train_full = datasets.MNIST(mnist_data_download_folder,train=True,download=False,transform=self._transforms)
mnist_test = datasets.MNIST(mnist_data_download_folder,train=False,transform=self._transforms)
# train/val split
train_size = int(self._config.train_size /
(self._config.train_size + self._config.val_size) * len(mnist_train_full))
val_size = len(mnist_train_full) - train_size
mnist_train,mnist_val = torch.utils.data.random_split(mnist_train_full,[train_size,val_size])
# assign to use in dataloaders
self._train_dataset = mnist_train
self._val_dataset = mnist_val
self._test_dataset = mnist_test
def train_dataloader(self):
return DataLoader(self._train_dataset,batch_size=self._config.batch_size,num_workers=self._config.num_workers)
def val_dataloader(self):
return DataLoader(self._val_dataset,num_workers=self._config.num_workers)
def test_dataloader(self):
return DataLoader(self._test_dataset,num_workers=self._config.num_workers)
Pycharm 不喜欢 setup
与
方法'PressureDataModule.setup()'的签名不匹配 'LightningDataModule' 类中基方法的签名
- 如果没有匹配项,为什么 Pycharm 会哭?
- 是不是因为参数不同?正确的参数数量是多少?
解决这个问题的正确方法是什么?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。