如何解决PyTorch 布尔值 - 停止反向传播?
我需要创建一个神经网络,在其中使用二进制门将某些张量清零,这些张量是禁用电路的输出。
为了提高运行速度,我期待使用 torch.bool
二元门来停止沿网络中禁用电路的反向传播。但是,我使用 CIFAR-10 数据集的官方 PyTorch
示例创建了一个小实验,对于 gate_A
和 gate_B
的任何值,运行速度完全相同:(这意味着这个想法行不通)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.pool = nn.MaxPool2d(2,2)
self.conv1a = nn.Conv2d(3,6,5)
self.conv2a = nn.Conv2d(6,16,5)
self.conv1b = nn.Conv2d(3,5)
self.conv2b = nn.Conv2d(6,5)
self.fc1 = nn.Linear(32 * 5 * 5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
# Only one gate is supposed to be enabled at random
# However,for the experiment,I fixed the values to [1,0] and [1,1]
choice = randint(0,1)
gate_A = torch.tensor(choice,dtype = torch.bool)
gate_B = torch.tensor(1-choice,dtype = torch.bool)
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
a *= gate_A
b *= gate_B
x = torch.cat( [a,b],dim = 1 )
x = torch.flatten(x,1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
如何定义 gate_A
和 gate_B
以使反向传播在它们为零时有效停止?
附注。在运行时动态更改 concatenation
也会更改分配给每个模块的权重。 (例如,与 a
关联的权重可以在另一轮中分配给 b
,从而破坏网络的运作方式)。
解决方法
您可以使用 torch.no_grad
(下面的代码可能会更简洁):
def forward(self,x):
# Only one gate is supposed to be enabled at random
# However,for the experiment,I fixed the values to [1,0] and [1,1]
choice = randint(0,1)
gate_A = torch.tensor(choice,dtype = torch.bool)
gate_B = torch.tensor(1-choice,dtype = torch.bool)
if choice:
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
a *= gate_A
with torch.no_grad(): # disable gradient computation
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
b *= gate_B
else:
with torch.no_grad(): # disable gradient computation
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
a *= gate_A
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
b *= gate_B
x = torch.cat( [a,b],dim = 1 )
x = torch.flatten(x,1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
再看一遍,我认为以下是针对特定问题的更简单的解决方案:
def forward(self,1)
if choice:
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
b = torch.zeros(shape_of_conv_output) # replace shape of conv output here
else:
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
a = torch.zeros(shape_of_conv_output) # replace shape of conv output here
x = torch.cat( [a,1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
,
简单的解决方案,只需在禁用 a
或 b
时定义一个带零的张量 :)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.pool = nn.MaxPool2d(2,2)
self.conv1a = nn.Conv2d(3,6,5)
self.conv2a = nn.Conv2d(6,16,5)
self.conv1b = nn.Conv2d(3,5)
self.conv2b = nn.Conv2d(6,5)
self.fc1 = nn.Linear(32 * 5 * 5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
if randint(0,1):
a = self.pool(F.relu(self.conv1a(x)))
a = self.pool(F.relu(self.conv2a(a)))
b = torch.zeros_like(a)
else:
b = self.pool(F.relu(self.conv1b(x)))
b = self.pool(F.relu(self.conv2b(b)))
a = torch.zeros_like(b)
x = torch.cat( [a,1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
附注。我在喝咖啡的时候想到了这个。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。