在多种情况下尽早停车

如何解决在多种情况下尽早停车

我正在为推荐器系统(项目推荐)进行多类分类,并且我目前正在使用sparse_categorical_crossentropy损失来训练我的网络。因此,通过监视我的验证损失EarlyStopping来执行val_loss是合理的:

tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=10)

可正常使用。但是,网络(推荐系统)的性能由“平均精度为10”来衡量,并在训练过程中作为指标average_precision_at_k10进行跟踪。因此,我还可以使用以下指标来提前停止:

tf.keras.callbacks.EarlyStopping(monitor='average_precision_at_k10',patience=10)

也可以按预期工作。

我的问题: 有时,验证损失会增加,而平均精度为10的情况会有所改善,反之亦然。因此,我需要监视两者,并尽早停止,当且仅当都恶化时。我想做什么:

tf.keras.callbacks.EarlyStopping(monitor=['val_loss','average_precision_at_k10'],patience=10)

显然不起作用。有什么想法可以做到吗?

解决方法

在上面Gerry P的指导下,我设法创建了自己的自定义EarlyStopping回调,并以为可以在其他人希望实现类似功能的情况下将其发布在此处。

如果两者验证损失 均 em的平均精度对于{{1 }}个时代,尽早停止。

patience

然后用作:

class CustomEarlyStopping(keras.callbacks.Callback):
    def __init__(self,patience=0):
        super(CustomEarlyStopping,self).__init__()
        self.patience = patience
        self.best_weights = None
        
    def on_train_begin(self,logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best_v_loss = np.Inf
        self.best_map10 = 0

    def on_epoch_end(self,epoch,logs=None): 
        v_loss=logs.get('val_loss')
        map10=logs.get('val_average_precision_at_k10')

        # If BOTH the validation loss AND map10 does not improve for 'patience' epochs,stop training early.
        if np.less(v_loss,self.best_v_loss) and np.greater(map10,self.best_map10):
            self.best_v_loss = v_loss
            self.best_map10 = map10
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)
                
    def on_train_end(self,logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
,

您可以通过创建自定义回调来实现。有关如何执行此操作的信息位于here.下面是一些代码,这些代码说明了您可以在自定义回调中执行的操作。我参考的文档显示了许多其他选项。

class LRA(keras.callbacks.Callback): # subclass the callback class
# create class variables as below. These can be accessed in your code outside the class definition as LRA.my_class_variable,LRA.best_weights
    my_class_variable=something  # a class variable
    best_weights=model.get_weights() # another  class variable
# define an initialization function with parameters you want to feed to the callback
    def __init__(self,param1,param2,etc):
        super(LRA,self).__init__()
        self.param1=param1
        self.param2=param2
        etc for all parameters
        # write any initialization code you need here

    def on_epoch_end(self,logs=None):  # method runs on the end of each epoch
        v_loss=logs.get('val_loss')  # example of getting log data at end of epoch the validation loss for this epoch
        acc=logs.get('accuracy') # another example of getting log data 
        LRA.best_weights=model.get_weights() # example of setting class variable value
        print(f'Hello epoch {epoch} has just ended') # print a message at the end of every epoch
        lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate
        if v_loss > self.param1:
           new_lr=lr * self.param2
           tf.keras.backend.set_value(model.optimizer.lr,new_lr) # set the learning rate in the optimizer
        # write whatever code you need
,

我建议您创建自己的回调。 在下面的内容中,我添加了一个同时监视准确性和损失的解决方案。您可以使用自己的指标替换acc:

class CustomCallback(keras.callbacks.Callback):
    acc = {}
    loss = {}
    best_weights = None
    
    def __init__(self,patience=None):
        super(CustomCallback,self).__init__()
        self.patience = patience
    
    def on_epoch_end(self,logs=None):
        epoch += 1
        self.loss[epoch] = logs['loss']
        self.acc[epoch] = logs['accuracy']
    
        if self.patience and epoch > self.patience:
            # best weight if the current loss is less than epoch-patience loss. Simiarly for acc but when larger
            if self.loss[epoch] < self.loss[epoch-self.patience] and self.acc[epoch] > self.acc[epoch-self.patience]:
                self.best_weights = self.model.get_weights()
            else:
                # to stop training
                self.model.stop_training = True
                # Load the best weights
                self.model.set_weights(self.best_weights)
        else:
            # best weight are the current weights
            self.best_weights = self.model.get_weights()

请记住,如果要控制监视数量(即min_delta)的最小变化,则必须将其集成到代码中。

以下是有关如何建立自定义回调的文档:custom_callback

,

这时,进行自定义循环并仅使用if语句会更加简单。例如:

<script>

这是使用此方法的简单自定义训练循环:

def main(epochs=50):
    for epoch in range(epochs):
        fit(epoch)

        if test_acc.result() > .8 and topk_acc.result() > .9:
            print(f'\nEarly stopping. Test acc is above 80% and TopK acc is above 90%.')
            break

if __name__ == '__main__':
    main(epochs=100)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?