如何解决Pytorch 3D卷积网络导致Google Colab中的RAM崩溃
我对DL(以及一般而言的编码)非常陌生。我正在pytorch中松散地基于V-Net实现一个非常简单的3D CNN(下面的代码)。目前,我可以使代码在小型测试图像(形状:(1,1,128,32))上运行,但是例如,如果我输入更大的值(1,256,128),该模型将使Colab崩溃。可以预料,或者如果我在这里犯了一些明显的错误,请先谢谢。
class Test3D(nn.Module):
def __init__(self):
super(Test3D,self).__init__()
self.input_layer = self._conv_input()
self.conv_layer1 = self.ResidualBlock(32,40)
self.conv_layer2 = self.ResidualBlock(40,48)
self.conv_layer3 = self.ResidualBlock(48,56)
self.conv_layer4 = self.ResidualBlock(56,48)
self.conv_layer5 = self.ResidualBlock(48,40)
self.conv_layer6 = self.ResidualBlock(40,32)
self.conv_layer7 = self._conv_output()
def _conv_input(self):
conv_layer= nn.Sequential(
nn.Conv3d(1,32,kernel_size=(3,3,3),stride=2,padding=1)
)
return conv_layer
def _conv_output(self):
conv_layer= nn.Sequential(
nn.ConvTranspose3d(32,6,kernel_size=2,padding=0)
)
return conv_layer
def ResidualBlock(self,in_c,out_c):
conv_layer=nn.Sequential(
nn.Conv3d(in_c,out_c,padding=1),nn.Conv3d(out_c,ContBatchnorm3d(out_c),nn.ReLU()
)
return conv_layer
def forward(self,x):
out = self.input_layer(x)
print(out.shape)
out = self.conv_layer1(out)
print(out.shape)
out = self.conv_layer2(out)
print(out.shape)
out = self.conv_layer3(out)
print(out.shape)
out = self.conv_layer4(out)
print(out.shape)
out = self.conv_layer5(out)
print(out.shape)
out = self.conv_layer6(out)
print(out.shape)
out = self.conv_layer7(out)
print(out.shape)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。