微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

仅将函数应用于满足条件的数组切片 (NumPy)

如何解决仅将函数应用于满足条件的数组切片 (NumPy)

我有一个数组 A[i,j]。最后一个索引包含函数 myfunc 的各种输入值,该函数将应用于每个 i 并产生输出 B[i]。但是,由 j 索引的许多值不会对 B 做出贡献,因此我想避免不必要的 myfunc 调用。这可以通过使用条件索引(例如 C = C[C>mythreshold] 和 for 循环相对容易地切出相关值来实现,如下面的 MWE 所示:

def myfunc(X):
    return np.square(X).sum()
A = np.floor(np.random.rand(3,4)*100)
mythreshold = 10
(N1,N2) = A.shape
B = np.zeros(N1)
for i in range(N1):
    C = A[i,:]
    C = C[C>mythreshold]
    B[i] = myfunc(C)

我不得不把它分解成 for 循环,这样我就可以删除 A 的切片而不删除完整数组的切片。这是因为我无法删除一个 A[i,:]i 元素而不删除一个 i 的相应元素。然而,为了速度,我想尽可能进行矢量化 - 避免 for 循环并一次性对所有 i 执行此操作。我该怎么做?

注意:那是一个 MWE;实际情况具有更大的数组维度,因此我的数组将是 A[i,j,k,l]B[i,j],因此 for 循环示例类似于下面的代码。我认为额外的维度不会使事情复杂化,但值得一提以防万一。

(N1,N2,N3,N4) = A.shape
for i in range(N1):
    for j in range(N2):
        C = A[i,:,:].flatten()
        C = C[C>mythreshold]
        B[i,j] = myfunc(C)

解决方法

In [10]: A = np.floor(np.random.rand(3,4)*2*mythreshold)
In [11]: A
Out[11]: 
array([[14.,4.,1.,8.],[11.,11.,2.],[ 8.,6.,18.,12.]])
In [12]: (N1,N2) = A.shape
    ...: B = np.zeros(N1)
    ...: for i in range(N1):
    ...:     C = A[i,:]
    ...:     C = C[C>mythreshold]
    ...:     B[i] = myfunc(C)
    ...: 
In [13]: B
Out[13]: array([196.,242.,468.])

对整个阵列的阈值测试:

In [14]: A>mythreshold
Out[14]: 
array([[ True,False,False],[ True,True,[False,True]])

制作一个副本,并将其他值设置为 0(或无害的):

In [15]: A1 = A.copy(); A1[A<=mythreshold]=0
In [16]: np.square(A1).sum(axis=1)
Out[16]: array([196.,468.])

这并没有避免将函数应用于所有元素,但它避免了对行进行迭代。通常避免 python 级循环加速 numpy 代码。但是,如果您的函数不能像我使用 axis 参数那样“矢量化”,或者它非常复杂以致于包含那些“0”值的代价很高,那么这不是可行的方法。>

如果您的函数包含 ufunc,您可以使用其 where 参数

In [17]: mask = A>mythreshold
In [18]: out = np.zeros_like(A)
In [19]: np.square(A,out=out,where=mask)
Out[19]: 
array([[196.,0.,0.],[121.,121.,[  0.,324.,144.]])
In [20]: _.sum(axis=1)
Out[20]: array([196.,468.])

通常当某些值给出不好的结果时使用这个 where,例如除以 0 或负数的对数。我不认为这可以节省时间,但还没有确定时间来确认这一点。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。