微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

python-tf.constant和tf.placeholder的行为不同

我想将tf.metrics封装在Sonnet模块上以测量每个批次的性能,以下是我所做的工作:

import tensorflow as tf
import sonnet as snt

class Metrics(snt.AbstractModule):
    def __init__(self, indicator, summaries = None, name = "metrics"):
        super(Metrics, self).__init__(name = name)
        self._indicator = indicator
        self._summaries = summaries

    def _build(self, labels, logits):
        if self._indicator == "accuracy":
            metric, metric_update = tf.metrics.accuracy(labels, logits)
            with tf.control_dependencies([metric_update]):
                outputs = tf.identity(metric)
        elif self._indicator == "precision":
            metric, metric_update = tf.metrics.precision(labels, logits)
            with tf.control_dependencies([metric_update]):
                outputs = tf.identity(metric)
        elif self._indicator == "recall":
            metric, metric_update = tf.metrics.recall(labels, logits)
            with tf.control_dependencies([metric_update]):
                outputs = tf.identity(metric)
        elif self._indicator == "f1_score":
            metric_recall, metric_update_recall = tf.metrics.recall(labels, logits)
            metric_precision, metric_update_precision = tf.metrics.precision(labels, logits)
            with tf.control_dependencies([metric_update_recall, metric_update_precision]):
                outputs = 2.0 / (1.0 / metric_recall + 1.0 / metric_precision)
        else:
            raise ValueError("unsupported metrics")

        if type(self._summaries) == list:
            self._summaries.append(tf.summary.scalar(self._indicator, outputs))

        return outputs

但是,当我想测试模块时,以下代码可以工作:

def test3():
    import numpy as np

    labels = tf.constant([1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], tf.int32)
    logits = tf.constant([1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], tf.int32)

    metrics = Metrics("accuracy")
    accuracy = metrics(labels, logits)

    metrics2 = Metrics("f1_score")
    f1_score = metrics2(labels, logits)

    writer = tf.summary.FileWriter("utils-const", tf.get_default_graph())
    with tf.Session() as sess:
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

        accu, f1 = sess.run([accuracy, f1_score])
        print(accu)
        print(f1)

    writer.close()

但是,以下代码不起作用:

def test4():
    from tensorflow.python import debug as tf_debug
    import numpy as np

    tf_labels = tf.placeholder(dtype=tf.int32, shape=[None])
    tf_logits = tf.placeholder(dtype=tf.int32, shape=[None])

    labels = np.array([1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], np.int32)
    logits = np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], np.int32)

    metrics = Metrics("accuracy")
    accuracy = metrics(tf_labels, tf_logits)

    metrics2 = Metrics("f1_score")
    f1_score = metrics2(tf_labels, tf_logits)

    writer = tf.summary.FileWriter("utils-Feed", tf.get_default_graph())
    with tf.Session() as sess:
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

        sess = tf_debug.LocalCLIDebugWrapperSession(sess)

        accu, f1 = sess.run([accuracy, f1_score], Feed_dict = {tf_labels: labels, tf_logits: logits})
        print(accu)
        print(f1)

    writer.close()

test3()的输出正确,为0.88. test4()的输出错误,为0.0.但是,它们应该等效.

有人知道吗?

解决方法:

您确定不是tf.constant版本失败吗?我发现tf.metrics与tf.constant结合在一起具有怪异的行为:

import tensorflow as tf

a = tf.constant(1.)
mean_a, mean_a_uop = tf.metrics.mean(a)
with tf.control_dependencies([mean_a_uop]):
  mean_a = tf.identity(mean_a)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()

for _ in range(10):
  print(sess.run(mean_a))

在GPU上运行时返回

0.0
2.0
1.5
1.3333334
1.25
1.2
1.1666666
1.1428572
1.125
1.1111112

而不是1s.似乎计数落后一. (我假设第一个值将是inf,但由于计数条件的限制,该值为零).另一方面,此代码的占位符版本正在按预期运行.

cpu上,行为甚至更奇怪,因为输出是不确定的.输出示例:

0.0
1.0
1.0
0.75
1.0
1.0
0.85714287
0.875
1.0
0.9

看起来像是您可以登录tensorflow’s github repo错误.(请注意,在常量上使用运行指标没有多用-但这仍然是错误).

编辑现在,我还偶然发现了一个带有tf.placeholder的怪异示例,看来tf.metrics有一个错误,不幸的是,该错误不仅限于与tf.constants一起使用.

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐