Pytorch GAN 模型未训练:矩阵乘法错误

如何解决Pytorch GAN 模型未训练:矩阵乘法错误

我正在尝试构建一个基本的 GAN 以熟悉 Pytorch。我对 Keras 有一些(有限的)经验,但由于我必须在 Pytorch 中做一个更大的项目,我想首先使用“基本”网络进行探索。

我正在使用 Pytorch Lightning。我想我已经添加了所有必要的组件。我尝试分别通过生成器和鉴别器传递一些噪声,我认为输出具有预期的形状。尽管如此,当我尝试训练 GAN 时出现运行时错误(完整回溯如下):

RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x9 and 25x1)

我注意到 7 是批次的大小(通过打印出批次尺寸),即使我将 batch_size 指定为 64。除此之外,老实说,我不知道从哪里开始:错误回溯对我没有帮助。

很有可能,我犯了很多错误。但是,我希望你们中的一些人能够从代码中发现当前的错误,因为乘法错误似乎指向某个地方的维度问题。这是代码。

import os

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from skimage import io
from torch.utils.data import Dataset,DataLoader,random_split
from torchvision.utils import make_grid
from torchvision.transforms import Resize,ToTensor,ToPILImage,Normalize  

class DoppelDataset(Dataset):
    """
    Dataset class for face data
    """

    def __init__(self,face_dir: str,transform=None):

        self.face_dir = face_dir
        self.face_paths = os.listdir(face_dir)
        self.transform = transform

    def __len__(self):

        return len(self.face_paths)

    def __getitem__(self,idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        face_path = os.path.join(self.face_dir,self.face_paths[idx])
        face = io.imread(face_path)

        sample = {'image': face}

        if self.transform:
            sample = self.transform(sample['image'])

        return sample


class DoppelDataModule(pl.LightningDataModule):

    def __init__(self,data_dir='../data/faces',batch_size: int = 64,num_workers: int = 0):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transforms = transforms.Compose([
            ToTensor(),Resize(100),Normalize(mean=(123.26290927634774,95.90498110733365,86.03763122875182),std=(63.20679012922922,54.86211954409834,52.31266645797249))
        ])

    def setup(self,stage=None):
        # Initialize dataset
        doppel_data = DoppelDataset(face_dir=self.data_dir,transform=self.transforms)

        # Train/val/test split
        n = len(doppel_data)
        train_size = int(.8 * n)
        val_size = int(.1 * n)
        test_size = n - (train_size + val_size)

        self.train_data,self.val_data,self.test_data = random_split(dataset=doppel_data,lengths=[train_size,val_size,test_size])

    def train_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.test_data,batch_size=self.batch_size,num_workers=self.num_workers)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.val_data,num_workers=self.num_workers)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.test_data,num_workers=self.num_workers)


class DoppelGenerator(nn.Sequential):
    """
    Generator network that produces images based on latent vector
    """

    def __init__(self,latent_dim: int):
        super().__init__()

        def block(in_channels: int,out_channels: int,padding: int = 1,stride: int = 2,bias=False):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=4,stride=stride,padding=padding,bias=bias),nn.BatchNorm2d(num_features=out_channels),nn.ReLU(True)
            )

        self.model = nn.Sequential(
            block(latent_dim,512,padding=0,stride=1),block(512,256),block(256,128),block(128,64),block(64,32),nn.ConvTranspose2d(32,3,stride=2,padding=1,bias=False),nn.Tanh()
        )

    def forward(self,input):
        return self.model(input)


class DoppelDiscriminator(nn.Sequential):
    """
    Discriminator network that classifies images in two categories
    """

    def __init__(self):
        super().__init__()

        def block(in_channels: int,out_channels: int):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels,nn.LeakyReLU(0.2,inplace=True),)

        self.model = nn.Sequential(
            block(3,512),nn.Conv2d(512,1,stride=1,nn.Flatten(),nn.Linear(25,1),nn.Sigmoid()
        )

    def forward(self,input):
        return self.model(input)


class DoppelGAN(pl.LightningModule):

    def __init__(self,channels: int,width: int,height: int,lr: float = 0.0002,b1: float = 0.5,b2: float = 0.999,**kwargs):

        super().__init__()

        # Save all keyword arguments as hyperparameters,accessible through self.hparams.X)
        self.save_hyperparameters()

        # Initialize networks
        # data_shape = (channels,width,height)
        self.generator = DoppelGenerator(latent_dim=self.hparams.latent_dim,)
        self.discriminator = DoppelDiscriminator()

        self.validation_z = torch.randn(8,self.hparams.latent_dim,1)

    def forward(self,input):
        return self.generator(input)

    def adversarial_loss(self,y_hat,y):
        return F.binary_cross_entropy(y_hat,y)

    def training_step(self,batch,batch_idx,optimizer_idx):
        images = batch

        # Sample noise (batch_size,latent_dim,1)
        z = torch.randn(images.size(0),1)

        # Train generator
        if optimizer_idx == 0:

            # Generate images (call generator -- see forward -- on latent vector)
            self.generated_images = self(z)

            # Log sampled images (visualize what the generator comes up with)
            sample_images = self.generated_images[:6]
            grid = make_grid(sample_images)
            self.logger.experiment.add_image('generated_images',grid,0)

            # Ground truth result (ie: all fake)
            valid = torch.ones(images.size(0),1)

            # Adversarial loss is binary cross-entropy
            generator_loss = self.adversarial_loss(self.discriminator(self(z)),valid)
            tqdm_dict = {'gen_loss': generator_loss}

            output = {
                'loss': generator_loss,'progress_bar': tqdm_dict,'log': tqdm_dict
            }
            return output

        # Train discriminator: classify real from generated samples
        if optimizer_idx == 1:

            # How well can it label as real?
            valid = torch.ones(images.size(0),1)
            real_loss = self.adversarial_loss(self.discriminator(images),valid)

            # How well can it label as fake?
            fake = torch.zeros(images.size(0),1)
            fake_loss = self.adversarial_loss(
                self.discriminator(self(z).detach()),fake)

            # Discriminator loss is the average of these
            discriminator_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': discriminator_loss}
            output = {
                'loss': discriminator_loss,'log': tqdm_dict
            }
            return output

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        # Optimizers
        opt_g = torch.optim.Adam(self.generator.parameters(),lr=lr,betas=(b1,b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(),b2))

        # Return optimizers/schedulers (currently no scheduler)
        return [opt_g,opt_d],[]

    def on_epoch_end(self):

        # Log sampled images
        sample_images = self(self.validation_z)
        grid = make_grid(sample_images)
        self.logger.experiment.add_image('generated_images',self.current_epoch)


if __name__ == '__main__':

    # Global parameter
    image_dim = 128
    latent_dim = 100
    batch_size = 64

    # Initialize dataset
    tfs = transforms.Compose([
        ToPILImage(),Resize(image_dim),ToTensor()
    ])
    doppel_dataset = DoppelDataset(face_dir='../data/faces',transform=tfs)

    # Initialize data module
    doppel_data_module = DoppelDataModule(batch_size=batch_size)

    # Build models
    generator = DoppelGenerator(latent_dim=latent_dim)
    discriminator = DoppelDiscriminator()

    # Test generator
    x = torch.rand(batch_size,1)
    y = generator(x)
    print(f'Generator: x {x.size()} --> y {y.size()}')

    # Test discriminator
    x = torch.rand(batch_size,128,128)
    y = discriminator(x)
    print(f'Discriminator: x {x.size()} --> y {y.size()}')

    # Build GAN
    doppelgan = DoppelGAN(batch_size=batch_size,channels=3,width=image_dim,height=image_dim,latent_dim=latent_dim)

    # Fit GAN
    trainer = pl.Trainer(gpus=0,max_epochs=5,progress_bar_refresh_rate=1)
    trainer.fit(model=doppelgan,datamodule=doppel_data_module)

完整回溯:

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/site-packages/IPython/core/interactiveshell.py",line 3437,in run_code
    exec(code_obj,self.user_global_ns,self.user_ns)
  File "<ipython-input-2-28805d67d74b>",line 1,in <module>
    runfile('/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py',wdir='/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger')
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py",line 197,in runfile
    pydev_imports.execfile(filename,global_vars,local_vars)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py",line 18,in execfile
    exec(compile(contents+"\n",file,'exec'),glob,loc)
  File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py",line 298,in <module>
    trainer.fit(model=doppelgan,datamodule=doppel_data_module)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",line 510,in fit
    results = self.accelerator_backend.train()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py",line 57,in train
    return self.train_or_test()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py",line 74,in train_or_test
    results = self.trainer.train()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",line 561,in train
    self.train_loop.run_training_epoch()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 550,in run_training_epoch
    batch_output = self.run_training_batch(batch,dataloader_idx)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 718,in run_training_batch
    self.optimizer_step(optimizer,opt_idx,train_step_and_backward_closure)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 485,in optimizer_step
    model_ref.optimizer_step(
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py",line 1298,in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py",line 286,in step
    self.__optimizer_step(*args,closure=closure,profiler_name=profiler_name,**kwargs)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py",line 144,in __optimizer_step
    optimizer.step(closure=closure,*args,**kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/autograd/grad_mode.py",line 26,in decorate_context
    return func(*args,**kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/optim/adam.py",line 66,in step
    loss = closure()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 708,in train_step_and_backward_closure
    result = self.training_step_and_backward(
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 806,in training_step_and_backward
    result = self.training_step(split_batch,hiddens)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py",line 319,in training_step
    training_step_output = self.trainer.accelerator_backend.training_step(args)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py",line 62,in training_step
    return self._step(self.trainer.model.training_step,args)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py",line 58,in _step
    output = model_step(*args)
  File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py",line 223,in training_step
    real_loss = self.adversarial_loss(self.discriminator(images),valid)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py",line 727,in _call_impl
    result = self.forward(*input,**kwargs)
  File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py",line 154,in forward
    return self.model(input)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py",**kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/container.py",line 117,in forward
    input = module(input)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py",**kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/linear.py",line 93,in forward
    return F.linear(input,self.weight,self.bias)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/functional.py",line 1690,in linear
    ret = torch.addmm(bias,input,weight.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x9 and 25x1)

解决方法

这个乘法问题来自DoppelDiscriminator。有一个线性层

    nn.Linear(25,1),

应该是

    nn.Linear(9,

基于错误消息。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)&gt; insert overwrite table dwd_trade_cart_add_inc &gt; select data.id, &gt; data.user_id, &gt; data.course_id, &gt; date_format(
错误1 hive (edu)&gt; insert into huanhuan values(1,&#39;haoge&#39;); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive&gt; show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 &lt;configuration&gt; &lt;property&gt; &lt;name&gt;yarn.nodemanager.res