微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

变分自编码器重建损失太高,精度太低

如何解决变分自编码器重建损失太高,精度太低

我是 Python 新手,并且有一个 VAE 项目。 输入图像为 128 * 128 * 3。 训练过程很慢,最大的问题是损失值太高,但准确率太低。这些值很荒谬,我认为问题可能出在 class VAE(tf.keras.Model) 中。所以我把这段代码在这里: '''

类 VAE(tf.keras.Model):

def __init__(self,latent_dim):
  super(VAE,self).__init__()
  self.latent_dim = latent_dim

  # Model
  self.encoder = create_encoder(latent_dim)

  self.decoder = create_decoder(latent_dim)

  # Metrics 
  self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
  self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
  self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
  self.accuracy_tracker = tf.keras.metrics.Accuracy(name="accuracy")

def train_step(self,data):

  # Use GradientTape to record everything we need to compute the gradient
  with tf.GradientTape() as tape:

     # Data: input = output
     X,t = data

     # Encoder
     z_mean,z_log_var,z = self.encoder(X)

     # Decoder
     y = self.decoder(z)

     reconstruction_loss = tf.reduce_mean(
          tf.reduce_sum(
              tf.keras.losses.binary_crossentropy(X,y),axis=1 
          )
     )

     kl_loss = tf.reduce_mean(
         tf.reduce_sum(
             -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),axis = 1
         )
     )

     total_loss = reconstruction_loss + kl_loss

 # Compute gradients
  grads = tape.gradient(total_loss,self.trainable_weights)

  # Apply gradients using the optimizer
  self.optimizer.apply_gradients(zip(grads,self.trainable_weights))

  # Update metrics 
  self.total_loss_tracker.update_state(total_loss)
  self.reconstruction_loss_tracker.update_state(reconstruction_loss)
  self.kl_loss_tracker.update_state(kl_loss)
  true_image = tf.reshape(tf.argmax(X,axis=1),shape=(-1,1))
  predicted_image = tf.reshape(tf.argmax(y,1))
  self.accuracy_tracker.update_state(true_image,predicted_image)

  # Return a dic mapping matric names to current value
  return {
      "loss": self.total_loss_tracker.result(),"recontruction_loss": self.reconstruction_loss_tracker.result(),"kl_loss": self.kl_loss_tracker.result(),"accuracy": self.accuracy_tracker.result()
  }

@property
def metrics(self):
  return [
          self.total_loss_tracker,self.reconstruction_loss_tracker,self.kl_loss_tracker,self.accuracy_tracker
  ]

'''

这个截图是编码器和解码器的总结。每个 epoch 的输出重建损失为 88 轮,kl 损失超过 5,精度小于 0.03。这是不正常的,尝试了几种方法来摆脱这种情况,但没有成功。我希望我能从这里得到帮助。非常感谢你。

encoder-decoder summary

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。