如何解决使用 skorch 和 sklearn 管道的多输出回归由于 dtype 导致运行时错误
我想用skorch做多输出回归。我创建了一个小玩具示例,如下所示。在这个例子中,NN 应该预测 5 个输出。我还想使用使用 sklearn 管道合并的预处理步骤(在此示例中使用 PCA,但它可以是任何其他预处理器)。执行此示例时,我在 Torch 的 Variable._execution_engine.run_backward 步骤中收到以下错误:
RuntimeError: Found dtype Double but expected Float
我是不是忘记了什么?我怀疑,必须在某个地方投射某些东西,但是由于 skorch 处理了很多 pytorch 的东西,我不知道是什么以及在哪里。
示例:
import torch
import skorch
from sklearn.datasets import make_classification,make_regression
from sklearn.pipeline import Pipeline,make_pipeline
from sklearn.decomposition import PCA
X,y = make_regression(n_samples=1000,n_features=40,n_targets=5)
X = X.astype('float32')
class RegressionModule(torch.nn.Module):
def __init__(self,input_dim=80):
super().__init__()
self.l0 = torch.nn.Linear(input_dim,10)
self.l1 = torch.nn.Linear(10,5)
def forward(self,X):
y = self.l0(X)
y = self.l1(y)
return y
class InputShapeSetter(skorch.callbacks.Callback):
def on_train_begin(self,net,X,y):
net.set_params(module__input_dim=X.shape[-1])
net = skorch.NeuralNetRegressor(
RegressionModule,callbacks=[InputShapeSetter()],)
pipe = make_pipeline(PCA(n_components=10),net)
pipe.fit(X,y)
print(pipe.predict(X))
编辑 1:
从这个例子中可以看出,在开始时将 X 转换为 float32 对每个预处理器都不起作用:
import torch
import skorch
from sklearn.datasets import make_classification,make_regression
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA
from category_encoders import OneHotEncoder
X,n_targets=5)
X = pd.DataFrame(X,columns=[f'feature_{i}' for i in range(X.shape[1])])
X['feature_1'] = pd.qcut(X['feature_1'],3,labels=["good","medium","bad"])
y = y.astype('float32')
class RegressionModule(torch.nn.Module):
def __init__(self,)
pipe = make_pipeline(OneHotEncoder(cols=['feature_1'],return_df=False),y)
print(pipe.predict(X))
解决方法
默认情况下,OneHotEncoder
返回 dtype=float64
的 numpy 数组。因此,当输入模型的 X
时,可以简单地转换输入数据 forward()
:
class RegressionModule(torch.nn.Module):
def __init__(self,input_dim=80):
super().__init__()
self.l0 = torch.nn.Linear(input_dim,10)
self.l1 = torch.nn.Linear(10,5)
def forward(self,X):
X = X.to(torch.float32)
y = self.l0(X)
y = self.l1(y)
return y
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。