如何解决KL 散度在贝叶斯卷积神经网络上变为 NaN
我正在尝试在 Python 3.7 上使用 Pytorch 实现贝叶斯卷积神经网络。我主要将自己定位于Shridhar's implementation。当使用归一化和 MNIST 数据运行我的 CNN 时,经过几次迭代后,KL 散度为 NaN。我已经以相同的方式实现了线性层,并且它们工作得非常好。
我将数据标准化如下:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist',train=True,download=True,transform=transforms.Compose([
transforms.ToTensor(),transforms.normalize((0.1307,),(0.3081,))
])),batch_size=BATCH_SIZE,shuffle=True,**LOADER_kwargs)
eval_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist',train=False,transform=transforms.Compose([
transforms.ToTensor(),))
])),batch_size=EVAL_BATCH_SIZE,shuffle=False,**LOADER_kwargs)
我对 Conv-Layer 的实现如下所示:
class BayesianConv2d(nn.Module):
def __init__(self,in_channels,out_channels,prior_sigma,kernel_size,stride=1,padding=0,dilation=1,groups=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.normal = torch.distributions.normal(0,1)
# conv-parameters
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
# Weight parameters
self.weight_mu = nn.Parameter(torch.Tensor(out_channels,*self.kernel_size).uniform_(0,0.1))
self.weight_rho = nn.Parameter(torch.Tensor(out_channels,*self.kernel_size).uniform_(-3,0.1))
self.weight_sigma = 0
self.weight = 0
# Bias parameters
self.bias_mu = nn.Parameter(torch.Tensor(out_channels).uniform_(0,0.1))
self.bias_rho = nn.Parameter(torch.Tensor(out_channels).uniform_(-3,0.1))
self.bias_sigma = 0
self.bias = 0
# prior
self.prior_sigma = prior_sigma
def forward(self,input,sample=False,calculate_log_probs=False):
# compute sigma out of rho: sigma = log(1+e^rho)
self.weight_sigma = torch.log1p(torch.exp(self.weight_rho))
self.bias_sigma = torch.log1p(torch.exp(self.bias_rho))
# sampling process -> use local reparameterization trick
activations_mu = F.conv2d(input.to(DEVICE),self.weight_mu,self.bias_mu,self.stride,self.padding,self.dilation,self.groups)
activations_sigma = torch.sqrt(1e-16 + F.conv2d((input**2).to(DEVICE),self.weight_sigma**2,self.bias_sigma**2,self.groups))
activation_epsilon = Variable(self.weight_mu.data.new(activations_sigma.size()).normal_(mean=0,std=1))
outputs = activations_mu + activations_sigma * activation_epsilon
if self.training or calculate_log_probs:
self.kl_div = 0.5 * ((2 * torch.log(self.prior_sigma / self.weight_sigma) - 1 + (self.weight_sigma / self.prior_sigma).pow(2) + ((0 - self.weight_mu) / self.prior_sigma).pow(2)).sum() \
+ (2 * torch.log(0.1 / self.bias_sigma) - 1 + (self.bias_sigma / 0.1).pow(2) + ((0 - self.bias_mu) / 0.1).pow(2)).sum())
return outputs
相应的 Conv-Net 的实现如下所示:
class BayesianConvNetwork(nn.Module):
# Set up network by definining layers
def __init__(self):
super().__init__()
self.conv1 = layers.BayesianConv2d(1,24,prior_sigma=0.1,kernel_size = (5,5),padding=2)
self.pool1 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.conv2 = layers.BayesianConv2d(24,48,padding=2)
self.pool2 = nn.MaxPool2d(kernel_size=3,padding=1)
self.conv3 = layers.BayesianConv2d(48,64,padding=2)
self.pool3 = nn.MaxPool2d(kernel_size=3,padding=1)
self.fcl1 = layers.BayesianLinearWithLocalReparamTrick(4*4*64,256,prior_sigma=0.1)
self.fcl2 = layers.BayesianLinearWithLocalReparamTrick(256,10,prior_sigma=0.1)
# define forward function by assigning corresponding activation functions to layers
def forward(self,x,sample=False):
x = F.relu(self.conv1(x,sample))
x = self.pool1(x)
x = F.relu(self.conv2(x,sample))
x = self.pool2(x)
x = F.relu(self.conv3(x,sample))
x = self.pool3(x)
x = x.view(-1,4*4*64)
x = F.relu(self.fcl1(x,sample))
x = F.log_softmax(self.fcl2(x,sample),dim=1)
return x
# summing up KL-divergences to obtain overall KL-divergence-value
def total_kl_div(self):
return (self.conv1.kl_div + self.conv2.kl_div + self.conv3.kl_div + self.fcl1.kl_div + self.fcl2.kl_div)
# sampling prediction: perform prediction for each of the "different networks" that result from the weight distributions
def sample_elbo(self,target,batch_idx,nmbr_batches,samples=SAMPLES):
outputs = torch.zeros(samples,target.shape[0],CLASSES).to(DEVICE)
kl_divs = torch.zeros(samples).to(DEVICE)
for i in range(samples): # sample through networks
outputs[i] = self(input,sample=True) # perform prediction
kl_divs[i] = self.total_kl_div() # calculate total kl_div of the network
kl_div = kl_divs.mean() # compute mean kl_div from all samples
negative_log_likelihood = F.nll_loss(outputs.mean(0),size_average=False)
loss = kl_weighting * kl_div + negative_log_likelihood
return loss
有没有人遇到过同样的问题或知道如何解决?
非常感谢!
解决方法
我发现这似乎是 SGD 优化器的问题。使用 Adam 作为优化器解决了这个问题,虽然我不知道原因。如果有人知道为什么它适用于 Adam 而不适用于 SGD,请随时发表评论。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。