如何解决VAE重建的图像非常模糊
我对机器学习非常陌生,并且已经从Keras VAE代码示例构建了VAE。我只更改了模型中的几层。我在kaggle猫狗数据集中训练了该模型,然后尝试重建一些图像。所有重建的图像看起来都像这些Reconstructed Images一样。这可能是什么原因?是因为模型不好,训练时间短还是在重建图像时出现错误?
编码器型号:
latent_dim = 2
encoder_inputs = keras.Input(shape=(328,328,3))
x = layers.Conv2D(32,3,strides=2,padding="same")(encoder_inputs)
x = layers.Activation("relu")(x)
x = layers.Batchnormalization()(x)
x = layers.Conv2D(64,padding="same")(x)
x = layers.Activation("relu")(x)
x = layers.Batchnormalization()(x)
x = layers.Conv2D(128,padding="same")(x) #neu
x = layers.Activation("relu")(x)
x = layers.Batchnormalization()(x)
x = layers.Flatten()(x)
x = layers.Dense(16,activation="relu")(x)
z_mean = layers.Dense(latent_dim,name="z_mean")(x)
z_log_var = layers.Dense(latent_dim,name="z_log_var")(x)
z = Sampling()([z_mean,z_log_var])
encoder = keras.Model(encoder_inputs,[z_mean,z_log_var,z],name="encoder")
encoder.summary()
解码器模型:
x = layers.Dense(41 * 41 * 128,activation="relu")(latent_inputs)
x = layers.Reshape((41,41,128))(x)
x = layers.Conv2DTranspose(128,activation="relu",padding="same")(x)
x = layers.Batchnormalization()(x)
x = layers.Conv2DTranspose(64,padding="same")(x)
x = layers.Batchnormalization()(x)
x = layers.Conv2DTranspose(32,padding="same")(x)
x = layers.Batchnormalization()(x)
decoder_outputs = layers.Conv2DTranspose(3,activation="sigmoid",padding="same")(x)
decoder = keras.Model(latent_inputs,decoder_outputs,name="decoder")
decoder.summary()
培训:
train_data_dir ='/content/Petimages'
nb_train_samples = 200
nb_epoch = 50
batch_size = 32
img_width = 328
img_height = 328
def fixed_generator(generator):
for batch in generator:
yield (batch,batch)
train_datagen = ImageDataGenerator(
rescale=1./255,)
train_generator = train_datagen.flow_from_directory(
train_data_dir,target_size=(img_width,img_height),batch_size=batch_size,class_mode=None)
vae = VAE(encoder,decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(
fixed_generator(train_generator),steps_per_epoch=nb_train_samples,epochs=nb_epoch,)
并重建图像:
import matplotlib.pyplot as plt
test2_datagen = ImageDataGenerator(rescale=1./255)
test2_generator = test2_datagen.flow_from_directory(
train_data_dir,batch_size=10,class_mode=None)
sample_img = next(test2_generator)
z_points = vae.encoder.predict(sample_img)
reconst_images = vae.decoder.predict(z_points)
fig = plt.figure(figsize=(10,8))
fig.subplots_adjust(hspace=0.1,wspace=0.1)
n_to_show =2
for i in range(n_to_show):
img = sample_img[i].squeeze()
sub = fig.add_subplot(2,n_to_show,i+1)
sub.axis('off')
sub.imshow(img)
for i in range(n_to_show):
img = reconst_images[i].squeeze()
sub = fig.add_subplot(2,i+n_to_show+1)
sub.axis('off')
sub.imshow(img)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。