微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

分布式训练时如何在 TensorFlow 中应用梯度裁剪?

如何解决分布式训练时如何在 TensorFlow 中应用梯度裁剪?

我想知道在分布式训练时如何在 TensorFlow 中应用梯度裁剪。 这是我的代码

    @lazy_property
    def optimize(self):
        # train_vars = ...
        optimizer = tf.train.AdamOptimizer(self._learning_rate)
        self.syn_op = tf.train.SyncReplicasOptimizer(optimizer,replicas_to_aggregate=self.gradient_merge,total_num_replicas=self.worker_count,use_locking=True)
        self.sync_replicas_hook = self.syn_op.make_session_run_hook(is_chief=self.is_chief)
        return self.syn_op.minimize(self.cost,var_list=train_vars,global_step=self.global_step)

我已阅读此答案:How to apply gradient clipping in TensorFlow。 下面是答案中渐变剪辑的代码

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        gvs = optimizer.compute_gradients(cost)
        capped_gvs = [(tf.clip_by_value(grad,-1.,1.),var) for grad,var in gvs]
        train_op = optimizer.apply_gradients(capped_gvs)

我应该在哪里更改以在我的情况下使用它?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?