如何解决在 tf2.1.0-keras 中使用条件
我试图在我自己的 tf2.1.0-keras 模型中使用 bool 条件,下面是一个简单的例子:
import tensorflow as tf
class TestKeras:
def __init__(self):
pass
def build_graph(self):
x = tf.keras.Input(shape=(2),batch_size=1)
x_value = x[0,0]
y = tf.cond(x_value > 0,lambda :tf.add(x_value,0),0))
return tf.keras.models.Model(inputs=[x],outputs=[y])
if __name__ == "__main__":
tk = TestKeras()
model = tk.build_graph()
model.summary(line_length=100)
但它似乎不起作用并抛出异常:
using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
我已尝试将 tf.cond
替换为 tf.keras.backend.switch
,但仍然出现相同的错误。
我还尝试将代码 y = tf.cond(xxx)
拆分为单个函数并添加 @tf.funcion
装饰器:
@tf.function
def compute_y(self,x):
return tf.cond(x > 0,lambda :tf.add(x,0))
但它又出现了另一个错误:
Inputs to eager execution function cannot be Keras symbolic tensors,but found [<tf.Tensor 'strided_slice:0' shape=() dtype=float32>]
有谁知道条件如何在 tf2.1.0-keras 中工作?
解决方法
tf.keras.Input
是一个符号张量,用于定义 keras 模型的输入。每当您想在 keras 模型中应用自定义逻辑时,您应该继承 Layer
类,或者使用 Lambda
层。
例如,使用 Lambda
层:
class TestKeras:
def __init__(self):
pass
def build_graph(self):
x = tf.keras.Input(shape=(2),batch_size=1)
def custom_fct(x):
x_value = x[0,0]
return tf.cond(x_value > 0,lambda :tf.add(x_value,0),0))
y = tf.keras.layers.Lambda(custom_fct)(x)
return tf.keras.models.Model(inputs=[x],outputs=[y])
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。