微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

tensorflow 2.x 代码中的条件报错

如何解决tensorflow 2.x 代码中的条件报错

迁移到 tensorflow 2.x. Win 10,tf 版本为 2.3.1。基本上,

import tensorflow as tf

def do_nothing(x,y):
    m,n = x.shape
    if m==n:
        return x,y
    else:
        raise Exception('should never arrive here')

xys = [[tf.eye(2),tf.eye(3)],[tf.eye(4),tf.eye(5)],]

@tf.function
def foo():
    return [do_nothing(x,y) for (x,y) in xys]
    
ans = foo()

有效。然后我只是将条件 m==m 更改为 tf.equal(m,n) as

import tensorflow as tf

def do_nothing(x,n = x.shape
    if tf.equal(m,n):
        return x,y) in xys]
    
ans = foo()

编码器不再工作。真是纳闷了。是错误还是什么?

我尝试了额外的实验来用更少的代码重现这个问题。看起来,如果您使用 tf.equaltf.greater 之类的东西,那么 ifelse 子句必须返回相同类型和大小的张量。请参阅下面的代码

import tensorflow as tf

#this piece works
@tf.function  
def foo1(x):
    if tf.greater(len(x),0):
        return True
    else:
        return False
print(foo1(tf.zeros([1])))
print(foo1(tf.zeros([0])))

#this piece works too
@tf.function
def foo2(x):
    if len(x)>0: 
        return True
    else:
        raise Exception()
print(foo2(tf.zeros([1])))

#this piece no long works
@tf.function
def foo3(x):
    if tf.greater(len(x),0):
        return True
    else:
        raise Exception()
print(foo3(tf.zeros([1])))

解决方法

我认为原因是因为 tf 返回一个 bool 类型的 Tensor,而不是一个简单的 Bool。 http://tensorflow.biotecan.com/python/Python_1.8/tensorflow.google.cn/api_docs/python/tf/equal.html

参考我在 Google colab 中所做的测试:

https://colab.research.google.com/drive/1sR99ScE-IDsWz0rNCH6VsWVclw1wz5oE#scrollTo=FsMqxbpnJ-Xg&line=1&uniqifier=1

import tensorflow as tf

def do_nothing(x,y):
    m,n = x.shape
    print(x)
    print(y)
    print(m,n)
    print(m==m)
    print(n==n)
    print(m==n)
    print(tf.equal(m,m))
    print(tf.equal(n,n))
    print(tf.equal(m,n))
    if tf.equal(m,n):
        return x,y
    else:
        raise Exception('should never arrive here')

xys = [[tf.eye(2),tf.eye(3)],[tf.eye(4),tf.eye(5)],]

@tf.function
def foo():
    return [do_nothing(x,y) for (x,y) in xys]
    
ans = foo()



tf.Tensor(
[[1. 0.]
 [0. 1.]],shape=(2,2),dtype=float32)
tf.Tensor(
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]],shape=(3,3),dtype=float32)
2 2
True
True
True
Tensor("Equal:0",shape=(),dtype=bool)
Tensor("Equal_1:0",dtype=bool)
Tensor("Equal_2:0",dtype=bool)

---------------------------------------------------------------------------

Exception                                 Traceback (most recent call last)

<ipython-input-12-24121e0806b4> in <module>()
     24     return [do_nothing(x,y) in xys]
     25 
---> 26 ans = foo()

8 frames

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args,**kwargs)
    975           except Exception as e:  # pylint:disable=broad-except
    976             if hasattr(e,"ag_error_metadata"):
--> 977               raise e.ag_error_metadata.to_exception(e)
    978             else:
    979               raise

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?