如何解决用于直方图匹配的简单 Pytorch 模型不更新其参数
我目前正在尝试拟合一个极其简单的模型,该模型基本上应该为直方图匹配找到最佳直方图。我写了一个超级简单的模型,只有一个我直接使用的 Parameter 对象:
import torch.nn.functional as F
class AutoHist(pl.LightningModule):
def __init__(self,channel=1,bins=255):
super().__init__()
self.hist = torch.nn.Parameter(torch.rand((1,channel,bins),requires_grad=True))
self.eps = 1e-5
def b_distance(self,h1,h2):
distance = 1
distance -= 1/(torch.sqrt(torch.mean(h1,axis=2)*torch.mean(h2,axis=2)*h1.size(2)**2))
distance *= torch.sum(torch.sqrt(h1*h2 + self.eps),axis=2)
return torch.sqrt(distance + self.eps)
def training_step(self,batch,batch_idx):
# training_step defined the train loop.
# It is independent of forward
x,y = batch
hist = self.hist / self.hist.sum()
distances = self.b_distance(hist,x)
loss = F.binary_cross_entropy(distances[:,0],y)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(),lr=1e-3)
return optimizer
但是由于某种原因,反向传播没有通过,参数也没有更新。有人知道问题可能出在哪里吗?梯度实际上存在并且从批次到批次发生变化。我使用pytorch Lightning删除样板代码,但关键在于我写的代码。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。