微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

logit 和 sklearn 管道的一种热编码

如何解决logit 和 sklearn 管道的一种热编码

我正在尝试使用 Python 中的 Dalex 包来可视化二进制 logit 模型的某些特征。

我从示例书中复制了一段代码here (整个第五个代码单元)但现在我不太确定应该如何解释结果......

在我使用 statsmodels 创建的基本 logit 模型中,我为每个类别手动选择了一个参考级变量以避免多重共线性(这意味着模型的所有结果都被解释为到参考水平)。

但是当我使用上面链接中的一段代码(也在这文章下面复制)时,它首先在 sklearn 中创建了一些管道对象,对分类变量进行单热编码,然后管道对象是拟合数据并在Dalex解释器中用作待解释模型。

问题是,当我在 Dalex 中使用 model_profile() 之类的函数时,它应该输出一个图表,显示变量对预测的ceteris paribus 影响,我不知道如何解释结果,因为看起来好像所有 一个类变量中的值都包含在图中。

例如,该模型显示了“性别”分类变量对男性和女性平均预测的影响...

这也显示了一条名为“平均预测”的水平线,但“平均预测”是什么?是根据男性作为参考水平计算的,还是女性?

我真的对结果的含义感到困惑......有人可以澄清一下吗?我尝试使用的函数 model_profile() 在笔记本中也有说明。谢谢!

我复制的一段代码

    numerical_features = ['age','fare','sibsp','parch']
    numerical_transformer = Pipeline(
        steps=[
            ('imputer',SimpleImputer(strategy='median')),('scaler',StandardScaler())
        ]
    )
    
    categorical_features = ['gender','class','embarked']
    categorical_transformer = Pipeline(
        steps=[
            ('imputer',SimpleImputer(strategy='constant',fill_value='missing')),('onehot',OneHotEncoder(handle_unkNown='ignore'))
        ]
    )
    
    preprocessor = ColumnTransformer(
        transformers=[
            ('num',numerical_transformer,numerical_features),('cat',categorical_transformer,categorical_features)
        ]
    )
    
    classifier = MLPClassifier(hidden_layer_sizes=(150,100,50),max_iter=500,random_state=0)
    
    clf = Pipeline(steps=[('preprocessor',preprocessor),('classifier',classifier)])
    clf.fit(X,y)
    exp = dx.Explainer(clf,X,y)

解决方法

为什么会这样?

发生这种情况是因为默认情况下,sklearnOneHotEncoder 会对数据中的每个类别进行一次热转换。然而,对于像 logit 这样的线性模型,通常最好将其中一个类别排除在外,以避免多重共线性并使您的结果可解释参考点。在这种情况下,您需要更改编码器的默认设置。

示例

您可以通过设置 drop="first" 来实现这一点,这会删除一个热编码过程的第一个类别。下面的示例说明了如何在一个简单的示例上工作。在这里,“女性”类别从一个热编码中删除,只有“男性”类别被编码,这将返回您期望的结果。请注意,这也适用于非二进制功能。

from sklearn.preprocessing import OneHotEncoder
X = pd.DataFrame({"gender":["male","female","male"]})
OHE = OneHotEncoder(drop="first")
OHE.fit_transform(X).toarray()
#[[1.],# [0.],# [1.]]
OHE.get_feature_names()
#['x0_male']

你需要做什么

因此,您只需在代码中更改管道定义中的以下行:

'onehot',OneHotEncoder(drop='first',handle_unknown='ignore')

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。