微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Pickle 自定义对象

如何解决Pickle 自定义对象

以下代码

from sklearn.preprocessing import LabelBinarizer
lb = LabelBinarizer()
lb.fit_transform(['yes','no','yes'])

返回二进制类的向量,如此处所述http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelBinarizer.html

array([[1],[0],[1]])

虽然二进制类的理想情况是(如 MultiLabelBinarizer):

array([[1,0],[0,1],[1,0]])

使用@applecider 中的以下类 sklearn LabelBinarizer returns vector when there are 2 classes

import numpy as np
from sklearn.preprocessing import LabelBinarizer


class LabelBinarizer2:

    def __init__(self):
        self.lb = LabelBinarizer()

    def fit(self,X):
        # Convert X to array
        X = np.array(X)
        # Fit X using the LabelBinarizer object
        self.lb.fit(X)
        # Save the classes
        self.classes_ = self.lb.classes_

    def fit_transform(self,X):
        # Convert X to array
        X = np.array(X)
        # Fit + transform X using the LabelBinarizer object
        Xlb = self.lb.fit_transform(X)
        # Save the classes
        self.classes_ = self.lb.classes_
        if len(self.classes_) == 2:
            Xlb = np.hstack((Xlb,1 - Xlb))
        return Xlb

    def transform(self,X):
        # Convert X to array
        X = np.array(X)
        # Transform X using the LabelBinarizer object
        Xlb = self.lb.transform(X)
        if len(self.classes_) == 2:
            Xlb = np.hstack((Xlb,1 - Xlb))
        return Xlb

    def inverse_transform(self,Xlb):
        # Convert Xlb to array
        Xlb = np.array(Xlb)
        if len(self.classes_) == 2:
            X = self.lb.inverse_transform(Xlb[:,0])
        else:
            X = self.lb.inverse_transform(Xlb)
        return X

然后,我对数据进行 fit_transform,这可以解决问题,但现在无法使用 pickle 并存储它,以便稍后加载编码器进行转换并将其用于测试数据。

encoder_leadsourcecode = preprocessing.MultiLabelBinarizer()
feature_leadsourcecode = encoder_leadsourcecode.fit_transform(df["Lead Source Code"])
feature_leadsourcecode = pd.DataFrame(feature_leadsourcecode,columns=encoder_leadsourcecode.classes_)

什么时候,我尝试腌制它:

LeadSourceCodeEnc = pk.dumps(encoder_leadsourcecode)

我得到以下信息:

AttributeError: Can't pickle local object 'transform.<locals>.LabelBinarizer2'

那么我们如何腌制自定义对象?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。