如何解决如何使TFMA / Beam可以使用自定义指标?
我创建了一个自定义Keras指标,类似于下面的演示实现:
import tensorflow as tf
class MyMetric(tf.keras.metrics.Mean):
def __init__(self,name='my_metric',dtype=None):
super(MyMetric,self).__init__(name=name,dtype=dtype)
def update_state(self,y_true,y_pred,sample_weight=None):
return super(MyMetric,self).update_state(
y_pred,sample_weight=sample_weight)
我已将实现转换为带有init / main文件的Python模块,并将路径添加到系统的PYTHONPATH
。
训练Keras模型时可以使用指标。
不幸的是,我还没有找到一种使TensorFlow模型分析(TFMA)可以使用自定义指标的方法。
在我的交互式上下文笔记本中,我可以在创建eval_config
时加载指标。
import tensorflow as tf
import tensorflow_model_analysis as tfma
from mymetric.metric import MyMetric
metrics = [MyMetric()]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(label_key='label_xf')],metrics_specs=metrics_specs,slicing_specs=[tfma.SlicingSpec()]
)
evaluator = Evaluator(
examples=example_gen.outputs['examples'],model=trainer.outputs['model'],baseline_model=model_resolver.outputs['model'],eval_config=eval_config)
当我尝试执行evaluator
时,该度量标准如度量标准规范中列出
metrics_specs {
metrics {
class_name: "MyMetric"
config: "{\"dtype\": \"float32\",\"name\": \"my_metric\"}"
threshold {
}
}
}
但是执行失败并显示错误
ValueError: Unknown metric function: MyMetric
由于度量标准计算是通过Apache Beam的executor.Do
函数执行的,因此我假设Beam找不到模块(即使它在PYTHONPATH上)。如果是这种情况,我该如何在PYTHONPATH配置之外使该模块可用于Apache Beam?
跟踪:
/usr/local/lib/python3.6/dist-packages/tensorflow_model_analysis/metrics/metric_specs.py in _deserialize_tf_metric(metric_config,custom_objects)
741 cls_name,cfg = _tf_class_and_config(metric_config)
742 with tf.keras.utils.custom_object_scope(custom_objects):
--> 743 return tf.keras.metrics.deserialize({'class_name': cls_name,'config': cfg})
744
745
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/metrics.py in deserialize(config,custom_objects)
3441 module_objects=globals(),3442 custom_objects=custom_objects,-> 3443 printable_module_name='metric function')
3444
3445
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier,module_objects,custom_objects,printable_module_name)
345 config = identifier
346 (cls,cls_config) = class_and_config_for_serialized_keras_object(
--> 347 config,printable_module_name)
348
349 if hasattr(cls,'from_config'):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config,printable_module_name)
294 cls = get_registered_object(class_name,module_objects)
295 if cls is None:
--> 296 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
297
298 cls_config = config['config']
ValueError: Unknown metric function: MyMetric
解决方法
您需要指定模块,以便TFX知道在哪里可以找到MyMetric类。一种方法是将其指定为度量标准的一部分:
BEGIN TRY
CREATE DATABASE Testing;
END TRY
BEGIN CATCH
IF ERROR_NUMBER() = 1801
PRINT N'Database already exists.';
ELSE
THROW;
END CATCH
from tensorflow_model_analysis import config
metric_config = [config.MetricConfig(class_name='MyMetric',module='mymodule.mymetric')]
您还需要创建一个名为metrics_specs = [config.MetricsSpec(metrics=metric_config)]
的模块,并将您的mymodule
类放入MyMetric
中,以使其正常工作。还要确保从执行代码的位置可以访问该模块(如果已将其添加到PYTHONPATH中,则应为这种情况)。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。