如何解决tf.where的行为不符合预期的张量
我尝试了以下代码:
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19),(1-x)*tf.math.log(1 - b + 1e-19))
产生的结果与以下结果不同:
a = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19)
这里x是2或0或1的二进制变量。b是0到1之间的实数值。
我缺少什么吗?
我比较两个答案的方式是tf.reduce_sum(a)
找到的解决方案:对于x = 0或x = 1,2确实等效。我使用的数据是2D张量,其某些位不是0或1。
这是通过以下方式发现的
tf.unique(tf.reshape(x,(-1,))
解决方法
代码示例:
# When x = 0.0
x = 0.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19),(1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # 0 + (1-x)*tf.math.log(1 - b + 1e-19) = -0.6931472
# When x = 1.0
x = 1.0
b = 0.5
a = tf.where(tf.greater_equal(x,(1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from x*tf.math.log(b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + 0 = -0.6931472
# When x = 0.4
x = 0.4
b = 0.5
a = tf.where(tf.greater_equal(x,(1-x)*tf.math.log(1 - b + 1e-19)) # -0.41588834 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) = -0.27725 + -0.41588 = -0.6931471824645996.
问题中提到的两个代码都将产生相同结果的唯一情况是x = 1或0时。
tf.reduce_sum(a)
在这里,a是一个标量,因此不会改变a的值。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。