如何解决从检查点恢复时的训练损失会爆炸
我正在尝试在算法中实现一个功能,该功能允许我从检查点恢复训练。问题是,当我恢复训练时,我的损失会激增很多数量级,从0.001到1000数量级。我怀疑问题可能在于恢复训练后,学习率没有正确设置。 >
这是我的训练功能:
def train_gray(epoch,data_loader,device,model,criterion,optimizer,i,path):
train_loss = 0.0
for data in data_loader:
img,_ = data
img = img.to(device)
stand_dev = 0.0392
noisy_img = add_noise(img,stand_dev,device)
output = model(noisy_img,stand_dev)
output = output[:,0:1,:,:]
loss = criterion(output,img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()*img.size(0)
train_loss = train_loss/len(data_loader)
print('Epoch: {} Complete \tTraining Loss: {:.6f}'.format(
epoch,train_loss
))
return train_loss
这是我的主要函数,用于初始化变量,加载检查点,调用训练函数并在经过一段时间的训练后保存检查点:
def main():
Now = datetime.Now()
current_time = Now.strftime("%H_%M_%s")
path = "/home/bledc/my_remote_folder/denoiser/models/{}_sigma_10_session2".format(current_time)
os.mkdir(path)
width = 256
# height = 256
num_epochs = 25
batch_size = 4
learning_rate = 0.0001
data_loader = load_dataset(batch_size,width)
model = UNetWithresnet50Encoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(
model.parameters(),lr=learning_rate,weight_decay=1e-5)
############################################################################################
# UNCOMMENT CODE BELOW TO RESUME TRAINING FROM A MODEL
model_path = "/home/bledc/my_remote_folder/denoiser/models/resnet_sigma_10/model_epoch_10.pt"
save_point = torch.load(model_path)
model.load_state_dict(save_point['model_state_dict'])
optimizer.load_state_dict(save_point['optimizer_state_dict'])
epoch = save_point['epoch']
train_loss = save_point['train_loss']
model.train()
############################################################################################
for i in range(epoch,num_epochs+1):
train_loss = train_gray(i,path)
checkpoint(i,train_loss,path)
print("end")
最后,这是我保存检查点的功能:
def checkpoint(epoch,path):
torch.save({
'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'train_loss': train_loss
},path+"/model_epoch_{}.pt".format(epoch))
print("Epoch saved")
如果我的问题是我没有节省学习率,该怎么办?
任何帮助将不胜感激, 克莱门特(Clement)
更新:我相当确定问题出在我的预训练模型中。我在每个时期都保存了优化器,但是优化器仅保存可训练层的信息。我希望尽快解决此问题,并在找出谁来保存和加载整个模型时发布更详尽的答案。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。