微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

类型错误:在对 pytorch 模型使用 summary() 时,不能将序列乘以“元组”类型的非整数

如何解决类型错误:在对 pytorch 模型使用 summary() 时,不能将序列乘以“元组”类型的非整数

因为我正在设计一个多模式架构,有一个 3dcnn 和 2dcnn,所以我想像在 Keras 中一样查看模型架构。输入大小如图所示,当我尝试打印结构时,我不断收到此错误,但在运行它以进行训练和模型评估时没有问题。

enter image description here

class CombineModel(nn.Module): #A multimodal network(3D and 2D CNN),with multIoUtputs (2 outputs)
  def __init__(self,num_classes):
    super(CombineModel,self).__init__()
    self.CNN3D = CNN3DModel() #output size is (512,1)
    self.Unet = UNet11() #output size is (512,1)
    self.fc1 = nn.Linear(1024,512)
    self.fc2 = nn.Linear(512,num_classes)
    self.fc3 = nn.Linear(512,1)
    self.relu = nn.ReLU()
    self.flatten = nn.Flatten() 
    self.softmax = nn.softmax()

  def forward(self,x1,x2):
    out1 = self.CNN3D(x1)
    out2 = self.Unet(x2)
    outa = torch.cat((out1,out2),dim=1)
    outa = self.fc1(outa)
    outa = self.relu(outa)
    outa = self.softmax(self.fc2(outa))
    outb = self.fc3(self.relu(out2))

    return outa,outb

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。