如何解决二元分类回归模型的准确度计算
我的准确率方法:
def accuracy(outputs,labels):
_,preds = torch.max(outputs,dim=1)
return torch.tensor(torch.sum(preds == labels).item() / len(preds))
我的类定义模型:
class PulsarLogisticRegression(nn.Module):
def __init__(self):
super().__init__()
self.linear= nn.Linear(input_size,output_size)
def forward(self,xb):
xb = xb.view(xb.size(0),-1)
out= self.linear(xb)
return out
def training_step(self,batch):
inputs,targets = batch
# Generate predictions
out = self(inputs)
# Calcuate loss
loss = F.cross_entropy(out,targets)
return loss
def validation_step(self,targets = batch
# Generate predictions
out = self(inputs)
# Calculate loss
loss = F.cross_entropy(out,targets)
acc = accuracy(out,targets) # Calculate accuracy
return {'val_loss': loss,'val_acc': acc} # fill this
def validation_epoch_end(self,outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # Combine losses
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies
return {'val_loss': epoch_loss.item(),'val_acc': epoch_acc.item()}
def epoch_end(self,epoch,result,num_epochs):
# Print result every 20th epoch
print("Epoch [{}],val_loss: {:.4f},val_acc: {:.4f}".format(epoch,result['val_loss'],result['val_acc']))
RuntimeError Traceback (most recent call last)
<ipython-input-88-cd9b8a9a3b02> in <module>()
----> 1 result = evaluate(model,val_loader) # Use the the evaluate function
2 print(result) 4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input,target,weight,size_average,ignore_index,reduce,reduction)
2262 .format(input.size(0),target.size(0)))
2263 if dim == 2:
-> 2264 ret = torch._C._nn.nll_loss(input,_Reduction.get_enum(reduction),ignore_index)
2265 elif dim == 4:
2266 ret = torch._C._nn.nll_loss2d(input,ignore_index)
RuntimeError: 1D target tensor expected,multi-target not supported
我对机器学习非常陌生,我正在尝试制作一个模型,该模型基于 5 列预测一列数据。列中的值是0和1。所以它基本上是一个二元分类模型。
我尝试过的: 正如我所说,我对这个领域相当陌生,一些解释建议使用挤压函数以某种方式将目标张量的形状减少到一维,但这似乎会在该类的其他方法中引发一些其他错误。
我正在寻找可以帮助我获得正确准确度的误差函数。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。