如何解决TensorFlow 2:检测某个节点属于哪个图
我刚开始(自学)学习 TensorFlow,我决定学习我在当地图书馆找到的书 "Learning TensorFlow"。 不幸的是,书中他们使用的是 TensorFlow 1.x,而我想使用 2.4 版本。
我在复制第3章的例子时遇到了一些麻烦。代码的重点是创建一个新的空计算图,创建一个节点(即在这种情况下是一个常量),然后确定该节点是否属于默认图形或新创建的图形。 这是书中的代码,它应该可以与 TensorFlow1 配合使用:
import tensorflow as tf
print(tf.get_default_graph())
g = tf.Graph() # This creates a new empty graph
a = tf.constant(5) # This creates a node
print(a.graph is g)
print(a.graph is tf.get_default_graph())
我确实意识到属性 get_default_graph() 在 TensorFlow 2 中不再可用,我用 tf.compat.v1.get_default_graph() 代替,但我仍然得到以下错误:
AttributeError: Tensor.graph 在启用 Eager Execution 时毫无意义。
任何帮助将不胜感激!提前致谢!
解决方法
导入 tensorflow 后需要禁用 Eager Execution,如下所示:
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
print(tf.compat.v1.get_default_graph())
g = tf.Graph() # This creates a new empty graph
a = tf.constant(5) # This creates a node
print(a.graph is g)
print(a.graph is tf.compat.v1.get_default_graph())
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。