微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

序数回归算法似乎将预测转移了一类

如何解决序数回归算法似乎将预测转移了一类

对于具有 3 个类 (1,2,3) 的序数回归问题,我正在运行以下算法:

class OrdinalClassifier():

    def __init__(self,clf):
        self.clf = clf
        self.clfs = {}

    def fit(self,X,y):
        self.unique_class = np.sort(np.unique(y))
        if self.unique_class.shape[0] > 2:
            for i in range(self.unique_class.shape[0]-1):
                # for each k - 1 ordinal value we fit a binary classification problem
                binary_y = (y > self.unique_class[i]).astype(np.uint8)
                clf = clone(self.clf)
                clf.fit(X,binary_y)
                self.clfs[i] = clf

    def predict_proba(self,X):
        clfs_predict = {k:self.clfs[k].predict_proba(X) for k in self.clfs}
        predicted = []
        for i,y in enumerate(self.unique_class):
            if i == 0:
                # V1 = 1 - Pr(y > V1)
                predicted.append(1 - clfs_predict[y-1][:,1])
            #elif y in clfs_predict:
            elif y < self.unique_class.shape[0]:
                # Vi = Pr(y > Vi-1) - Pr(y > Vi)
                predicted.append(clfs_predict[y-2][:,1] - clfs_predict[y-1][:,1])
            else:
                # Vk = Pr(y > Vk-1)
                predicted.append(clfs_predict[y-2][:,1])
        return np.vstack(predicted).T

    def predict(self,X):
        return np.argmax(self.predict_proba(X),axis=1)

我通过给它一个 clf 来称呼它:

clf = RandomForestClassifier()
forest = OrdinalClassifier(clf)

并通过调用 fit 来训练它:

forest.fit(X_train,y_train)

最后我通过调用得到预测:

#add 1 so 0->1,1 -> 2,2 -> 3
pred  = forest.predict(y_test) + 1

我相信这个算法应该按照我的意图工作。然而,在使用不同的超参数集运行模型时,这些类似乎以某种方式混淆了。对于几乎所有超参数集,我都发现了相同的模式。

  1. 预测为第 1 类的类占实际第 2 类的百分比最大
  2. 预测为第 2 类的类占实际第 3 类的百分比最大
  3. 预测为第 3 类的类占实际第 1 类的百分比最大

我觉得实际上似乎找到正确序数的一两个超参数组合更多是由运气造成的,而不是实际上产生了一个好的模型。总之,我的算法似乎确实在数据中找到了序数关系,但这些关系不正确/解码不正确。

问题:我的模型有问题吗?我应该只选择在验证集上按预期执行的超参数的确切组合吗?或者我应该将我的预测解码为 3 类 => 1 类,1 类 => 2 类,2 类 => 3 类?

预先感谢您的帮助!

编辑:对于任何感兴趣的人,我的 LGBM 模型开始以正确顺序看到模式的参数是当学习率变得非常高(接近或等于 1)而参数为l1 正则化 >0.

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?