带有交叉验证的 statsmodels GLM 预测的形状未对齐错误

如何解决带有交叉验证的 statsmodels GLM 预测的形状未对齐错误

我遇到了一个问题,我在 statsmodels 中构建阶跃函数,同时首先使用交叉验证来确定理想的切割量。但是我遇到了一个问题,我就是不知道如何解决

在我使用 Sklearn 的 KFold 函数添加交叉验证循环后,我开始收到错误

ValueError: shapes (480,2) and (1,) not aligned: 2 (dim 1) != 1 (dim 0)

我不确定为什么现在会发生这种情况,就像在我开始使用交叉验证循环之前一样,它运行良好,没有任何问题。

如果有人可以查看我的代码块并指出此问题的根源,我将不胜感激。

进入前X_train和y_train的形状:

X_train: (2400,)  y_train: (2400,)

代码

import statsmodels.api as sm
from sklearn.model_selection import KFold

kf = KFold(n_splits=5,shuffle=True,random_state=1)

cuts = []
RMSE = []

for i in range(1,11):
  cuts.append(i)
  cross_val_rms = []
  for train_index,test_index in kf.split(X_train):
    train_x,test_x= X_train.iloc[train_index],X_train.iloc[test_index]
    train_y,test_y= y_train.iloc[train_index],y_train.iloc[test_index]
    
    df_cut,bins = pd.cut(train_x,i,retbins=True,right=True)
    df_steps = pd.concat([train_x,df_cut,train_y],keys=['age','age_cuts','wage'],axis = 1)
    df_steps_dummies = pd.get_dummies(df_cut)
    GLM_fitted = sm.GLM(df_steps.wage,df_steps_dummies).fit()
    bin_mapping = np.digitize(test_x,bins)
    X_valid = pd.get_dummies(bin_mapping)
    pred = GLM_fitted.predict(X_valid)
    rms = np.sqrt(mean_squared_error(test_y,pred))
    cross_val_rms.append(rms)
  mean_rms = sum(cross_vall_rms)/len(cross_vall_rms)
  RMSE.append(mean_rms)

cuts_df = pd.DataFrame()
cuts_df['Cuts'] = cuts
cuts_df['RMSE'] = RMSE

print('Cuts with lowest Root Mean Squared Error:',cuts_df.loc[cuts_df['RMSE'].idxmin],sep='\n')

错误



---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-166-a9794538c3e5> in <module>()
     21     bin_mapping = np.digitize(test_x,bins)
     22     X_valid = pd.get_dummies(bin_mapping)
---> 23     pred = GLM_fitted.predict(X_valid)
     24     rms = np.sqrt(mean_squared_error(test_y,pred))
     25     cross_val_rms.append(rms)

1 frames

/usr/local/lib/python3.7/dist-packages/statsmodels/genmod/generalized_linear_model.py in predict(self,params,exog,exposure,offset,linear)
    870             exog = self.exog
    871 
--> 872         linpred = np.dot(exog,params) + offset + exposure
    873         if linear:
    874             return linpred

<__array_function__ internals> in dot(*args,**kwargs)

ValueError: shapes (480,) not aligned: 2 (dim 1) != 1 (dim 0)

解决方法

我认为如果您解释您在回归中尝试做什么会有所帮助。你得到这个错误是因为如果你从训练折叠中得到 3 个 bin,这并不意味着你从测试折叠中得到了 3 个 bin,你可能会得到 2 个折叠,因为 1 个 bin 中没有值。

据我所知,您可以简单地先对值进行离散化,然后使用示例数据进行训练:

import numpy as np
import pandas as pd
import statsmodels.api as sm
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold


X_train = pd.Series(np.random.uniform(0,1,2400))
y_train = pd.Series(np.random.uniform(0,2400))

然后

for i in range(2,11):

  cross_val_rms = []
  df_steps_dummies = pd.get_dummies(pd.cut(X_train,i))
  
  for train_index,test_index in kf.split(X_train):
    train_x,test_x= df_steps_dummies.iloc[train_index,:],df_steps_dummies.iloc[test_index,:]
    train_y,test_y= y_train[train_index],y_train[test_index]
    
    GLM_fitted = sm.GLM(train_y,train_x).fit()
    pred = GLM_fitted.predict(test_x)
    rms = np.sqrt(mean_squared_error(test_y,pred))
    cross_val_rms.append(rms)

RMSE.append(np.array(cross_val_rms).mean())

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?