如何解决PyTorch:了解RuntimeError:梯度运算所需的变量之一已通过就地操作进行了修改
这是一个问题设置:
将rgb转换为hsv的函数:
def rbg2hsv(img_rgb):
batch_size,channel,height,width = img_rgb.size()
r,g,b = img_rgb[:,:,:],img_rgb[:,1,2,:]
t = torch.min(img_rgb,dim=1,keepdim=False)[0]
v = torch.max(img_rgb,keepdim=False)[0]
s = (v - t) / (v + 1e-6)
s[v == 0] = 0
# v==r
hr = 60 * (g - b) / (v - t + 1e-6)
# v==g
hg = 120 + 60 * (b - r) / (v - t + 1e-6)
# v==b
hb = 240 + 60 * (r - g) / (v - t + 1e-6)
h = torch.zeros(batch_size,width,requires_grad=False)
if torch.cuda.is_available():
h = h.cuda()
h = h.flatten()
hr = hr.flatten()
hg = hg.flatten()
hb = hb.flatten()
h[(v == b).flatten()] = hb[(v == b).flatten()]
h[(v == g).flatten()] = hg[(v == g).flatten()]
h[(v == r).flatten()] = hr[(v == r).flatten()]
h[h < 0] += 360
h = torch.reshape(h,(batch_size,width))
img_hsv = torch.stack([h,s,v])
img_hsv = img_hsv.permute(1,3)
return img_hsv
这是重现问题的最小代码:
img_rgb = torch.rand(2,3,32,32)
img_rgb.requires_grad = True
instance_norm = nn.Instancenorm2d(1,affine=False)
gamma = 2
beta = 1
img_hsv = rbg2hsv(img_rgb)
# 1: OK
#img_hsv[:,2:3,:] = img_hsv[:,:] * gamma + beta
# 2: FAIL
img_hsv[:,:] = instance_norm(img_hsv[:,:])
# 3: OK
#img_hsv[:,:].clone())
img_hsv.mean().backward() # Fails on this line
print('img_rgb.grad.size()',img_rgb.grad.size())
变体2: FAIL
失败,并显示错误:RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1,32]],which is output 0 of ViewBackward,is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that Failed to compute its gradient,with torch.autograd.set_detect_anomaly(True).
如果我将rbg2hsv
更改为虚拟add_fn
,则所有三种变体都可以工作:
def add_fn(x):
x = x + 1
return x
img_rgb = torch.rand(2,affine=False)
gamma = 2
beta = 1
img_hsv = add_fn(img_rgb)
# 1: OK
#img_hsv[:,:] * gamma + beta
# 2: OK
img_hsv[:,:].clone())
img_hsv.mean().backward()
print('img_rgb.grad.size()',img_rgb.grad.size())
所以我的问题是:
- 为什么变量1没有失败(似乎它也是张量的原位修改)?
- 为什么将
rbg2hsv
更改为add_fn
才能使所有变体有效? (即rbg2hsv
有问题吗?)
这里有一个类似的问题https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836,但是当我们需要使用.clone()
以及何时不使用时(即某些操作是否已就位),我仍然不理解规则?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。