如何解决变分自动编码器损失下降但不重建输入出于调试的想法适用于 mnist 但不适用于其他数据
我的变分自动编码器似乎适用于 MNIST,但在稍微“更难”的数据上失败。
“失败”是指至少有两个明显的问题:
- 非常差的重建,例如验证集上最后一个时期的样本重建
根本没有任何正则化。
控制台最后报告的损失是val_loss=9.57e-5,train_loss=9.83e-5
,我认为这意味着精确重建。 - validation loss 很低(这似乎不能反映重建),并且总是低于训练 loss,这很可疑。
对于 MNIST,一切看起来都很好(层数较少!)。
我会尽可能多地提供信息,因为我不确定我应该提供什么来帮助任何人帮助我。
首先,这是完整代码
您会注意到损失计算和日志记录非常简单直接,我似乎找不到问题所在。
import torch
from torch import nn
import torch.nn.functional as F
from typing import List,Optional,Any
from pytorch_lightning.core.lightning import LightningModule
from Testing.Research.config.ConfigProvider import ConfigProvider
from pytorch_lightning import Trainer,seed_everything
from torch import optim
import os
from pytorch_lightning.loggers import TensorBoardLogger
# import tfmpl
import matplotlib.pyplot as plt
import matplotlib
from Testing.Research.data_modules.MyDataModule import MyDataModule
from Testing.Research.data_modules.MNISTDataModule import MNISTDataModule
from Testing.Research.data_modules.CaseDataModule import CaseDataModule
import torchvision
from Testing.Research.config.paths import tb_logs_folder
from Testing.Research.config.paths import vae_checkpoints_path
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
class VAEFC(LightningModule):
# see https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73
# for possible upgrades,see https://arxiv.org/pdf/1602.02282.pdf
# https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational
# -auto-encoder
def __init__(self,encoder_layer_sizes: List,decoder_layer_sizes: List,config):
super(VAEFC,self).__init__()
self._config = config
self.logger: Optional[TensorBoardLogger] = None
self.save_hyperparameters()
assert len(encoder_layer_sizes) >= 3,"must have at least 3 layers (2 hidden)"
# encoder layers
self._encoder_layers = nn.ModuleList()
for i in range(1,len(encoder_layer_sizes) - 1):
enc_layer = nn.Linear(encoder_layer_sizes[i - 1],encoder_layer_sizes[i])
self._encoder_layers.append(enc_layer)
# predict mean and covariance vectors
self._mean_layer = nn.Linear(encoder_layer_sizes[
len(encoder_layer_sizes) - 2],encoder_layer_sizes[len(encoder_layer_sizes) - 1])
self._logvar_layer = nn.Linear(encoder_layer_sizes[
len(encoder_layer_sizes) - 2],encoder_layer_sizes[len(encoder_layer_sizes) - 1])
# decoder layers
self._decoder_layers = nn.ModuleList()
for i in range(1,len(decoder_layer_sizes)):
dec_layer = nn.Linear(decoder_layer_sizes[i - 1],decoder_layer_sizes[i])
self._decoder_layers.append(dec_layer)
self._recon_function = nn.MSELoss(reduction='mean')
self._last_val_batch = {}
def _encode(self,x):
for i in range(len(self._encoder_layers)):
layer = self._encoder_layers[i]
x = F.relu(layer(x))
mean_output = self._mean_layer(x)
logvar_output = self._logvar_layer(x)
return mean_output,logvar_output
def _reparametrize(self,mu,logvar):
if not self.training:
return mu
std = logvar.mul(0.5).exp_()
if std.is_cuda:
eps = torch.FloatTensor(std.size()).cuda().normal_()
else:
eps = torch.FloatTensor(std.size()).normal_()
reparameterized = eps.mul(std).add_(mu)
return reparameterized
def _decode(self,z):
for i in range(len(self._decoder_layers) - 1):
layer = self._decoder_layers[i]
z = F.relu((layer(z)))
decoded = self._decoder_layers[len(self._decoder_layers) - 1](z)
# decoded = F.sigmoid(self._decoder_layers[len(self._decoder_layers)-1](z))
return decoded
def _loss_function(self,recon_x,x,logvar,reconstruction_function):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
binary_cross_entropy = reconstruction_function(recon_x,x) # mse loss TODO see if mse or cross entropy
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
kld = torch.sum(kld_element).mul_(-0.5)
# KL divergence Kullback–Leibler divergence,regularization term for VAE
# It is a measure of how different two probability distributions are different from each other.
# We are trying to force the distributions closer while keeping the reconstruction loss low.
# see https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73
# read on weighting the regularization term here:
# https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational
# -auto-encoder
return binary_cross_entropy + kld * self._config.regularization_factor
def _parse_batch_by_dataset(self,batch,batch_index):
if self._config.dataset == "toy":
(orig_batch,noisy_batch),label_batch = batch
# TODO put in the noise here and not in the dataset?
elif self._config.dataset == "mnist":
orig_batch,label_batch = batch
orig_batch = orig_batch.reshape(-1,28 * 28)
noisy_batch = orig_batch
elif self._config.dataset == "case":
orig_batch,label_batch = batch
orig_batch = orig_batch.float().reshape(
-1,len(self._config.case.feature_list) * self._config.case.frames_per_pd_sample
)
noisy_batch = orig_batch
else:
raise ValueError("invalid dataset")
noisy_batch = noisy_batch.view(noisy_batch.size(0),-1)
return orig_batch,noisy_batch,label_batch
def training_step(self,batch_idx):
orig_batch,label_batch = self._parse_batch_by_dataset(batch,batch_idx)
recon_batch,logvar = self.forward(noisy_batch)
loss = self._loss_function(
recon_batch,orig_batch,reconstruction_function=self._recon_function
)
# self.logger.experiment.add_scalars("losses",{"train_loss": loss})
tb = self.logger.experiment
tb.add_scalars("losses",{"train_loss": loss},global_step=self.current_epoch)
# self.logger.experiment.add_scalar("train_loss",loss,self.current_epoch)
if batch_idx == len(self.train_dataloader()) - 2:
# https://pytorch.org/docs/stable/_modules/torch/utils/tensorboard/writer.html#SummaryWriter.add_embedding
# noisy_batch = noisy_batch.detach()
# recon_batch = recon_batch.detach()
# last_batch_plt = matplotlib.figure.Figure() # read https://github.com/wookayin/tensorflow-plot
# ax = last_batch_plt.add_subplot(1,1,1)
# ax.scatter(orig_batch[:,0],orig_batch[:,1],label="original")
# ax.scatter(noisy_batch[:,noisy_batch[:,label="noisy")
# ax.scatter(recon_batch[:,recon_batch[:,label="reconstructed")
# ax.legend(loc="upper left")
# self.logger.experiment.add_figure(f"original last batch,epoch {self.current_epoch}",last_batch_plt)
# tb.add_embedding(orig_batch,global_step=self.current_epoch,metadata=label_batch)
pass
self.logger.experiment.flush()
self.log("train_loss",prog_bar=True,on_step=False,on_epoch=True)
return loss
def _plot_batches(self,label_batch,batch_idx,recon_batch,logvar):
# orig_batch_view = orig_batch.reshape(-1,self._config.case.frames_per_pd_sample,# len(self._config.case.feature_list))
#
# plt.figure()
# plt.plot(orig_batch_view[11,:,0].detach().cpu().numpy(),label="feature 0")
# plt.legend(loc="upper left")
# plt.show()
tb = self.logger.experiment
if self._config.dataset == "mnist":
orig_batch -= orig_batch.min()
orig_batch /= orig_batch.max()
recon_batch -= recon_batch.min()
recon_batch /= recon_batch.max()
orig_grid = torchvision.utils.make_grid(orig_batch.view(-1,28,28))
val_recon_grid = torchvision.utils.make_grid(recon_batch.view(-1,28))
tb.add_image("original_val",orig_grid,global_step=self.current_epoch)
tb.add_image("reconstruction_val",val_recon_grid,global_step=self.current_epoch)
label_img = orig_batch.view(-1,28)
pass
elif self._config.dataset == "case":
orig_batch_view = orig_batch.reshape(-1,len(self._config.case.feature_list)).transpose(1,2)
recon_batch_view = recon_batch.reshape(-1,2)
# plt.figure()
# plt.plot(orig_batch_view[11,:].detach().cpu().numpy())
# plt.show()
# pass
n_samples = orig_batch_view.shape[0]
n_plots = min(n_samples,4)
first_sample_idx = 0
# TODO either plotting or data problem
fig,axs = plt.subplots(n_plots,1)
for sample_idx in range(n_plots):
for feature_idx,(orig_feature,recon_feature) in enumerate(
zip(orig_batch_view[sample_idx + first_sample_idx,:],recon_batch_view[sample_idx + first_sample_idx,:])):
i = feature_idx
if i > 0: continue # or scale issues don't allow informative plotting
# plt.figure()
# plt.plot(orig_feature.detach().cpu().numpy(),label=f'orig{i},sample{sample_idx}')
# plt.legend(loc='upper left')
# pass
axs[sample_idx].plot(orig_feature.detach().cpu().numpy(),sample{sample_idx}')
axs[sample_idx].plot(recon_feature.detach().cpu().numpy(),label=f'recon{i},sample{sample_idx}')
# sample{sample_idx}')
axs[sample_idx].legend(loc='upper left')
pass
# plt.show()
tb.add_figure("recon_vs_orig",fig,close=True)
def validation_step(self,reconstruction_function=self._recon_function
)
tb = self.logger.experiment
# can probably speed up training by waiting for epoch end for data copy from gpu
# see https://sagivtech.com/2017/09/19/optimizing-pytorch-training-code/
tb.add_scalars("losses",{"val_loss": loss},global_step=self.current_epoch)
label_img = None
if len(orig_batch) > 2:
self._last_val_batch = {
"orig_batch": orig_batch,"noisy_batch": noisy_batch,"label_batch": label_batch,"batch_idx": batch_idx,"recon_batch": recon_batch,"mu": mu,"logvar": logvar
}
# self._plot_batches(orig_batch,logvar)
outputs = {"val_loss": loss,"label_img": label_img}
self.log("val_loss",on_epoch=True)
return outputs
def validation_epoch_end(self,outputs: List[Any]) -> None:
first_batch_dict = outputs[-1]
self._plot_batches(
self._last_val_batch["orig_batch"],self._last_val_batch["noisy_batch"],self._last_val_batch["label_batch"],self._last_val_batch["batch_idx"],self._last_val_batch["recon_batch"],self._last_val_batch["mu"],self._last_val_batch["logvar"]
)
self.log(name="VAEFC_val_loss_epoch_end",value={"val_loss": first_batch_dict["val_loss"]})
def test_step(self,reconstruction_function=self._recon_function
)
tb = self.logger.experiment
tb.add_scalars("losses",{"test_loss": loss},global_step=self.global_step)
return {"test_loss": loss,"mus": mu,"labels": label_batch,"images": orig_batch}
def test_epoch_end(self,outputs: List):
tb = self.logger.experiment
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
self.log(name="test_epoch_end",value={"test_loss_avg": avg_loss})
if self._config.dataset == "mnist":
tb.add_embedding(
mat=torch.cat([o["mus"] for o in outputs]),metadata=torch.cat([o["labels"] for o in outputs]).detach().cpu().numpy(),label_img=torch.cat([o["images"] for o in outputs]).view(-1,28),global_step=self.global_step,)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(),lr=self._config.learning_rate)
return optimizer
def forward(self,x):
mu,logvar = self._encode(x)
z = self._reparametrize(mu,logvar)
decoded = self._decode(z)
return decoded,logvar
def train_vae(config,datamodule,latent_dim,dec_layer_sizes,enc_layer_sizes):
model = VAEFC(config=config,encoder_layer_sizes=enc_layer_sizes,decoder_layer_sizes=dec_layer_sizes)
logger = TensorBoardLogger(save_dir=tb_logs_folder,name='VAEFC',default_hp_metric=False)
logger.hparams = config
checkpoint_callback = ModelCheckpoint(dirpath=vae_checkpoints_path)
trainer = Trainer(deterministic=config.is_deterministic,# auto_lr_find=config.auto_lr_find,# log_gpu_memory='all',# min_epochs=99999,max_epochs=config.num_epochs,default_root_dir=vae_checkpoints_path,logger=logger,callbacks=[checkpoint_callback],gpus=1
)
# trainer.tune(model)
trainer.fit(model,datamodule=datamodule)
best_model_path = checkpoint_callback.best_model_path
print("done training vae with lightning")
print(f"best model path = {best_model_path}")
return trainer
def run_trained_vae(trainer):
# https://pytorch-lightning.readthedocs.io/en/latest/test_set.html
# (1) load the best checkpoint automatically (lightning tracks this for you)
trainer.test()
# (2) don't load a checkpoint,instead use the model with the latest weights
# trainer.test(ckpt_path=None)
# (3) test using a specific checkpoint
# trainer.test(ckpt_path='/path/to/my_checkpoint.ckpt')
# (4) test with an explicit model (will use this model and not load a checkpoint)
# trainer.test(model)
参数
对于我(手动)使用的任何参数组合,我都得到了非常相似的结果。也许我没有尝试什么。
num_epochs: 40
batch_size: 32
learning_rate: 0.0001
auto_lr_find: False
noise_factor: 0.1
regularization_factor: 0.0
train_size: 0.8
val_size: 0.1
num_workers: 1
dataset: "case" # toy,mnnist,case
mnist:
enc_layer_sizes: [784,512,]
dec_layer_sizes: [512,784]
latent_dim: 25
n_classes: 10
classifier_layers: [20,10]
toy:
enc_layer_sizes: [2,200,200]
dec_layer_sizes: [200,2]
latent_dim: 8
centers_radius: 4.0
n_clusters: 10
cluster_size: 5000
case:
#enc_layer_sizes: [ 1800,600,300,100 ]
#dec_layer_sizes: [ 100,1800 ]
#frames_per_pd_sample: 600
enc_layer_sizes: [ 10,300 ]
dec_layer_sizes: [ 600,10 ]
frames_per_pd_sample: 10
latent_dim: 300
n_classes: 10
classifier_layers: [ 20,10 ] # unused right now.
feature_list:
#- V_0_0 # 0,X
#- V_0_1 # 0,Y
#- V_0_2 # 0,Z
- pads_0
enc_kernel_sizes: [] # for conv
end_strides: []
dec_kernel_sizes: []
dec_strides: []
is_deterministic: False
real_data_pd_dir: "D:/pressure_pd"
case_dir: "real_case_20_min"
case_file: "pressure_data_0.pkl"
数据
对于 Mnist,一切正常。
换成我的具体数据时,结果如上。
数据是一个时间序列,具有多个特征。为了更简单,我只提供一个特征,切成等长的块,然后作为向量输入输入层。
数据是时间序列这一事实可能有助于将来建模,但现在我只想将其称为数据块,我相信我正在这样做。
代码:
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import torch
from Testing.Research.config.ConfigProvider import ConfigProvider
import os
import pickle
import pandas as pd
from typing import Tuple
import numpy as np
class CaseDataset(Dataset):
def __init__(self,path):
super(CaseDataset,self).__init__()
self._path = path
self._config = ConfigProvider.get_config()
self.frames_per_pd_sample = self._config.case.frames_per_pd_sample
self._load_case_from_pkl()
self.__len = len(self._full) // self.frames_per_pd_sample # discard last non full batch
def _load_case_from_pkl(self):
assert os.path.isfile(self._path)
with open(self._path,"rb") as f:
p = pickle.load(f)
self._full: pd.DataFrame = p["full"]
self._subsampled: pd.DataFrame = p["subsampled"]
self._misc: pd.DataFrame = p["misc"]
feature_list = self._config.case.feature_list
self._features_df = self._full[feature_list].copy()
# normalize from -1 to 1
features_to_normalize = self._features_df.columns
self._features_df[features_to_normalize] = \
self._features_df[features_to_normalize].apply(lambda x: (((x - x.min()) / (x.max() - x.min())) * 2) - 1)
pass
def __len__(self):
# number of samples in the dataset
return self.__len
def __getitem__(self,index: int) -> Tuple[np.array,np.array]:
data_item = self._features_df.iloc[index * self.frames_per_pd_sample: (index + 1) * self.frames_per_pd_sample,:].values
label = 0.0
# plt.figure()
# plt.plot(data_item[:,label="feature 0")
# plt.legend(loc="upper left")
# plt.show()
return data_item,label
每批次的时间步长似乎不会影响收敛。
训练测试 val split
是这样完成的:
import os
from pytorch_lightning import LightningDataModule
import torchvision.datasets as datasets
from torchvision.transforms import transforms
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from Testing.Research.config.paths import mnist_data_download_folder
from Testing.Research.datasets.real_cases.CaseDataset import CaseDataset
from typing import Optional
class CaseDataModule(LightningDataModule):
def __init__(self,config,path):
super().__init__()
self._config = config
self._path = path
self._train_dataset: Optional[Subset] = None
self._val_dataset: Optional[Subset] = None
self._test_dataset: Optional[Subset] = None
def prepare_data(self):
pass
def setup(self,stage):
# transform
transform = transforms.Compose([transforms.ToTensor()])
full_dataset = CaseDataset(self._path)
train_size = int(self._config.train_size * len(full_dataset))
val_size = int(self._config.val_size * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train,val,test = torch.utils.data.random_split(full_dataset,[train_size,val_size,test_size])
# assign to use in dataloaders
self._full_dataset = full_dataset
self._train_dataset = train
self._val_dataset = val
self._test_dataset = test
def train_dataloader(self):
return DataLoader(self._train_dataset,batch_size=self._config.batch_size,num_workers=self._config.num_workers)
def val_dataloader(self):
return DataLoader(self._val_dataset,num_workers=self._config.num_workers)
def test_dataloader(self):
return DataLoader(self._test_dataset,num_workers=self._config.num_workers)
问题
- 我相信验证损失始终低于火车损失表明这里出现了很大的问题,但我无法确定是什么,或者想出如何验证这一点。
- 我怎样才能让模型正确地自动编码数据?基本上,我希望它学习恒等函数,并让损失反映这一点。
- 损失似乎并不反映重建。我认为这可能是最根本的问题
我的想法
- 尝试用卷积网络代替 FC?也许它能够更好地学习特征?
- 没有想法:(
将提供任何缺少的信息。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。