微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow Keras自定义指标错误在update_state上

如何解决Tensorflow Keras自定义指标错误在update_state上

上下文

使用带有tf.Estimator界面的TF 1.15来训练和评估模型。尝试使用tf.keras.metric.Metric来编写自定义TF指标。

问题

我编写了一个自定义指标,并将其包含在eval_metrics_ops中(以下示例)。如果我使用指标定义估算器,则会出现以下错误

ValueError: Please call update_state(...) on the "<metric_name>" metric 

错误的措词看起来很清楚(我必须致电update_state()),但是我不确定在度量标准上应该在哪里致电update_state()(不确定我是否应该致电)。这不是一个最小的示例,但这是我编写的一个演示指标。

class MyMetric(tf.keras.metrics.Metric):
    def __init__(self,name="my_metric",**kwargs):
        super(MyMetric,self).__init__(name=name,**kwargs)

    def update_state(self,y_true,y_pred,sample_weight=None):
        self.true_samples = tf.reduce_sum(y_true)
    
    def result(self):
        return self.true_samples

创建一个dict,其中度量标准名称是键,而度量标准实例是值。 This提到了如何为dict创建eval_metrics_ops

metrics_ops = {"my_metric": MyMetric()}`. # The TensorFlow 1.15 documentation does not say we have to call `update_state(....) anywhere.`
estimator_spec = tf.estimator.EstimatorSpec(mode,model.loss,eval_metric_ops=metrics_ops)

有什么主意我该如何消除这个错误

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。