Sklearn:尝试交叉验证管道时未收到变压器参数

如何解决Sklearn:尝试交叉验证管道时未收到变压器参数

我正在尝试在管道上使用 cross_validate 函数。如果我正常训练它,管道可以正常工作,但是当我使用 cross_validate 时出现错误。基本上,当我使用 cross_validate 函数时,我传递给管道上的转换器的参数是 nonetype。为什么会这样,我该如何解决?我试着在这里一个最小的例子

from sklearn.model_selection import cross_validate
from simpletransformers.classification import ClassificationModel
from sklearn.ensemble import RandomForestClassifier
from sklearn.base import BaseEstimator,TransformerMixin
from sklearn.pipeline import Pipeline,FeatureUnion
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_20newsgroups

import pandas as pd

class Transformer(BaseEstimator,TransformerMixin):
    
    def __init__(self,model,mtype,mname,num_labels: int):
        self._model = model
        self._inst_model = None
        self._num_labels = num_labels
        self._type = mtype
        self._name =mname
        
        
    def fit(self,train_input,y=None):
        self._create_model()
        self._train_model(train_input)
        return self
    
    def transform(self,eval_df,y=None):
        result,model_outputs,wrong_predictions = self._inst_model.eval_model(eval_df=eval_df)
        return model_outputs
    
    def _create_model(self):
        self._inst_model = self._model(self._type,self._name,args={"output_dir": 'min_ex'},num_labels=self._num_labels)
        
    def _train_model(self,train_input):
        train_df,eval_df = train_test_split(train_input,test_size=0.20)
        return self._inst_model.train_model(train_df,eval_df=eval_df)

if __name__ == '__main__':

    categories = ['sci.med','sci.space']
    X_t,y_t = fetch_20newsgroups(random_state=1,subset='train',categories=categories,remove=('footers','quotes'),return_X_y=True)
    X = pd.DataFrame({
            'text': X_t,'labels': y_t,})
    y = y_t
    transformer_grid = {
    "model": ClassificationModel,"num_labels": 14,"mtype": "electra","mname": "german-nlp-group/electra-base-german-uncased"
    }
    
    classifier_grid = {
    'n_estimators' : 100,'random_state': 42
    }
    
    pipe = Pipeline([ 
    ('feats',FeatureUnion([
        ('transformer',Pipeline([
            ('transformer',Transformer(**transformer_grid)),])),('classifier',RandomForestClassifier(**classifier_grid))
    ])
#     pipe.fit(X,y)
    cv_results = cross_validate(pipe,X,y,cv=5,scoring='accuracy',n_jobs=1)

我得到的错误是这个

2021-03-04 16:05:54.861544: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
/root/complex_semantics/lib/python3.8/site-packages/sklearn/base.py:209: FutureWarning: From version 0.24,get_params will raise an AttributeError if a parameter cannot be retrieved as an instance attribute. PrevIoUsly it would return None.
  warnings.warn('From version 0.24,get_params will raise an '
/root/complex_semantics/lib/python3.8/site-packages/sklearn/model_selection/_validation.py:548: FitFailedWarning: Estimator fit Failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "/root/complex_semantics/lib/python3.8/site-packages/sklearn/model_selection/_validation.py",line 531,in _fit_and_score
    estimator.fit(X_train,y_train,**fit_params)
  File "/root/complex_semantics/lib/python3.8/site-packages/sklearn/pipeline.py",line 330,in fit
    Xt = self._fit(X,**fit_params_steps)
  File "/root/complex_semantics/lib/python3.8/site-packages/sklearn/pipeline.py",line 292,in _fit
    X,fitted_transformer = fit_transform_one_cached(
  File "/root/complex_semantics/lib/python3.8/site-packages/joblib/memory.py",line 352,in __call__
    return self.func(*args,**kwargs)
  File "/root/complex_semantics/lib/python3.8/site-packages/sklearn/pipeline.py",line 740,in _fit_transform_one
    res = transformer.fit_transform(X,line 953,in fit_transform
    results = self._parallel_func(X,fit_params,_fit_transform_one)
  File "/root/complex_semantics/lib/python3.8/site-packages/sklearn/pipeline.py",line 978,in _parallel_func
    return Parallel(n_jobs=self.n_jobs)(delayed(func)(
  File "/root/complex_semantics/lib/python3.8/site-packages/joblib/parallel.py",line 1029,in __call__
    if self.dispatch_one_batch(iterator):
  File "/root/complex_semantics/lib/python3.8/site-packages/joblib/parallel.py",line 847,in dispatch_one_batch
    self._dispatch(tasks)
  File "/root/complex_semantics/lib/python3.8/site-packages/joblib/parallel.py",line 765,in _dispatch
    job = self._backend.apply_async(batch,callback=cb)
  File "/root/complex_semantics/lib/python3.8/site-packages/joblib/_parallel_backends.py",line 208,in apply_async
    result = ImmediateResult(func)
  File "/root/complex_semantics/lib/python3.8/site-packages/joblib/_parallel_backends.py",line 572,in __init__
    self.results = batch()
  File "/root/complex_semantics/lib/python3.8/site-packages/joblib/parallel.py",line 252,in __call__
    return [func(*args,**kwargs)
  File "/root/complex_semantics/lib/python3.8/site-packages/joblib/parallel.py",in <listcomp>
    return [func(*args,line 376,in fit_transform
    return last_step.fit_transform(Xt,**fit_params_last_step)
  File "/root/complex_semantics/lib/python3.8/site-packages/sklearn/base.py",line 693,in fit_transform
    return self.fit(X,**fit_params).transform(X)
  File "min_ex.py",line 23,in fit
    self._create_model()
  File "min_ex.py",line 32,in _create_model
    self._inst_model = self._model(self._type,TypeError: 'nonetype' object is not callable

  warnings.warn("Estimator fit Failed. The score on this train-test"

编辑:

我更改为在实例化模型时使用另一个变量,仍然是同样的问题。取出打印报表以便于阅读。要运行的代码唯一缺少的是数据的加载。只有在交叉验证时它仍然给我同样的错误

编辑 2:

通过添加合成数据集创建了一个最小的可重现示例

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?