如何解决PyTorch Lightning 是否是整个 epoch 的平均指标?
我正在查看 PyTorch-Lightning
官方文档 https://pytorch-lightning.readthedocs.io/en/0.9.0/lightning-module.html 中提供的示例。
这里的损失和度量是根据混凝土批次计算的。但是当记录一个人对特定批次的准确性不感兴趣时,它可能很小而且不具有代表性,而是所有时期的平均值。我是否理解正确,有一些代码对所有批次执行平均,通过 epoch?
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
class ClassificationTask(pl.LightningModule):
def __init__(self,model):
super().__init__()
self.model = model
def training_step(self,batch,batch_idx):
x,y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat,y)
return pl.TrainResult(loss)
def validation_step(self,batch_idx):
x,y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat,y)
acc = FM.accuracy(y_hat,y)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict({'val_acc': acc,'val_loss': loss})
return result
def test_step(self,batch_idx):
result = self.validation_step(batch,batch_idx)
result.rename_keys({'val_acc': 'test_acc','val_loss': 'test_loss'})
return result
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(),lr=0.02)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。