如何解决如何允许复杂的输入和复杂的权重到 Pytorch 模型?
假设即使是最简单的模型(取自 here)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(1,32,3,1)
self.conv2 = nn.Conv2d(32,64,1)
self.fc1 = nn.Linear(9216,128)
self.fc2 = nn.Linear(128,10)
def forward(self,x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x,2)
x = torch.flatten(x,1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x,dim=1)
return output
向模型提供复杂数据时,
output = model(data.complex())
它给了
ret = torch.addmm(bias,input,weight.t())
RuntimeError: expected scalar type Float but found ComplexDouble
(为了简单起见,我没有复制整个堆栈跟踪,也没有复制整个训练代码)
在模型的 self.complex()
之后做 __init__
,就像我通常会做的那样 self.double()
,不起作用,
torch.nn.modules.module.ModuleAttributeError: 'Net' object has no attribute 'complex'
编辑:
同时,我发现 this paper。还在读。
解决方法
正如您通常所做的self.double()
,我从https://pytorch.org/docs/stable/generated/torch.nn.Module.html
self.type(dst_type)
就我而言,self.type(torch.complex64)
对我有用。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。