变分自动编码器损失下降但不重建输入出于调试的想法适用于 mnist 但不适用于其他数据

如何解决变分自动编码器损失下降但不重建输入出于调试的想法适用于 mnist 但不适用于其他数据

我的变分自动编码器似乎适用于 MNIST,但在稍微“更难”的数据上失败。
“失败”是指至少有两个明显的问题:

  1. 非常差的重建,例如验证集上最后一个时期的样本重建

    enter image description here

    enter image description here

    enter image description here

    根本没有任何正则化。
    控制台最后报告的损失是 val_loss=9.57e-5,train_loss=9.83e-5,我认为这意味着精确重建。
  2. validation loss 很低(这似乎不能反映重建),并且总是低于训练 loss,这很可疑。

    losses

    losses2

对于 MNIST,一切看起来都很好(层数较少!)。

mnist recon

我会尽可能多地提供信息,因为我不确定我应该提供什么来帮助任何人帮助我。


首先,这是完整代码
您会注意到损失计算和日志记录非常简单直接,我似乎找不到问题所在。

import torch
from torch import nn
import torch.nn.functional as F
from typing import List,Optional,Any
from pytorch_lightning.core.lightning import LightningModule
from Testing.Research.config.ConfigProvider import ConfigProvider
from pytorch_lightning import Trainer,seed_everything
from torch import optim
import os
from pytorch_lightning.loggers import TensorBoardLogger
# import tfmpl
import matplotlib.pyplot as plt
import matplotlib
from Testing.Research.data_modules.MyDataModule import MyDataModule
from Testing.Research.data_modules.MNISTDataModule import MNISTDataModule
from Testing.Research.data_modules.CaseDataModule import CaseDataModule
import torchvision
from Testing.Research.config.paths import tb_logs_folder
from Testing.Research.config.paths import vae_checkpoints_path
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint


class VAEFC(LightningModule):
    # see https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73
    # for possible upgrades,see https://arxiv.org/pdf/1602.02282.pdf
    # https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational
    # -auto-encoder
    def __init__(self,encoder_layer_sizes: List,decoder_layer_sizes: List,config):
        super(VAEFC,self).__init__()
        self._config = config
        self.logger: Optional[TensorBoardLogger] = None
        self.save_hyperparameters()

        assert len(encoder_layer_sizes) >= 3,"must have at least 3 layers (2 hidden)"
        # encoder layers
        self._encoder_layers = nn.ModuleList()
        for i in range(1,len(encoder_layer_sizes) - 1):
            enc_layer = nn.Linear(encoder_layer_sizes[i - 1],encoder_layer_sizes[i])
            self._encoder_layers.append(enc_layer)

        # predict mean and covariance vectors
        self._mean_layer = nn.Linear(encoder_layer_sizes[
                                         len(encoder_layer_sizes) - 2],encoder_layer_sizes[len(encoder_layer_sizes) - 1])
        self._logvar_layer = nn.Linear(encoder_layer_sizes[
                                           len(encoder_layer_sizes) - 2],encoder_layer_sizes[len(encoder_layer_sizes) - 1])

        # decoder layers
        self._decoder_layers = nn.ModuleList()
        for i in range(1,len(decoder_layer_sizes)):
            dec_layer = nn.Linear(decoder_layer_sizes[i - 1],decoder_layer_sizes[i])
            self._decoder_layers.append(dec_layer)

        self._recon_function = nn.MSELoss(reduction='mean')
        self._last_val_batch = {}

    def _encode(self,x):
        for i in range(len(self._encoder_layers)):
            layer = self._encoder_layers[i]
            x = F.relu(layer(x))

        mean_output = self._mean_layer(x)
        logvar_output = self._logvar_layer(x)
        return mean_output,logvar_output

    def _reparametrize(self,mu,logvar):
        if not self.training:
            return mu
        std = logvar.mul(0.5).exp_()
        if std.is_cuda:
            eps = torch.FloatTensor(std.size()).cuda().normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        reparameterized = eps.mul(std).add_(mu)
        return reparameterized

    def _decode(self,z):
        for i in range(len(self._decoder_layers) - 1):
            layer = self._decoder_layers[i]
            z = F.relu((layer(z)))

        decoded = self._decoder_layers[len(self._decoder_layers) - 1](z)
        # decoded = F.sigmoid(self._decoder_layers[len(self._decoder_layers)-1](z))
        return decoded

    def _loss_function(self,recon_x,x,logvar,reconstruction_function):
        """
        recon_x: generating images
        x: origin images
        mu: latent mean
        logvar: latent log variance
        """
        binary_cross_entropy = reconstruction_function(recon_x,x)  # mse loss TODO see if mse or cross entropy
        # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
        kld = torch.sum(kld_element).mul_(-0.5)
        # KL divergence Kullback–Leibler divergence,regularization term for VAE
        # It is a measure of how different two probability distributions are different from each other.
        # We are trying to force the distributions closer while keeping the reconstruction loss low.
        # see https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

        # read on weighting the regularization term here:
        # https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational
        # -auto-encoder
        return binary_cross_entropy + kld * self._config.regularization_factor

    def _parse_batch_by_dataset(self,batch,batch_index):
        if self._config.dataset == "toy":
            (orig_batch,noisy_batch),label_batch = batch
            # TODO put in the noise here and not in the dataset?
        elif self._config.dataset == "mnist":
            orig_batch,label_batch = batch
            orig_batch = orig_batch.reshape(-1,28 * 28)
            noisy_batch = orig_batch
        elif self._config.dataset == "case":
            orig_batch,label_batch = batch

            orig_batch = orig_batch.float().reshape(
                    -1,len(self._config.case.feature_list) * self._config.case.frames_per_pd_sample
            )
            noisy_batch = orig_batch
        else:
            raise ValueError("invalid dataset")
        noisy_batch = noisy_batch.view(noisy_batch.size(0),-1)

        return orig_batch,noisy_batch,label_batch

    def training_step(self,batch_idx):
        orig_batch,label_batch = self._parse_batch_by_dataset(batch,batch_idx)

        recon_batch,logvar = self.forward(noisy_batch)

        loss = self._loss_function(
                recon_batch,orig_batch,reconstruction_function=self._recon_function
        )
        # self.logger.experiment.add_scalars("losses",{"train_loss": loss})
        tb = self.logger.experiment
        tb.add_scalars("losses",{"train_loss": loss},global_step=self.current_epoch)
        # self.logger.experiment.add_scalar("train_loss",loss,self.current_epoch)
        if batch_idx == len(self.train_dataloader()) - 2:
            # https://pytorch.org/docs/stable/_modules/torch/utils/tensorboard/writer.html#SummaryWriter.add_embedding
            # noisy_batch = noisy_batch.detach()
            # recon_batch = recon_batch.detach()
            # last_batch_plt = matplotlib.figure.Figure()  # read https://github.com/wookayin/tensorflow-plot
            # ax = last_batch_plt.add_subplot(1,1,1)
            # ax.scatter(orig_batch[:,0],orig_batch[:,1],label="original")
            # ax.scatter(noisy_batch[:,noisy_batch[:,label="noisy")
            # ax.scatter(recon_batch[:,recon_batch[:,label="reconstructed")
            # ax.legend(loc="upper left")
            # self.logger.experiment.add_figure(f"original last batch,epoch {self.current_epoch}",last_batch_plt)
            # tb.add_embedding(orig_batch,global_step=self.current_epoch,metadata=label_batch)
            pass
        self.logger.experiment.flush()
        self.log("train_loss",prog_bar=True,on_step=False,on_epoch=True)
        return loss

    def _plot_batches(self,label_batch,batch_idx,recon_batch,logvar):
        # orig_batch_view = orig_batch.reshape(-1,self._config.case.frames_per_pd_sample,# len(self._config.case.feature_list))
        #
        # plt.figure()
        # plt.plot(orig_batch_view[11,:,0].detach().cpu().numpy(),label="feature 0")
        # plt.legend(loc="upper left")
        # plt.show()

        tb = self.logger.experiment
        if self._config.dataset == "mnist":
            orig_batch -= orig_batch.min()
            orig_batch /= orig_batch.max()
            recon_batch -= recon_batch.min()
            recon_batch /= recon_batch.max()

            orig_grid = torchvision.utils.make_grid(orig_batch.view(-1,28,28))
            val_recon_grid = torchvision.utils.make_grid(recon_batch.view(-1,28))

            tb.add_image("original_val",orig_grid,global_step=self.current_epoch)
            tb.add_image("reconstruction_val",val_recon_grid,global_step=self.current_epoch)

            label_img = orig_batch.view(-1,28)
            pass
        elif self._config.dataset == "case":
            orig_batch_view = orig_batch.reshape(-1,len(self._config.case.feature_list)).transpose(1,2)
            recon_batch_view = recon_batch.reshape(-1,2)

            # plt.figure()
            # plt.plot(orig_batch_view[11,:].detach().cpu().numpy())
            # plt.show()
            # pass

            n_samples = orig_batch_view.shape[0]
            n_plots = min(n_samples,4)
            first_sample_idx = 0

            # TODO either plotting or data problem
            fig,axs = plt.subplots(n_plots,1)
            for sample_idx in range(n_plots):
                for feature_idx,(orig_feature,recon_feature) in enumerate(
                        zip(orig_batch_view[sample_idx + first_sample_idx,:],recon_batch_view[sample_idx + first_sample_idx,:])):
                    i = feature_idx
                    if i > 0: continue  # or scale issues don't allow informative plotting

                    # plt.figure()
                    # plt.plot(orig_feature.detach().cpu().numpy(),label=f'orig{i},sample{sample_idx}')
                    # plt.legend(loc='upper left')
                    # pass

                    axs[sample_idx].plot(orig_feature.detach().cpu().numpy(),sample{sample_idx}')
                    axs[sample_idx].plot(recon_feature.detach().cpu().numpy(),label=f'recon{i},sample{sample_idx}')
                    # sample{sample_idx}')
                    axs[sample_idx].legend(loc='upper left')
                pass
            # plt.show()

            tb.add_figure("recon_vs_orig",fig,close=True)

    def validation_step(self,reconstruction_function=self._recon_function
        )

        tb = self.logger.experiment
        # can probably speed up training by waiting for epoch end for data copy from gpu
        # see https://sagivtech.com/2017/09/19/optimizing-pytorch-training-code/
        tb.add_scalars("losses",{"val_loss": loss},global_step=self.current_epoch)

        label_img = None
        if len(orig_batch) > 2:
            self._last_val_batch = {
                "orig_batch": orig_batch,"noisy_batch": noisy_batch,"label_batch": label_batch,"batch_idx": batch_idx,"recon_batch": recon_batch,"mu": mu,"logvar": logvar
            }
        # self._plot_batches(orig_batch,logvar)

        outputs = {"val_loss":  loss,"label_img": label_img}
        self.log("val_loss",on_epoch=True)
        return outputs

    def validation_epoch_end(self,outputs: List[Any]) -> None:
        first_batch_dict = outputs[-1]

        self._plot_batches(
                self._last_val_batch["orig_batch"],self._last_val_batch["noisy_batch"],self._last_val_batch["label_batch"],self._last_val_batch["batch_idx"],self._last_val_batch["recon_batch"],self._last_val_batch["mu"],self._last_val_batch["logvar"]
        )
        self.log(name="VAEFC_val_loss_epoch_end",value={"val_loss": first_batch_dict["val_loss"]})

    def test_step(self,reconstruction_function=self._recon_function
        )

        tb = self.logger.experiment
        tb.add_scalars("losses",{"test_loss": loss},global_step=self.global_step)

        return {"test_loss": loss,"mus": mu,"labels": label_batch,"images": orig_batch}

    def test_epoch_end(self,outputs: List):
        tb = self.logger.experiment

        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        self.log(name="test_epoch_end",value={"test_loss_avg": avg_loss})

        if self._config.dataset == "mnist":
            tb.add_embedding(
                    mat=torch.cat([o["mus"] for o in outputs]),metadata=torch.cat([o["labels"] for o in outputs]).detach().cpu().numpy(),label_img=torch.cat([o["images"] for o in outputs]).view(-1,28),global_step=self.global_step,)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(),lr=self._config.learning_rate)
        return optimizer

    def forward(self,x):
        mu,logvar = self._encode(x)
        z = self._reparametrize(mu,logvar)
        decoded = self._decode(z)
        return decoded,logvar


def train_vae(config,datamodule,latent_dim,dec_layer_sizes,enc_layer_sizes):
    model = VAEFC(config=config,encoder_layer_sizes=enc_layer_sizes,decoder_layer_sizes=dec_layer_sizes)

    logger = TensorBoardLogger(save_dir=tb_logs_folder,name='VAEFC',default_hp_metric=False)
    logger.hparams = config

    checkpoint_callback = ModelCheckpoint(dirpath=vae_checkpoints_path)
    trainer = Trainer(deterministic=config.is_deterministic,# auto_lr_find=config.auto_lr_find,# log_gpu_memory='all',# min_epochs=99999,max_epochs=config.num_epochs,default_root_dir=vae_checkpoints_path,logger=logger,callbacks=[checkpoint_callback],gpus=1
                      )
    # trainer.tune(model)
    trainer.fit(model,datamodule=datamodule)
    best_model_path = checkpoint_callback.best_model_path
    print("done training vae with lightning")
    print(f"best model path = {best_model_path}")
    return trainer


def run_trained_vae(trainer):
    # https://pytorch-lightning.readthedocs.io/en/latest/test_set.html
    # (1) load the best checkpoint automatically (lightning tracks this for you)
    trainer.test()

    # (2) don't load a checkpoint,instead use the model with the latest weights
    # trainer.test(ckpt_path=None)

    # (3) test using a specific checkpoint
    # trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt')

    # (4) test with an explicit model (will use this model and not load a checkpoint)
    # trainer.test(model)


参数

对于我(手动)使用的任何参数组合,我都得到了非常相似的结果。也许我没有尝试什么。

num_epochs: 40
batch_size: 32
learning_rate: 0.0001
auto_lr_find: False

noise_factor: 0.1
regularization_factor: 0.0

train_size: 0.8
val_size: 0.1
num_workers: 1

dataset: "case" # toy,mnnist,case
mnist:
  enc_layer_sizes: [784,512,]
  dec_layer_sizes: [512,784]
  latent_dim: 25
  n_classes: 10
  classifier_layers: [20,10]
toy:
  enc_layer_sizes: [2,200,200]
  dec_layer_sizes: [200,2]
  latent_dim: 8
  centers_radius: 4.0
  n_clusters: 10
  cluster_size: 5000
case:
  #enc_layer_sizes: [ 1800,600,300,100 ]
  #dec_layer_sizes: [ 100,1800 ]
  #frames_per_pd_sample: 600

  enc_layer_sizes: [ 10,300 ]
  dec_layer_sizes: [ 600,10 ]
  frames_per_pd_sample: 10

  latent_dim: 300
  n_classes: 10
  classifier_layers: [ 20,10 ] # unused right now.

  feature_list:
    #- V_0_0 # 0,X
    #- V_0_1 # 0,Y
    #- V_0_2 # 0,Z
    - pads_0
  enc_kernel_sizes: [] # for conv
  end_strides: []
  dec_kernel_sizes: []
  dec_strides: []

is_deterministic: False

real_data_pd_dir: "D:/pressure_pd"
case_dir: "real_case_20_min"
case_file: "pressure_data_0.pkl"


数据

对于 Mnist,一切正常。

换成我的具体数据时,结果如上。

数据是一个时间序列,具有多个特征。为了更简单,我只提供一个特征,切成等长的块,然后作为向量输入输入层。
数据是时间序列这一事实可能有助于将来建模,但现在我只想将其称为数据块,我相信我正在这样做。

代码:

from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import torch
from Testing.Research.config.ConfigProvider import ConfigProvider
import os
import pickle
import pandas as pd
from typing import Tuple
import numpy as np


class CaseDataset(Dataset):
    def __init__(self,path):
        super(CaseDataset,self).__init__()
        self._path = path

        self._config = ConfigProvider.get_config()
        self.frames_per_pd_sample = self._config.case.frames_per_pd_sample
        self._load_case_from_pkl()
        self.__len = len(self._full) // self.frames_per_pd_sample  # discard last non full batch

    def _load_case_from_pkl(self):
        assert os.path.isfile(self._path)
        with open(self._path,"rb") as f:
            p = pickle.load(f)

        self._full: pd.DataFrame = p["full"]
        self._subsampled: pd.DataFrame = p["subsampled"]
        self._misc: pd.DataFrame = p["misc"]

        feature_list = self._config.case.feature_list
        self._features_df = self._full[feature_list].copy()

        # normalize from -1 to 1
        features_to_normalize = self._features_df.columns
        self._features_df[features_to_normalize] = \
            self._features_df[features_to_normalize].apply(lambda x: (((x - x.min()) / (x.max() - x.min())) * 2) - 1)

        pass

    def __len__(self):
        # number of samples in the dataset
        return self.__len

    def __getitem__(self,index: int) -> Tuple[np.array,np.array]:
        data_item = self._features_df.iloc[index * self.frames_per_pd_sample: (index + 1) * self.frames_per_pd_sample,:].values
        label = 0.0
        # plt.figure()
        # plt.plot(data_item[:,label="feature 0")
        # plt.legend(loc="upper left")
        # plt.show()
        return data_item,label

每批次的时间步长似乎不会影响收敛。

训练测试 val split

是这样完成的:

import os
from pytorch_lightning import LightningDataModule
import torchvision.datasets as datasets
from torchvision.transforms import transforms
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from Testing.Research.config.paths import mnist_data_download_folder
from Testing.Research.datasets.real_cases.CaseDataset import CaseDataset
from typing import Optional


class CaseDataModule(LightningDataModule):
    def __init__(self,config,path):
        super().__init__()
        self._config = config
        self._path = path

        self._train_dataset: Optional[Subset] = None
        self._val_dataset: Optional[Subset] = None
        self._test_dataset: Optional[Subset] = None

    def prepare_data(self):
        pass

    def setup(self,stage):
        # transform
        transform = transforms.Compose([transforms.ToTensor()])
        full_dataset = CaseDataset(self._path)

        train_size = int(self._config.train_size * len(full_dataset))
        val_size = int(self._config.val_size * len(full_dataset))
        test_size = len(full_dataset) - train_size - val_size
        train,val,test = torch.utils.data.random_split(full_dataset,[train_size,val_size,test_size])

        # assign to use in dataloaders
        self._full_dataset = full_dataset
        self._train_dataset = train
        self._val_dataset = val
        self._test_dataset = test

    def train_dataloader(self):
        return DataLoader(self._train_dataset,batch_size=self._config.batch_size,num_workers=self._config.num_workers)

    def val_dataloader(self):
        return DataLoader(self._val_dataset,num_workers=self._config.num_workers)

    def test_dataloader(self):
        return DataLoader(self._test_dataset,num_workers=self._config.num_workers)

问题

  1. 我相信验证损失始终低于火车损失表明这里出现了很大的问题,但我无法确定是什么,或者想出如何验证这一点。
  2. 我怎样才能让模型正确地自动编码数据?基本上,我希望它学习恒等函数,并让损失反映这一点。
  3. 损失似乎并不反映重建。我认为这可能是最根本的问题

我的想法

  1. 尝试用卷积网络代替 FC?也许它能够更好地学习特征?
  2. 没有想法:(

将提供任何缺少的信息。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)> insert overwrite table dwd_trade_cart_add_inc > select data.id, > data.user_id, > data.course_id, > date_format(
错误1 hive (edu)> insert into huanhuan values(1,'haoge'); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive> show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 <configuration> <property> <name>yarn.nodemanager.res