微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在具有自定义损失的Pytorch中训练模型如何设置优化器并进行训练?

如何解决在具有自定义损失的Pytorch中训练模型如何设置优化器并进行训练?

我是pytorch的新手,我正在尝试运行找到的github模型并对其进行测试。因此,作者提供了模型和损失函数

像这样:

#1. Inference the model
model = PhysNet_padding_Encoder_Decoder_MAX(frames=128)
rPPG,x_visual,x_visual3232,x_visual1616 = model(inputs)

#2. normalized the Predicted rPPG signal and GroundTruth BVP signal
rPPG = (rPPG-torch.mean(rPPG)) /torch.std(rPPG)     # normalize
BVP_label = (BVP_label-torch.mean(BVP_label)) /torch.std(BVP_label)     # normalize

#3. Calculate the loss
loss_ecg = Neg_Pearson(rPPG,BVP_label)

数据加载

    train_loader = torch.utils.data.DataLoader(train_set,batch_size = 20,shuffle = True)

    batch = next(iter(train_loader))

    data,label1,label2 = batch

    inputs= data

假设我想训练这个模型15个纪元。 所以这就是我到目前为止: 我正在尝试设置优化程序和训练,但是我不确定如何将自定义损失和数据加载与模型联系起来并正确设置15个时期的训练。

optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

for epoch in range(15):
  ....

有什么建议吗?

解决方法

我假设BVP_label是train_loader的label

train_loader = torch.utils.data.DataLoader(train_set,batch_size = 20,shuffle = True)

# Using GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = PhysNet_padding_Encoder_Decoder_MAX(frames=128)
model.to(device)

optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

for epoch in range(15):
    model.train()
    for inputs,label1,label2 in train_loader:
        rPPG,x_visual,x_visual3232,x_visual1616 = model(inputs)
        BVP_label = label1 # assumed BVP_label is label1

        rPPG = (rPPG-torch.mean(rPPG)) /torch.std(rPPG)
        BVP_label = (BVP_label-torch.mean(BVP_label)) /torch.std(BVP_label)
        
        loss_ecg = Neg_Pearson(rPPG,BVP_label)
        
        optimizer.zero_grad()
        loss_ecg.backward()
        optimizer.step()

PyTorch培训步骤如下。

  • 创建DataLoader
  • 初始化模型和优化器
  • 创建设备对象并将模型移至设备

在火车圈中

  • 选择一个小批量数据
  • 使用模型进行预测
  • 计算损失
  • loss.backward()更新模型的梯度
  • 使用优化器更新参数

如您所知,您还可以查看PyTorch教程。

Learning PyTorch with Examples

What is torch.nn really?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?