微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在numpy数组中提取每行固定数量的元素

如何解决在numpy数组中提取每行固定数量的元素

假设我有一个数组 a一个布尔数组 b,我想从 a 的每一行中的有效元素中提取固定数量的元素。有效元素是由 b 指示的元素。

这是一个例子:

a = np.arange(24).reshape(4,6)
b = np.array([[0,1,0],[0,1],1]]).astype(bool)
x = []
for i in range(a.shape[0]):
    c = a[i,b[i]]
    d = np.random.choice(c,2)
    x.append(d)

这里我使用了一个 for 循环,如果这些数组很大并且是高维的,它会很慢。有没有更有效的方法来做到这一点?谢谢。

解决方法

  1. 生成形状为 a 的随机均匀 [0,1] 矩阵。
  2. 将此矩阵乘以掩码 b 以将无效元素设置为零。
  3. 从每行中选择 k 个最大索引(仅从该行中的有效元素模拟无偏随机 k 样本)。
  4. (可选)使用这些索引来获取元素。
a = np.arange(24).reshape(4,6)
b = np.array([[0,1,0],[0,1],1]])
k = 2

r = np.random.uniform(size=a.shape)
indices = np.argpartition(-r * b,k)[:,:k]

从索引中获取元素:

>>> indices
array([[3,2],[5,[3,[4,5]])
>>> a[np.arange(a.shape[0])[:,None],indices]
array([[ 3,[11,7],[15,14],[22,23]])

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。