微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

onehot编码后如何使用预测模型?

如何解决onehot编码后如何使用预测模型?

我已经为这个数据集创建了一个预测模型

>>df.head()

    Service    Tasks Difficulty     Hours
0   ABC         24     1           0.833333
1   CDE         77     1           1.750000
2   SDE         90     3           3.166667
3   QWE         47     1           1.083333
4   ASD         26     3           1.000000

>>df.shape
(998,4)

>>X = df.iloc[:,:-1]
>>y = df.iloc[:,-1].values
>>from sklearn.compose import ColumnTransformer 
>>ct = ColumnTransformer([("cat",OneHotEncoder(),[0])],remainder="passthrough")
>>X = ct.fit_transform(X)  
>>x = X.toarray()
>>x = x[:,1:]

>>x.shape
(998,339)

>>from sklearn.ensemble import RandomForestRegressor
>>rf_model = RandomForestRegressor(random_state = 1)
>>rf_model.fit(x,y)

我如何使用这个模型来预测 Hours 以这种格式的用户输入[["SDE",90,3]]

我试过了

>>test_input = [["SDE",3]]
>>test_input = ct.fit_transform(test_input)  
>>test_input = test_input[[:,1:]

>>test_input[0]
array([24,1],dtype=object)


>>predict_hours = rf_model.predict(test_input)
ValueError

由于我的数据集有很多 categorical 值,因此无法输入 "SDE" 的编码值作为输入,我需要在收到输入后将 "SDE" 转换为 onehot encoded 格式[["SDE",3]]

我不知道该怎么做,有人可以帮忙吗?

解决方法

您可以使用 Pipeline 轻松处理预处理和分类阶段

import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

# I have created a dummy dataset
df = pd.read_csv('test.csv')

X = df.iloc[:,:-1]
y = df.iloc[:,-1].values

# preprocessor
preprocessor = ColumnTransformer([("cat",OneHotEncoder(handle_unknown='ignore'),[0])],remainder="passthrough")

# create a pipeline with preprocessor and classifier
pipeline = Pipeline([('preprocessor',preprocessor),('classifier',RandomForestRegressor(random_state = 1))
                      ])
# split dataset
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.5,random_state=0)

# train the pipelime
pipeline.fit(X_train,y_train)

# predict
print(pipeline.predict(X_test))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。