多项式logit在交叉验证的某些方面返回nans

如何解决多项式logit在交叉验证的某些方面返回nans

我编写了此代码,该代码使用分层k折叠来拆分数据集并拟合多项式回归,然后获得准确性。我的X一个有19个变量的数组(最后一个是聚类变量),而Y有3个类(0,1,2)。

X = np.asarray(df[[*all 19 columns here*]],dtype="float64")
y = np.asarray(df["categoric_var"],dtype="int")

acc_test=[]
acc_train=[]
skf = StratifiedKFold(n_splits=5,shuffle=True)
split_n = 0

for train_ix,test_ix in skf.split(X,y):
    split_n +=1
    X_train,X_valid = X[train_ix],X[test_ix]
    y_train,y_valid = y[train_ix],y[test_ix]
    cluster_groups = X_train[:,-1]
    X_train2 = X_train[:,:-1].astype("float64") # remove clustering variable
    X_valid2 = X_valid[:,:-1].astype("float64") # remove clustering variable

    mnl = sm.MNLogit(y_train,X_train2).fit(cov_type="cluster",cov_kwds={"groups":cluster_groups})
    print(mnl.summary())
    train_pred = mnl.predict(X_train2)

    # turn predicted probabilities into final classification,into a list
    pred_list_train = []
    for row in train_pred:
        if np.where(row == np.amax(row))[0]==0:
            pred_list_train.append(0)
        elif np.where(row == np.amax(row))[0]==1:
            pred_list_train.append(1)
        else:
            pred_list_train.append(2)

    print('MNLogit Regression,training set,fold ',i,': ',classification_report(y_train,pred_list_train))
    
    pred = mnl.predict(X_valid2)

    # turn predicted probabilities into final classification,into a list
    pred_list_test = []
    for row in pred:
        if np.where(row == np.amax(row))[0]==0:
            pred_list_test.append(0)
        elif np.where(row == np.amax(row))[0]==1:
            pred_list_test.append(1)
        else:
            pred_list_test.append(2)

    #Measure of the fit of the model

    print('MNLogit Regression,validation set,classification_report(y_valid,pred_list_test))

    acc_test.append(accuracy_score(y_valid,pred_list_test))
    acc_train.append(accuracy_score(y_train,pred_list_train))

问题是我有y的两个版本,一个版本的类更加不平衡(版本1),另一个版本的类更加平衡(版本2)。

当我在y的版本1中尝试此代码时,它可以完美地工作。但是,当我尝试在版本2上运行它时,有些折叠会返回所有nan的回归...这是一个示例(对长度表示歉意)。这是前两折的结果:

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2251: RuntimeWarning: divide by zero encountered in log

  logprob = np.log(self.cdf(np.dot(self.exog,params)))

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2252: RuntimeWarning: invalid value encountered in multiply

  return np.sum(d * logprob)

Optimization terminated successfully.

         Current function value: nan

         Iterations 14

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in greater

  return (a < x) & (x < b)

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in less

  return (a < x) & (x < b)

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:1912: RuntimeWarning: invalid value encountered in less_equal

  cond2 = cond0 & (x <= _a)

                          MNLogit Regression Results                          

==============================================================================

Dep. Variable:                      y   No. Observations:                13852

Model:                        MNLogit   Df Residuals:                    13814

Method:                           MLE   Df Model:                           36

Date:                Thu,13 Aug 2020   Pseudo R-squ.:                     nan

Time:                        23:04:09   Log-Likelihood:                    nan

converged:                       True   LL-Null:                       -13943.

Covariance Type:              cluster   LLR p-value:                       nan

==============================================================================

       y=1       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1            -0.0012      0.009     -0.126      0.900      -0.020       0.017

x2             0.0001    1.8e-05      6.207      0.000    7.63e-05       0.000

x3            -0.6074      0.621     -0.978      0.328      -1.825       0.610

x4             8.5373      1.219      7.004      0.000       6.148      10.926

x5             0.0136      0.002      5.906      0.000       0.009       0.018

x6             0.0024      0.066      0.037      0.970      -0.127       0.131

x7            -0.0060      0.003     -1.972      0.049      -0.012   -3.76e-05

x8            -0.0263      0.015     -1.695      0.090      -0.057       0.004

x9            -0.0237      0.026     -0.926      0.355      -0.074       0.026

x10           -0.0008      0.002     -0.404      0.686      -0.005       0.003

x11            0.0713      0.031      2.308      0.021       0.011       0.132

x12        -9.272e-05   1.54e-05     -6.003      0.000      -0.000   -6.24e-05

x13           -0.0012      0.000     -4.696      0.000      -0.002      -0.001

x14          5.53e-05   1.06e-05      5.215      0.000    3.45e-05    7.61e-05

x15           -0.0007      0.000     -3.538      0.000      -0.001      -0.000

x16         7.334e-05   6.94e-05      1.056      0.291   -6.27e-05       0.000

x17           -0.0098      0.001     -9.659      0.000      -0.012      -0.008

x18           -0.0506      0.036     -1.409      0.159      -0.121       0.020

x19            0.0953      0.017      5.682      0.000       0.062       0.128

------------------------------------------------------------------------------

       y=2       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1             0.0354      0.025      1.411      0.158      -0.014       0.084

x2             0.0003      0.000      1.996      0.046    5.62e-06       0.001

x3             3.3663      3.177      1.060      0.289      -2.860       9.593

x4            16.6473      8.483      1.962      0.050       0.021      33.273

x5             0.0507      0.026      1.963      0.050    7.82e-05       0.101

x6             0.3423      0.278      1.232      0.218      -0.202       0.887

x7             0.0274      0.026      1.051      0.293      -0.024       0.079

x8             0.0998      0.071      1.397      0.162      -0.040       0.240

x9            -0.0231      0.049     -0.466      0.641      -0.120       0.074

x10            0.0126      0.006      1.969      0.049    5.65e-05       0.025

x11            0.2219      0.129      1.720      0.085      -0.031       0.475

x12           -0.0002    8.6e-05     -2.286      0.022      -0.000    -2.8e-05

x13           -0.0022      0.001     -2.591      0.010      -0.004      -0.001

x14            0.0001   5.35e-05      2.313      0.021    1.89e-05       0.000

x15           -0.0018      0.001     -2.209      0.027      -0.003      -0.000

x16         6.439e-05      0.000      0.468      0.640      -0.000       0.000

x17           -0.8636      0.047    -18.523      0.000      -0.955      -0.772

x18            1.7166      4.104      0.418      0.676      -6.328       9.761

x19            0.0713      0.052      1.375      0.169      -0.030       0.173

==============================================================================

MNLogit Regression,fold  21 :                precision    recall  f1-score   support

 

           0       0.89      0.78      0.83      3679

           1       0.76      0.83      0.80      2738

           2       0.97      1.00      0.98      7435

 

    accuracy                           0.91     13852

   macro avg       0.87      0.87      0.87     13852

weighted avg       0.91      0.91      0.90     13852

 

MNLogit Regression,fold  21 :                precision    recall  f1-score   support

 

           0       0.88      0.78      0.83       920

           1       0.77      0.82      0.79       685

           2       0.97      1.00      0.98      1859

 

    accuracy                           0.90      3464

   macro avg       0.87      0.86      0.87      3464

weighted avg       0.90      0.90      0.90      3464

 

shape xtrain:  (13853,19)

shape ytrain:  (13853,)

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2219: RuntimeWarning: overflow encountered in exp

  eXB = np.column_stack((np.ones(len(X)),np.exp(X)))

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2220: RuntimeWarning: invalid value encountered in true_divide

  return eXB/eXB.sum(1)[:,None]

C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\base\optimizer.py:300: RuntimeWarning: invalid value encountered in greater

  oldparams) > tol)):

Optimization terminated successfully.

         Current function value: nan

         Iterations 6

                          MNLogit Regression Results                         

==============================================================================

Dep. Variable:                      y   No. Observations:                13853

Model:                        MNLogit   Df Residuals:                    13815

Method:                           MLE   Df Model:                           36

Date:                Thu,13 Aug 2020   Pseudo R-squ.:                     nan

Time:                        23:04:10   Log-Likelihood:                    nan

converged:                       True   LL-Null:                       -13944.

Covariance Type:              cluster   LLR p-value:                       nan

==============================================================================

       y=1       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1                nan        nan        nan        nan         nan         nan

x2                nan        nan        nan        nan         nan         nan

x3                nan        nan        nan        nan         nan         nan

x4                nan        nan        nan        nan         nan         nan

x5                nan        nan        nan        nan         nan         nan

x6                nan        nan        nan        nan         nan         nan

x7                nan        nan        nan        nan         nan         nan

x8                nan        nan        nan        nan         nan         nan

x9                nan        nan        nan        nan         nan         nan

x10               nan        nan        nan        nan         nan         nan

x11               nan        nan        nan        nan         nan         nan

x12               nan        nan        nan        nan         nan         nan

x13               nan        nan        nan        nan         nan         nan

x14               nan        nan        nan        nan         nan         nan

x15               nan        nan        nan        nan         nan         nan

x16               nan        nan        nan        nan         nan         nan

x17               nan        nan        nan        nan         nan         nan

x18               nan        nan        nan        nan         nan         nan

x19               nan        nan        nan        nan         nan         nan

------------------------------------------------------------------------------

       y=2       coef    std err          z      P>|z|      [0.025      0.975]

------------------------------------------------------------------------------

x1                nan        nan        nan        nan         nan         nan

x2                nan        nan        nan        nan         nan         nan

x3                nan        nan        nan        nan         nan         nan

x4                nan        nan        nan        nan         nan         nan

x5                nan        nan        nan        nan         nan         nan

x6                nan        nan        nan        nan         nan         nan

x7                nan        nan        nan        nan         nan         nan

x8                nan        nan        nan        nan         nan         nan

x9                nan        nan        nan        nan         nan         nan

x10               nan        nan        nan        nan         nan         nan

x11               nan        nan        nan        nan         nan         nan

x12               nan        nan        nan        nan         nan         nan

x13               nan        nan        nan        nan         nan         nan

x14               nan        nan        nan        nan         nan         nan

x15               nan        nan        nan        nan         nan         nan

x16               nan        nan        nan        nan         nan         nan

x17               nan        nan        nan        nan         nan         nan

x18               nan        nan        nan        nan         nan         nan

x19               nan        nan        nan        nan         nan         nan

==============================================================================

__main__:42: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False,but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

__main__:44: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False,but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

  _warn_prf(average,modifier,msg_start,len(result))

__main__:54: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False,but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

__main__:56: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False,but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.

C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

  _warn_prf(average,len(result))

MNLogit Regression,fold  21 :                precision    recall  f1-score   support

 

           0       0.00      0.00      0.00      3679

           1       0.00      0.00      0.00      2739

           2       0.54      1.00      0.70      7435

 

    accuracy                           0.54     13853

   macro avg       0.18      0.33      0.23     13853

weighted avg       0.29      0.54      0.37     13853

 

MNLogit Regression,fold  21 :                precision    recall  f1-score   support

 

           0       0.00      0.00      0.00       920

           1       0.00      0.00      0.00       684

           2       0.54      1.00      0.70      1859

 

    accuracy                           0.54      3463

   macro avg       0.18      0.33      0.23      3463

weighted avg       0.29      0.54      0.38      3463

我不知道这里会发生什么,因为什么都没有真正改变,只有因变量中的值。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?