微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 keras 中使用 train_on_batch 自定义 Loss fnc 以进行回放学习

如何解决在 keras 中使用 train_on_batch 自定义 Loss fnc 以进行回放学习

学习者,

我想用小批量 using a custom loss function 训练 NN。每个小批量包含 n new samples and m replay samples。重放样本用于重放,避免遗忘。

我的损失 fnc 看起来像:

loss=mse(new_samples_truth,new_samples_pred) + factor*mse(replay_samples_truth,replay_sampls_pred)

如您所见,损失是为新样本和重放样本分别计算的两个 mse 的加权和。 这意味着每当我想训练一个批次时,我都想将新数据点和重放数据点分开,并计算整个批次的标量损失。

如何在 Keras 中实现此损失函数并将其与 train_on_batch 一起使用? Keras 的 train_on_batch 方法似乎使用损失函数分别计算小批量中每个数据点的损失。由于我的批次包含新的和重放数据点,​​这将不起作用。 So how can I Keras make calculate the Loss for the entire batch at once and return only once scalar? 此外,似乎 Keras 会分别评估每个数据点的损失 fnc,并将每个样本的损失保存在一个数组中。但是,我想获得整个批次的损失。有人了解 Keras 如何实际处理批次的损失计算吗?

这是我的伪代码

    batch=pd.concat([new_samples,replay_samples]) #new_samples and replay_samples are pd.dataframes
    #len(batch) = 20

   def my_replay_loss(factor):
       def loss(y_true,y_pred): #y_true and y_pred come from keras 

           y_true_new_samples = y_true.head(10)
           y_pred_new_samples = y_pred.head(10)
           y_true_replay_samples = y_true.tail(10)
           y_pred_replay_samples = y_pred.tail(10)
           
           calc_loss = mse(y_true_new_samples,y_pred_new_samples) + factor*mse(y_true_replay_samples,y_pred_replay_samples)
           return calc_loss

        return loss

'''

解决方法

您可以像以前一样定义自定义损失函数,但您需要使用 tf 操作构建一个 tf 图,这里有一个示例:

def my_replay_loss(factor):
    def loss(y_true,y_pred):
        calc_loss = tf.math.add(tf.keras.losses.mse(y_true[:10,:],y_pred[:10,:]),tf.math.multiply(factor,tf.keras.losses.mse(y_true[-10:,y_pred[-10:,:])))
        return calc_loss
    return loss

然后你可以编译你的模型:

loss=my_replay_loss(factor=tf.constant(0.5,dtype=tf.float32))

但是,在您的情况下,我建议不要根据批次中数据点的顺序构建损失函数,但我建议构建具有 2 个输入和 2 个输出的模型 (最终共享模型架构)并用 2 个损失编译它们。 Keras 模型支持 loss_weights 参数,允许您根据自己的喜好对两种损失进行权衡。

loss_weights: 指定标量系数(Python 浮点数)的可选列表或字典,以对不同模型输出的损失贡献进行加权。模型将最小化的损失值将是所有单个损失的加权总和,由 loss_weights 系数加权。如果是列表,则预计与模型输出的映射为 1:1。如果是字典,则需要将输出名称(字符串)映射到标量系数。

这里有一个简单的例子:

import tensorflow as tf

def get_compiled_model(input_shape):
    
    input1 = tf.keras.Input(input_shape)
    input2 = tf.keras.Input(input_shape)

    conv_layer = tf.keras.layers.Conv2D(128,3,activation='relu')
    flatten_layer = tf.keras.layers.Flatten()
    dense_layer = tf.keras.layers.Dense(1,activation='softmax')

    x1 = conv_layer(input1)
    x1 = flatten_layer(x1)
    output1 = dense_layer(x1)

    x2 = conv_layer(input2)
    x2 = flatten_layer(x2)
    output2 = dense_layer(x2)
    
    model = tf.keras.Model(inputs=[input1,input2],outputs=[output1,output2])
    
    model.compile(optimizer=tf.keras.optimizers.Adam(),loss=[tf.keras.losses.mse,tf.keras.losses.mse],loss_weights=[1,0.5])
    return model

model = get_compiled_model(input_shape=x_train[0].shape)
model.fit([x_train1,x_train2],[y_train1,y_train2],epochs=5,batch_size=10)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?