微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何存储包含 Gridsearchcv 对象的管道?

如何解决如何存储包含 Gridsearchcv 对象的管道?

我构建了以下模型:

def gen_binary_lightgbm(X_train,y_train,round_num,metric):
params = {
'learning_rate' : [0.001,0.01,0.05,0.08,0.1],'reg_lambda' : [0,0.5,1]
}

gkf = KFold(n_splits=5,shuffle=True,random_state=42).split(X_train,y_train)
lgb_estimator = lgb.LGBMClassifier(objective='binary',num_boost_round=round_num)

gsearch = Pipeline([('scaler',StandardScaler()),('model',gridsearchcv(
        estimator=lgb_estimator,param_grid=params,n_jobs=-1,scoring = metric,refit= metric,cv=gkf,verbose=-1,pre_dispatch=8,error_score=-999,return_train_score=True
        ))])

lgb_model = gsearch.fit(X_train,y_train)
return lgb_model

当我尝试腌制模型时:

with open(filename,'wb') as file:
    joblib.dump(lightgbm_model,file)

它出现了:

TypeError: can't pickle generator objects

谁能帮助我存储可以检索和进行预测的最佳模型?

非常感谢。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。