微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

python – 使用自定义Estimator的Tensorflow指标

我有一个卷积神经网络,我最近重构使用Tensorflow的Estimator API,主要是在this tutorial之后.但是,在训练期间,我添加到EstimatorSpec的指标没有显示在Tensorboard上,并且似乎没有在tfdbg中进行评估,尽管写入Tensorboard的图表中显示名称范围和指标.

model_fn的相关位如下:

 ...

 predictions = tf.placeholder(tf.float32, [num_classes], name="predictions")

 ...

 with tf.name_scope("metrics"):
    predictions_rounded = tf.round(predictions)
    accuracy = tf.metrics.accuracy(input_y, predictions_rounded, name='accuracy')
    precision = tf.metrics.precision(input_y, predictions_rounded, name='precision')
    recall = tf.metrics.recall(input_y, predictions_rounded, name='recall')

if mode == tf.estimator.ModeKeys.PREDICT:
    spec = tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions)
elif mode == tf.estimator.ModeKeys.TRAIN:

    ...

    # if we're doing softmax vs sigmoid, we have different metrics
    if cross_entropy == CrossEntropyType.softmax:
        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall
        }
    elif cross_entropy == CrossEntropyType.SIGMOID:
        metrics = {
            'precision': precision,
            'recall': recall
        }
    else:
        raise NotImplementedError("Unrecognized cross entropy function: {}\t Available types are: softmax, SIGMOID".format(cross_entropy))
    spec = tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=metrics)
else:
    raise NotImplementedError('ModeKey provided is not supported: {}'.format(mode))

return spec

任何人都有任何想法为什么这些没有写?我正在使用Tensorflow 1.7和Python 3.5.我试过通过tf.summary.scalar明确地添加它们,虽然它们确实以这种方式进入Tensorboard,但是在第一次通过图形之后它们永远不会更新.

解决方法:

metrics API有一个扭曲,让我们以tf.metrics.accuracy为例(所有tf.metrics.*工作相同).这将返回2个值,精度指标和upate_op,这看起来像是您的第一个错误.你应该有这样的东西:

accuracy, update_op = tf.metrics.accuracy(input_y, predictions_rounded, name='accuracy')

精确度只是您期望计算的值,但请注意,您可能希望在多次调用sess.run时计算准确性,例如,当您计算不完全适合的大型测试集的准确性时记忆.这就是update_op的用武之地,它会产生结果,因此当你要求准确性时,它会为你提供一个运行记录.

update_op没有依赖项,因此您需要在sess.run中显式运行它或添加依赖项.例如,您可以将其设置为依赖于成本函数,以便在计算成本函数时计算update_op(导致运行计数以更新准确性):

with tf.control_dependencies(cost):
  tf.group(update_op, other_update_ops, ...)

您可以使用局部变量初始值设定项重置度量标准的值:

sess.run(tf.local_variables_initializer())

您需要使用tf.summary.scalar(精度)向tensorboard添加精度,如您所提到的那样(尽管看起来您添加错误的东西).

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐