如何解决使用 tf.cond() 检查自定义层的调用方法中的条件
我正在 tensorflow 2.x 中实现一个自定义层。我的要求是,程序应该在返回输出之前检查条件。
class SimpleRNN_cell(tf.keras.layers.Layer):
def __init__(self,M1,M2,fi=tf.nn.tanh,disp_name=True):
super(SimpleRNN_cell,self).__init__()
pass
def call(self,X,hidden_state,return_state=True):
y = tf.constant(5)
if return_state == True:
return y,self.h
else:
return y
我的问题是:我应该继续使用当前代码(假设 tape.gradient(Loss,self.trainable_weights)
可以正常工作)还是应该使用 tf.cond()
。
此外,如果可能,请说明在何处使用 tf.cond()
以及在何处不使用。我没有找到太多关于这个主题的内容。
解决方法
tf.cond
仅在基于可微计算图中的数据执行条件评估时才相关。 (https://www.tensorflow.org/api_docs/python/tf/cond) 这在默认图形模式的 TF 1.0 中尤为必要。对于急切模式,GradientTape
系统还允许使用诸如 if ...:
(https://www.tensorflow.org/guide/autodiff#control_flow)
然而,仅仅根据配置参数提供不同的行为,不依赖于计算图的数据并且在模型运行时固定,使用简单的 python if
语句是正确的。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。