如何解决使用自定义指标获取 keras 序列模型中 sample_weight 的当前纪元数
我正在使用 Keras、Tensorflow 2 顺序模型。我想计算一个 Revenue@k 作为我的指标。所以我写了一个自定义指标如下。我需要附加参数 _sample_weight 来给出产品的价格。
def revenue_1(_sample_weight):
def revenue(y_true,y_pred):
m = keras_metrics.Precision(top_k=1)
# select the batch size of the sample weight (price)
m.update_state(y_true,y_pred,sample_weight=_sample_weight)
return m.result().numpy()
return revenue
我调用这个函数如下:
self.model.compile(loss=self.loss,optimizer=self.optimizer_method,run_eagerly=True,metrics[custom_metrics.revenue_1(self.input_data['train_price'],self.model)])
问题是y_true和y_pred形状是batch的大小,而sample_weight是完整的数据大小。我不知道是否可以访问内部数据以使用当前批次对 sample_weight 进行切片。我该怎么做?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。