如何解决类型错误:在对 pytorch 模型使用 summary() 时,不能将序列乘以“元组”类型的非整数
因为我正在设计一个多模式架构,有一个 3dcnn 和 2dcnn,所以我想像在 Keras 中一样查看模型架构。输入大小如图所示,当我尝试打印结构时,我不断收到此错误,但在运行它以进行训练和模型评估时没有问题。
class CombineModel(nn.Module): #A multimodal network(3D and 2D CNN),with multIoUtputs (2 outputs)
def __init__(self,num_classes):
super(CombineModel,self).__init__()
self.CNN3D = CNN3DModel() #output size is (512,1)
self.Unet = UNet11() #output size is (512,1)
self.fc1 = nn.Linear(1024,512)
self.fc2 = nn.Linear(512,num_classes)
self.fc3 = nn.Linear(512,1)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
self.softmax = nn.softmax()
def forward(self,x1,x2):
out1 = self.CNN3D(x1)
out2 = self.Unet(x2)
outa = torch.cat((out1,out2),dim=1)
outa = self.fc1(outa)
outa = self.relu(outa)
outa = self.softmax(self.fc2(outa))
outb = self.fc3(self.relu(out2))
return outa,outb
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。