微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

tensorflow 2,梯度是一个空列表

如何解决tensorflow 2,梯度是一个空列表

我正在重新组织我的代码以使其易于阅读,但是在编译时它说: ValueError: No gradients provided for any variable: ['enc_conv_4/kernel:0'...。我知道我的损失函数是可微的,因为代码在接触它之前就可以工作,但现在缺少我的模型的梯度。

    @tf.function
    def train_disc(self,real_imgs,gen_imgs):
      with tf.GradientTape() as disc_tape:
        d_loss = self.wasserstein_loss(real_imgs,gen_imgs)
      gradients_d = disc_tape.gradient(d_loss,self.discriminator.trainable_variables)
      self.d_optimizer.apply_gradients(zip(gradients_d,self.discriminator.trainable_variables))
      return d_loss

    @tf.function
    def train_gen(self,real_img,gen_imgs,mask,img_feat,rot_feat_mean):
      with tf.GradientTape() as gen_tape:
        g_loss_param = self.generator_loss(mask,rot_feat_mean)
        g_loss = g_loss_param(real_img,gen_imgs)
      gradients_g = gen_tape.gradient(g_loss,self.generator.trainable_variables)
      print(gradients_g)
      self.g_optimizer.apply_gradients(zip(gradients_g,self.generator.trainable_variables))

如您所见,当我对判别器和生成器执行相同操作时,生成器给了我一个空的梯度列表。

gen_imgs = self.generator([real_img,mask],training=True)


d_loss = self.train_disc(real_img,gen_imgs[:,:,:-1])

if step%self.n_critic == 0:
  masked_images = real_img * mask
  idx = 3  # index of desired layer
  layer_input = Input(shape=(self.img_shape))  #
  x = layer_input
  for layer in self.generator.layers[idx:idx+12]:
      x = layer(x)
  model_feat = Model(inputs=layer_input,outputs=x)
  model_feat.trainable = False
  img_feat = model_feat(masked_images,training=False)
  rot_feat_mean = []
  for i in range(self.batch_size):
      rot = []
      for an in [180,155,130,105,80,55,20,10]:
          r = tf.keras.preprocessing.image.random_rotation(masked_images[i],an,row_axis=0,col_axis=1,channel_axis=2)
          rot.append(r)
      rot = np.array(rot)
      rot_feat_mean.append(np.mean(model_feat(rot,training=False),axis=0))
  rot_feat_mean = np.array(rot_feat_mean)
  g_loss = self.train_gen(real_img,:-1],rot_feat_mean)

最后一段代码的最后一行给了我一个错误。不知道这个错误是不是语义错误造成的。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?