如何解决变分自编码器重建损失太高,精度太低
我是 Python 新手,并且有一个 VAE 项目。 输入图像为 128 * 128 * 3。 训练过程很慢,最大的问题是损失值太高,但准确率太低。这些值很荒谬,我认为问题可能出在 class VAE(tf.keras.Model) 中。所以我把这段代码放在这里: '''
类 VAE(tf.keras.Model):
def __init__(self,latent_dim):
super(VAE,self).__init__()
self.latent_dim = latent_dim
# Model
self.encoder = create_encoder(latent_dim)
self.decoder = create_decoder(latent_dim)
# Metrics
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
self.accuracy_tracker = tf.keras.metrics.Accuracy(name="accuracy")
def train_step(self,data):
# Use GradientTape to record everything we need to compute the gradient
with tf.GradientTape() as tape:
# Data: input = output
X,t = data
# Encoder
z_mean,z_log_var,z = self.encoder(X)
# Decoder
y = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
tf.keras.losses.binary_crossentropy(X,y),axis=1
)
)
kl_loss = tf.reduce_mean(
tf.reduce_sum(
-0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),axis = 1
)
)
total_loss = reconstruction_loss + kl_loss
# Compute gradients
grads = tape.gradient(total_loss,self.trainable_weights)
# Apply gradients using the optimizer
self.optimizer.apply_gradients(zip(grads,self.trainable_weights))
# Update metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
true_image = tf.reshape(tf.argmax(X,axis=1),shape=(-1,1))
predicted_image = tf.reshape(tf.argmax(y,1))
self.accuracy_tracker.update_state(true_image,predicted_image)
# Return a dic mapping matric names to current value
return {
"loss": self.total_loss_tracker.result(),"recontruction_loss": self.reconstruction_loss_tracker.result(),"kl_loss": self.kl_loss_tracker.result(),"accuracy": self.accuracy_tracker.result()
}
@property
def metrics(self):
return [
self.total_loss_tracker,self.reconstruction_loss_tracker,self.kl_loss_tracker,self.accuracy_tracker
]
'''
这个截图是编码器和解码器的总结。每个 epoch 的输出重建损失为 88 轮,kl 损失超过 5,精度小于 0.03。这是不正常的,尝试了几种方法来摆脱这种情况,但没有成功。我希望我能从这里得到帮助。非常感谢你。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。