如何解决Tensorflow Keras自定义指标错误在update_state上
上下文
使用带有tf.Estimator
界面的TF 1.15来训练和评估模型。尝试使用tf.keras.metric.Metric
来编写自定义TF指标。
问题
我编写了一个自定义指标,并将其包含在eval_metrics_ops
中(以下示例)。如果我使用指标定义估算器,则会出现以下错误。
ValueError: Please call update_state(...) on the "<metric_name>" metric
错误的措词看起来很清楚(我必须致电update_state()
),但是我不确定在度量标准上应该在哪里致电update_state()
(不确定我是否应该致电)。这不是一个最小的示例,但这是我编写的一个演示指标。
class MyMetric(tf.keras.metrics.Metric):
def __init__(self,name="my_metric",**kwargs):
super(MyMetric,self).__init__(name=name,**kwargs)
def update_state(self,y_true,y_pred,sample_weight=None):
self.true_samples = tf.reduce_sum(y_true)
def result(self):
return self.true_samples
创建一个dict
,其中度量标准名称是键,而度量标准实例是值。 This提到了如何为dict
创建eval_metrics_ops
。
metrics_ops = {"my_metric": MyMetric()}`. # The TensorFlow 1.15 documentation does not say we have to call `update_state(....) anywhere.`
estimator_spec = tf.estimator.EstimatorSpec(mode,model.loss,eval_metric_ops=metrics_ops)
有什么主意我该如何消除这个错误?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。