微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 tf.cond() 检查自定义层的调用方法中的条件

如何解决使用 tf.cond() 检查自定义层的调用方法中的条件

我正在 tensorflow 2.x 中实现一个自定义层。我的要求是,程序应该在返回输出之前检查条件。

class SimpleRNN_cell(tf.keras.layers.Layer):
    def __init__(self,M1,M2,fi=tf.nn.tanh,disp_name=True):
        super(SimpleRNN_cell,self).__init__()
        pass        
    def call(self,X,hidden_state,return_state=True):
        y = tf.constant(5)
        if return_state == True:
            return y,self.h
        else:
            return y

我的问题是:我应该继续使用当前代码(假设 tape.gradient(Loss,self.trainable_weights) 可以正常工作)还是应该使用 tf.cond()。 此外,如果可能,请说明在何处使用 tf.cond() 以及在何处不使用。我没有找到太多关于这个主题内容

解决方法

tf.cond 仅在基于可微计算图中的数据执行条件评估时才相关。 (https://www.tensorflow.org/api_docs/python/tf/cond) 这在默认图形模式的 TF 1.0 中尤为必要。对于急切模式,GradientTape 系统还允许使用诸如 if ...: (https://www.tensorflow.org/guide/autodiff#control_flow)

之类的 Python 构造进行条件数据流

然而,仅仅根据配置参数提供不同的行为,不依赖于计算图的数据并且在模型运行时固定,使用简单的 python if 语句是正确的。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?