基于python中的多个特征的训练-测试分割的分层交叉验证或抽样

如何解决基于python中的多个特征的训练-测试分割的分层交叉验证或抽样

sklearn 的 train_test_splitStratifiedShuffleSplitStratifiedKFold 都基于类标签(y 变量或 target_column)进行分层。如果我们想基于特征列(x 变量)而不是不基于目标列进行采样怎么办。如果它只有一个特征,那么基于该单列进行分层会很容易,但是如果特征列很多并且我们想要保留所选样本中的总体比例怎么办?

下面我创建了一个 df,它使人口倾斜,其中低收入者较多,女性较多,CA 最少,MA 最多。我希望所选样本具有这些特征,即更多的低收入者,更多的女性,来自 CA 的人数最少,来自 MA 的人数最多

import random
import string
import pandas as pd
N = 20000 # Total rows in data
names    = [''.join(random.choices(string.ascii_uppercase,k = 5)) for _ in range(N)]
incomes  = [random.choices(['High','Low'],weights=(30,70))[0] for _ in range(N)]
genders  = [random.choices(['M','F'],weights=(40,60))[0] for _ in range(N)]
states   = [random.choices(['CA','IL','FL','MA'],weights=(10,20,30,40))[0] for _ in range(N)]
targets_y= [random.choice([0,1]) for _ in range(N)]

df = pd.DataFrame(dict(
        name     = names,income   = incomes,gender   = genders,state    = states,target_y = targets_y
    ))

如果对于某些特征,我们的示例很少,并且我们希望在所选示例中至少包含 n 个示例,则会出现更复杂的情况。考虑这个例子:

single_row = {'name' : 'ABC','income' : 'High','gender' : 'F','state' : 'NY','target_y' : 1}

df = df.append(single_row,ignore_index=True)

df

enter image description here

.

我希望添加的这一行始终包含在测试拆分中(此处为 n=1)。

解决方法

这可以使用 pandas groupby 来实现:

我们先来看看人口特征:

grps = df.groupby(['state','income','gender'],group_keys=False)
grps.count()

enter image description here

接下来让我们用 20% 的原始数据创建一个测试集

test_proportion = 0.2
at_least = 1
test = grps.apply(lambda x: x.sample(max(round(len(x)*test_proportion),at_least)))
test

enter image description here

测试集特征:

test.groupby(['state','gender']).count()

enter image description here

接下来我们创建训练集作为原始 df 和测试集的差异

print('No. of samples in test  =',len(test))
train = set(df.name) - set(test.name)
print('No. of samples in train =',len(train))

没有。测试样本数 = 4000

没有。训练中的样本数 = 16001

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?