如何解决Pytorch GAN 模型未训练:矩阵乘法错误
我正在尝试构建一个基本的 GAN 以熟悉 Pytorch。我对 Keras 有一些(有限的)经验,但由于我必须在 Pytorch 中做一个更大的项目,我想首先使用“基本”网络进行探索。
我正在使用 Pytorch Lightning。我想我已经添加了所有必要的组件。我尝试分别通过生成器和鉴别器传递一些噪声,我认为输出具有预期的形状。尽管如此,当我尝试训练 GAN 时出现运行时错误(完整回溯如下):
RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x9 and 25x1)
我注意到 7 是批次的大小(通过打印出批次尺寸),即使我将 batch_size 指定为 64。除此之外,老实说,我不知道从哪里开始:错误回溯对我没有帮助。
很有可能,我犯了很多错误。但是,我希望你们中的一些人能够从代码中发现当前的错误,因为乘法错误似乎指向某个地方的维度问题。这是代码。
import os
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from skimage import io
from torch.utils.data import Dataset,DataLoader,random_split
from torchvision.utils import make_grid
from torchvision.transforms import Resize,ToTensor,ToPILImage,Normalize
class DoppelDataset(Dataset):
"""
Dataset class for face data
"""
def __init__(self,face_dir: str,transform=None):
self.face_dir = face_dir
self.face_paths = os.listdir(face_dir)
self.transform = transform
def __len__(self):
return len(self.face_paths)
def __getitem__(self,idx):
if torch.is_tensor(idx):
idx = idx.tolist()
face_path = os.path.join(self.face_dir,self.face_paths[idx])
face = io.imread(face_path)
sample = {'image': face}
if self.transform:
sample = self.transform(sample['image'])
return sample
class DoppelDataModule(pl.LightningDataModule):
def __init__(self,data_dir='../data/faces',batch_size: int = 64,num_workers: int = 0):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.transforms = transforms.Compose([
ToTensor(),Resize(100),Normalize(mean=(123.26290927634774,95.90498110733365,86.03763122875182),std=(63.20679012922922,54.86211954409834,52.31266645797249))
])
def setup(self,stage=None):
# Initialize dataset
doppel_data = DoppelDataset(face_dir=self.data_dir,transform=self.transforms)
# Train/val/test split
n = len(doppel_data)
train_size = int(.8 * n)
val_size = int(.1 * n)
test_size = n - (train_size + val_size)
self.train_data,self.val_data,self.test_data = random_split(dataset=doppel_data,lengths=[train_size,val_size,test_size])
def train_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.test_data,batch_size=self.batch_size,num_workers=self.num_workers)
def val_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.val_data,num_workers=self.num_workers)
def test_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.test_data,num_workers=self.num_workers)
class DoppelGenerator(nn.Sequential):
"""
Generator network that produces images based on latent vector
"""
def __init__(self,latent_dim: int):
super().__init__()
def block(in_channels: int,out_channels: int,padding: int = 1,stride: int = 2,bias=False):
return nn.Sequential(
nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=4,stride=stride,padding=padding,bias=bias),nn.BatchNorm2d(num_features=out_channels),nn.ReLU(True)
)
self.model = nn.Sequential(
block(latent_dim,512,padding=0,stride=1),block(512,256),block(256,128),block(128,64),block(64,32),nn.ConvTranspose2d(32,3,stride=2,padding=1,bias=False),nn.Tanh()
)
def forward(self,input):
return self.model(input)
class DoppelDiscriminator(nn.Sequential):
"""
Discriminator network that classifies images in two categories
"""
def __init__(self):
super().__init__()
def block(in_channels: int,out_channels: int):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels,nn.LeakyReLU(0.2,inplace=True),)
self.model = nn.Sequential(
block(3,512),nn.Conv2d(512,1,stride=1,nn.Flatten(),nn.Linear(25,1),nn.Sigmoid()
)
def forward(self,input):
return self.model(input)
class DoppelGAN(pl.LightningModule):
def __init__(self,channels: int,width: int,height: int,lr: float = 0.0002,b1: float = 0.5,b2: float = 0.999,**kwargs):
super().__init__()
# Save all keyword arguments as hyperparameters,accessible through self.hparams.X)
self.save_hyperparameters()
# Initialize networks
# data_shape = (channels,width,height)
self.generator = DoppelGenerator(latent_dim=self.hparams.latent_dim,)
self.discriminator = DoppelDiscriminator()
self.validation_z = torch.randn(8,self.hparams.latent_dim,1)
def forward(self,input):
return self.generator(input)
def adversarial_loss(self,y_hat,y):
return F.binary_cross_entropy(y_hat,y)
def training_step(self,batch,batch_idx,optimizer_idx):
images = batch
# Sample noise (batch_size,latent_dim,1)
z = torch.randn(images.size(0),1)
# Train generator
if optimizer_idx == 0:
# Generate images (call generator -- see forward -- on latent vector)
self.generated_images = self(z)
# Log sampled images (visualize what the generator comes up with)
sample_images = self.generated_images[:6]
grid = make_grid(sample_images)
self.logger.experiment.add_image('generated_images',grid,0)
# Ground truth result (ie: all fake)
valid = torch.ones(images.size(0),1)
# Adversarial loss is binary cross-entropy
generator_loss = self.adversarial_loss(self.discriminator(self(z)),valid)
tqdm_dict = {'gen_loss': generator_loss}
output = {
'loss': generator_loss,'progress_bar': tqdm_dict,'log': tqdm_dict
}
return output
# Train discriminator: classify real from generated samples
if optimizer_idx == 1:
# How well can it label as real?
valid = torch.ones(images.size(0),1)
real_loss = self.adversarial_loss(self.discriminator(images),valid)
# How well can it label as fake?
fake = torch.zeros(images.size(0),1)
fake_loss = self.adversarial_loss(
self.discriminator(self(z).detach()),fake)
# Discriminator loss is the average of these
discriminator_loss = (real_loss + fake_loss) / 2
tqdm_dict = {'d_loss': discriminator_loss}
output = {
'loss': discriminator_loss,'log': tqdm_dict
}
return output
def configure_optimizers(self):
lr = self.hparams.lr
b1 = self.hparams.b1
b2 = self.hparams.b2
# Optimizers
opt_g = torch.optim.Adam(self.generator.parameters(),lr=lr,betas=(b1,b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(),b2))
# Return optimizers/schedulers (currently no scheduler)
return [opt_g,opt_d],[]
def on_epoch_end(self):
# Log sampled images
sample_images = self(self.validation_z)
grid = make_grid(sample_images)
self.logger.experiment.add_image('generated_images',self.current_epoch)
if __name__ == '__main__':
# Global parameter
image_dim = 128
latent_dim = 100
batch_size = 64
# Initialize dataset
tfs = transforms.Compose([
ToPILImage(),Resize(image_dim),ToTensor()
])
doppel_dataset = DoppelDataset(face_dir='../data/faces',transform=tfs)
# Initialize data module
doppel_data_module = DoppelDataModule(batch_size=batch_size)
# Build models
generator = DoppelGenerator(latent_dim=latent_dim)
discriminator = DoppelDiscriminator()
# Test generator
x = torch.rand(batch_size,1)
y = generator(x)
print(f'Generator: x {x.size()} --> y {y.size()}')
# Test discriminator
x = torch.rand(batch_size,128,128)
y = discriminator(x)
print(f'Discriminator: x {x.size()} --> y {y.size()}')
# Build GAN
doppelgan = DoppelGAN(batch_size=batch_size,channels=3,width=image_dim,height=image_dim,latent_dim=latent_dim)
# Fit GAN
trainer = pl.Trainer(gpus=0,max_epochs=5,progress_bar_refresh_rate=1)
trainer.fit(model=doppelgan,datamodule=doppel_data_module)
完整回溯:
Traceback (most recent call last):
File "/usr/local/lib/python3.9/site-packages/IPython/core/interactiveshell.py",line 3437,in run_code
exec(code_obj,self.user_global_ns,self.user_ns)
File "<ipython-input-2-28805d67d74b>",line 1,in <module>
runfile('/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py',wdir='/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger')
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py",line 197,in runfile
pydev_imports.execfile(filename,global_vars,local_vars) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py",line 18,in execfile
exec(compile(contents+"\n",file,'exec'),glob,loc)
File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py",line 298,in <module>
trainer.fit(model=doppelgan,datamodule=doppel_data_module)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",line 510,in fit
results = self.accelerator_backend.train()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py",line 57,in train
return self.train_or_test()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py",line 74,in train_or_test
results = self.trainer.train()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",line 561,in train
self.train_loop.run_training_epoch()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 550,in run_training_epoch
batch_output = self.run_training_batch(batch,dataloader_idx)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 718,in run_training_batch
self.optimizer_step(optimizer,opt_idx,train_step_and_backward_closure)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 485,in optimizer_step
model_ref.optimizer_step(
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py",line 1298,in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py",line 286,in step
self.__optimizer_step(*args,closure=closure,profiler_name=profiler_name,**kwargs)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py",line 144,in __optimizer_step
optimizer.step(closure=closure,*args,**kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/autograd/grad_mode.py",line 26,in decorate_context
return func(*args,**kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/optim/adam.py",line 66,in step
loss = closure()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 708,in train_step_and_backward_closure
result = self.training_step_and_backward(
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 806,in training_step_and_backward
result = self.training_step(split_batch,hiddens)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 319,in training_step
training_step_output = self.trainer.accelerator_backend.training_step(args)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py",line 62,in training_step
return self._step(self.trainer.model.training_step,args)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py",line 58,in _step
output = model_step(*args)
File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py",line 223,in training_step
real_loss = self.adversarial_loss(self.discriminator(images),valid)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py",line 727,in _call_impl
result = self.forward(*input,**kwargs)
File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py",line 154,in forward
return self.model(input)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py",**kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/container.py",line 117,in forward
input = module(input)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py",**kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/linear.py",line 93,in forward
return F.linear(input,self.weight,self.bias)
File "/usr/local/lib/python3.9/site-packages/torch/nn/functional.py",line 1690,in linear
ret = torch.addmm(bias,input,weight.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x9 and 25x1)
解决方法
这个乘法问题来自DoppelDiscriminator
。有一个线性层
nn.Linear(25,1),
应该是
nn.Linear(9,
基于错误消息。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。