如何在scikit中将f1_score参数传递给make_scorer学习如何与cross_val_score一起使用?

如何解决如何在scikit中将f1_score参数传递给make_scorer学习如何与cross_val_score一起使用?

我有一个分类问题(有很多标签),我想使用F1分数和'average'='weighted'。

虽然我做错了。这是我的代码

from sklearn.metrics import f1_score

from sklearn.metrics import make_scorer

f1 = make_scorer(f1_score,{'average' : 'weighted'})

np.mean(cross_val_score(model,X,y,cv=8,n_jobs=-1,scoring = f1))

---------------------------------------------------------------------------
_RemoteTraceback                          Traceback (most recent call last)
_RemoteTraceback: 
"""
Traceback (most recent call last):
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\joblib\externals\loky\process_executor.py",line 418,in _process_worker
    r = call_item()
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\joblib\externals\loky\process_executor.py",line 272,in __call__
    return self.fn(*self.args,**self.kwargs)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\joblib\_parallel_backends.py",line 608,in __call__
    return self.func(*args,**kwargs)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\joblib\parallel.py",line 256,in __call__
    for func,args,kwargs in self.items]
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\joblib\parallel.py",in <listcomp>
    for func,kwargs in self.items]
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\sklearn\model_selection\_validation.py",line 560,in _fit_and_score
    test_scores = _score(estimator,X_test,y_test,scorer)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\sklearn\model_selection\_validation.py",line 607,in _score
    scores = scorer(estimator,y_test)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\sklearn\metrics\_scorer.py",line 88,in __call__
    *args,**kwargs)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\sklearn\metrics\_scorer.py",line 213,in _score
    **self._kwargs)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\sklearn\utils\validation.py",line 73,in inner_f
    return f(**kwargs)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\sklearn\metrics\_classification.py",line 1047,in f1_score
    zero_division=zero_division)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\sklearn\utils\validation.py",line 1175,in fbeta_score
    zero_division=zero_division)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\sklearn\utils\validation.py",line 1434,in precision_recall_fscore_support
    pos_label)
  File "C:\Users\Alienware\Anaconda3\envs\tf2\lib\site-packages\sklearn\metrics\_classification.py",line 1265,in _check_set_wise_labels
    % (y_type,average_options))
ValueError: Target is multiclass but average='binary'. Please choose another average setting,one of [None,'micro','macro','weighted'].
"""

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-48-0323d7b23fbc> in <module>
----> 1 np.mean(cross_val_score(model,scoring = f1))

~\Anaconda3\envs\tf2\lib\site-packages\sklearn\utils\validation.py in inner_f(*args,**kwargs)
     71                           FutureWarning)
     72         kwargs.update({k: arg for k,arg in zip(sig.parameters,args)})
---> 73         return f(**kwargs)
     74     return inner_f
     75 

~\Anaconda3\envs\tf2\lib\site-packages\sklearn\model_selection\_validation.py in cross_val_score(estimator,groups,scoring,cv,n_jobs,verbose,fit_params,pre_dispatch,error_score)
    404                                 fit_params=fit_params,405                                 pre_dispatch=pre_dispatch,--> 406                                 error_score=error_score)
    407     return cv_results['test_score']
    408 

~\Anaconda3\envs\tf2\lib\site-packages\sklearn\utils\validation.py in inner_f(*args,args)})
---> 73         return f(**kwargs)
     74     return inner_f
     75 

~\Anaconda3\envs\tf2\lib\site-packages\sklearn\model_selection\_validation.py in cross_validate(estimator,return_train_score,return_estimator,error_score)
    246             return_times=True,return_estimator=return_estimator,247             error_score=error_score)
--> 248         for train,test in cv.split(X,groups))
    249 
    250     zipped_scores = list(zip(*scores))

~\Anaconda3\envs\tf2\lib\site-packages\joblib\parallel.py in __call__(self,iterable)
   1015 
   1016             with self._backend.retrieval_context():
-> 1017                 self.retrieve()
   1018             # Make sure that we get a last message telling us we are done
   1019             elapsed_time = time.time() - self._start_time

~\Anaconda3\envs\tf2\lib\site-packages\joblib\parallel.py in retrieve(self)
    907             try:
    908                 if getattr(self._backend,'supports_timeout',False):
--> 909                     self._output.extend(job.get(timeout=self.timeout))
    910                 else:
    911                     self._output.extend(job.get())

~\Anaconda3\envs\tf2\lib\site-packages\joblib\_parallel_backends.py in wrap_future_result(future,timeout)
    560         AsyncResults.get from multiprocessing."""
    561         try:
--> 562             return future.result(timeout=timeout)
    563         except LokyTimeoutError:
    564             raise TimeoutError()

~\Anaconda3\envs\tf2\lib\concurrent\futures\_base.py in result(self,timeout)
    433                 raise CancelledError()
    434             elif self._state == FINISHED:
--> 435                 return self.__get_result()
    436             else:
    437                 raise TimeoutError()

~\Anaconda3\envs\tf2\lib\concurrent\futures\_base.py in __get_result(self)
    382     def __get_result(self):
    383         if self._exception:
--> 384             raise self._exception
    385         else:
    386             return self._result

ValueError: Target is multiclass but average='binary'. Please choose another average setting,'weighted']. 

解决方法

在查看documentation中给出的示例时,您会发现应该将score函数的参数(此处为f1_score)传递为dict,而不是传递为关键字参数:

f1 = make_scorer(f1_score,average='weighted')

np.mean(cross_val_score(model,X,y,cv=8,n_jobs=-1,scorin =f1))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?