微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

这是在训练深度学习模型中实现k倍交叉验证的正确方法吗?

如何解决这是在训练深度学习模型中实现k倍交叉验证的正确方法吗?

我了解k折交叉验证的概念,并且正在尝试将其应用于此处。以下伪代码正确吗?

encoder = Encoder.load_model(enc_filepath)
decoder = Decoder.load_model(dec_filepath)
tokenizer = Tokenizer.load_from_file(tok_filepath)
optimizer = Adam()
loss_object = SparseCategoricalCrossentropy()

epoch_enc_weights = encoder.weights
epoch_dec_weights = decoder.weights

stop_condition_met = False
num_epochs = 100
curr_epoch = 0
best_loss = last_best_loss(enc_filepath) // best_loss from prevIoUs run is recorded in filename
k = 5

while not stop_condition_met and curr_epoch < num_epochs:
    fold_val_loss = []
    fold_weights = []

    for i in range(k):
        tr_x,tr_y,val_x,val_y = generate_data(folds,i)
        enc_weights,dec_weights,val_loss = 
                train(encoder,decoder,tokenizer,optimizer,loss_object,tr_x,val_y)
        fold_val_loss.append(val_loss)
        fold_weights.append((enc_weights,dec_weights))
        encoder.load_weights(epoch_enc_weights)
        decoder.load_weights(epoch_dec_weights)

    i,loss = get_best_loss(fold_val_loss)
    epoch_enc_weights = fold_weights[i][0]
    epoch_dec_weights = fold_weights[i][1]
    encoder.load_weights(epoch_enc_weights)
    decoder.load_weights(epoch_dec_weights)

    if loss < best_loss:
        best_loss = loss
        encoder.save_model('./enc_{datetime}_epoch{curr_epoch}_{best_loss:.3f}.tf')
        decoder.save_model('./dec_{datetime}_epoch{curr_epoch}_{best_loss:.3f}.tf')
    curr_epoch += 1
    stop_condition_met = check_stop_condition()

encoder = Encoder.load_model(enc_filepath)
decoder = Decoder.load_model(dec_filepath)
test_loss = eval(encoder,ts_x,ts_y)

这是怎么做的?基本上:

  1. 对于每一折,生成火车验证数据。
  2. 火车模型
  3. 将此折的损失和权重附加到该时期的数组变量中
  4. 重新初始化权重以进行下一折,从步骤1开始重复
  5. 完成所有折叠后,确定验证损失最好的迭代
  6. 为下一个时期加载该迭代的权重
  7. 如果这个时期的最佳损失要比有史以来的最佳损失要好,请将这些模型保存到文件
  8. 增加curr_epoch计数,检查停止条件
  9. 如果不停止训练,请从第1步开始重复
  10. 如果停止训练,请从文件中加载最后的最佳模型并根据测试数据进行评估
  11. 部署此模型进行生产

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。