如何解决如何在 PyTorch 交叉熵损失函数中获取和打印每个类的权重
我有使用 PyTorch 进行多类分割的代码。输入是图像及其地面实况掩码。这是我的一段代码:
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
net.train()
epoch_loss = 0
for batch in train_loader:
true_masks = batch['mask']
imgs=batch['image']
imgs = imgs.to(device=device,dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device,dtype=mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred,true_masks)
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(net.parameters(),0.1)
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
现在我想知道每个时代每个班级的权重是多少。我有 6 个班级:[0,1,2,3,4,5]。例如,我想获得如下信息:
1 级重量=...
class 2 weight=...
3 级重量=...
class 4 weight=...
...
如何获取和打印这些重量?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。