微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

将 Pytorch Float 模型转换为 Double

如何解决将 Pytorch Float 模型转换为 Double

我正在尝试解决 Gym 中的 Cartpole。事实证明,状态是双浮点精度,而 pytorch 认以单浮点精度创建模型。

class QNetworkMLP(Module):
    def __init__(self,state_dim,num_actions):
        super(QNetworkMLP,self).__init__()
        self.l1 = Linear(state_dim,64)
        self.l2 = Linear(64,64)
        self.l3 = Linear(64,128)
        self.l4 = Linear(128,num_actions)
        self.relu = ReLU()
        self.lrelu = LeakyReLU()
    
    def forward(self,x) :
        x = self.lrelu(self.l1(x))
        x = self.lrelu(self.l2(x))
        x = self.lrelu(self.l3(x))
        x = self.l4(x)
        return x

我试图通过

转换它
model = QNetworkMLP(4,2).double()

但它仍然不起作用我得到同样的错误

File ".\agent.py",line 117,in update_online_network
    predicted_Qval = self.online_network(states_batch).gather(1,actions_batch)
  File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py",line 722,in _call_impl
    result = self.forward(*input,**kwargs)
  File "C:\Users\27abh\Desktop\OpenAI Gym\Cartpole\agent_model.py",line 16,in forward
    x = self.lrelu(self.l1(x))
  File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py",**kwargs)
  File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\linear.py",line 91,in forward
    return F.linear(input,self.weight,self.bias)
  File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\functional.py",line 1674,in linear
    ret = torch.addmm(bias,input,weight.t())
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat1' in call to _th_addmm

解决方法

你能在初始化模型后试试这个吗:

 model.to(torch.double)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。