微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

用嵌套的np.less短路np.all,以便在numpy中进行大型数组比较

如何解决用嵌套的np.less短路np.all,以便在numpy中进行大型数组比较

在我当前的代码(请参阅MWE)中,我遇到了一个瓶颈,在此情况下,对于大型2D数组,使用嵌套的np.all执行np.less。我知道,如果false中只有一个np.less值,我们可以停止检查,因为索引中的其余值会将代码评估为false(因为我 AND -将给定维度的单个索引中的所有值放在一起)。

使用numba或numpy是否可以利用这种“早期退出/短路”条件在此计算中产生有意义的加速?

MWE的倒数第二行是我要加快的速度。请注意,NM可能很大,但实际上只有很少的比较会得出true

import numpy as np

N = 10000
M = 10 # Reduced to small value to show that sometimes the comparisons evaluate to 'True'

array = np.random.uniform(low=0.0,high=10.0,size=(N,M))
comparison_array = np.random.uniform(low=0.0,size=(M))

# Can we apply an early exit condition on this?
mask = np.all(np.less(array,comparison_array),axis=-1)

print(f"Number of 'True' comparisons: {np.sum(mask)}")

解决方法

这里的numba版本开发得可以正常工作,而不一定经过优化:

@numba.njit
def foo(arr,carr):
    N,M = arr.shape
    mask = np.ones(N,dtype=np.bool_)
    for i in range(N):
        for j in range(M):
            if arr[i,j]>=carr[j]:
                mask[i]=False
                break
    return mask

测试:

In [178]: np.sum(foo(array,comparison_array))
Out[178]: 2
In [179]: np.sum(np.all(np.less(array,comparison_array),axis=1))
Out[179]: 2

时间:

In [180]: timeit np.sum(foo(array,comparison_array))
155 µs ± 6.36 µs per loop (mean ± std. dev. of 7 runs,10000 loops each)
In [181]: timeit np.sum(np.all(np.less(array,axis=1))
451 µs ± 5.19 µs per loop (mean ± std. dev. of 7 runs,1000 loops each)

这是一个不错的进步。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?