如何解决WANGP中梯度惩罚损失的向后传递
WANGP 中评论家的损失项是:
L = D(x) - D(G(z)) + λ * ( norm(gradient(x')) -1) ^2
哪里
x = 真实图像,
G(z) = 生成的图像 &
x' = ε * real_image + (1 - ε) * 生成的图像
我在 Pytorch 中的代码:
def get_gradient(crit,real,fake,epsilon):
'''
Return the gradient of the critic's scores with respect to mixes of real and fake images.
Parameters:
crit: the critic model
real: a batch of real images
fake: a batch of fake images
epsilon: a vector of the uniformly random proportions of real/fake per mixed image
Returns:
gradient: the gradient of the critic's scores,with respect to the mixed image
'''
# Mix the images together
mixed_images = real * epsilon + fake * (1 - epsilon)
# Calculate the critic's scores on the mixed images
mixed_scores = crit(mixed_images)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=mixed_images,outputs=mixed_scores,grad_outputs=torch.ones_like(mixed_scores),create_graph=True,retain_graph=True,)[0]
return gradient
def gradient_penalty(gradient):
'''
Return the gradient penalty,given a gradient.
Given a batch of image gradients,calculate the magnitude of each image's gradient
and penalize the mean quadratic distance of each magnitude to 1.
Parameters:
gradient: the gradient of the critic's scores,with respect to the mixed image
Returns:
penalty: the gradient penalty
'''
# Flatten the gradients so that each row captures one image
gradient = gradient.view(len(gradient),-1)
# Calculate the magnitude of every row
gradient_norm = gradient.norm(2,dim=1)
# Penalize the mean squared distance of the gradient norms from 1
penalty = torch.mean((gradient_norm - 1)**2)
return penalty
#this is how critic will be updated in each epoch
crit_opt.zero_grad()
fake_noise = get_noise(cur_batch_size,z_dim,device=device)
fake = gen(fake_noise)
crit_fake_pred = crit(fake.detach())
crit_real_pred = crit(real)
epsilon = torch.rand(len(real),1,device=device,requires_grad=True)
gradient = get_gradient(crit,fake.detach(),epsilon)
gp = gradient_penalty(gradient)
crit_loss = get_crit_loss(crit_fake_pred,crit_real_pred,gp,c_lambda)
# Keep track of the average critic loss in this batch
mean_iteration_critic_loss += crit_loss.item() / crit_repeats
# Update gradients
crit_loss.backward(retain_graph=True)
# Update optimizer
crit_opt.step()
critic_losses += [mean_iteration_critic_loss]
我试图了解 Pytorch 在梯度惩罚项的反向传递中计算的内容。由于损失项包含norm(gradient(x'))
,那么backward pass中的双梯度是如何计算的?如何使用它来计算critic的神经网络和epsilon的梯度?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。