微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 numpy 和列表​​理解评估函数时,绕过网格网格的工作速度更快

如何解决使用 numpy 和列表​​理解评估函数时,绕过网格网格的工作速度更快

在这个线程中,我通过usethedeathstar找到了一种在使用简单的numpy方程时绕过meshgrid的方法numpy - evaluate function on a grid of points

我遇到了类似的问题,但在等式中使用列表理解并尝试试一试,我认为它行不通:

import numpy as np
import math
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def rastrigin(dim,x,A=10):
    return  dim + np.sum([(xi**2 - A * np.cos(2 * math.pi * xi)) for xi in x])

def main():
    x = np.linspace(-4,4,100)    
    y = np.linspace(-4,100)

    # Option 1 - bypass meshgrid - MAGIC!
    #https://stackoverflow.com/questions/22774726/numpy-evaluate-function-on-a-grid-of-points/22778484#22778484
    Z = rastrigin(2,[x[:,None],y[None,:]])

    # Option 2 - Traditional way using meshgrid
    X,Y = np.meshgrid(x,y)
    Z = np.array( [rastrigin(2,[x,y]) for x,y in zip(np.ravel(X),np.ravel(Y))] ).reshape(X.shape) 
    
    # timeit shows Option 1 is ridiculously faster than Option 2
    import timeit
    t1 = timeit.timeit(lambda: np.array( [rastrigin(2,np.ravel(Y))] ).reshape(X.shape),number=100)
    t2 = timeit.timeit(lambda: rastrigin(2,:]]),number=100)
    print(t1,t2)

    fig = plt.figure()
    ax = fig.gca(projection='3d')

    ax.plot_surface(X,Y,Z,rstride=1,cstride=1,cmap=plt.cm.plasma,linewidth=0,antialiased=False)   
    plt.show()


if __name__ == "__main__":
    main()

这不仅适用于列表理解,我认为即使是原作者也不打算这样做,但速度非常快。选项 1 在 0.003 秒内运行 timeit,选项 2 在 8.7 秒内运行。

我的问题是:如何?我不明白为什么这种方法适用于列表理解。

我知道这会生成两个数组,一个 (100,1) 和另一个 (1,100): [(xi**2 - A * np.cos(2 * math.pi * xi)) for xi in x] 。然后 numpy.sum 正在传播总和并生成 (100,100) 结果?这是预期的行为吗?

解决方法

numpy - evaluate function on a grid of points 中关于绕过 meshgrid 的讨论有点误导。

考虑在 (1000,1000) 值网格上评估演示函数:

使用 meshgrid 制作一对 (1000,1000) 数组:

In [136]: X,Y = np.meshgrid(np.linspace(-4,4,1000),np.linspace(-4,1000))
In [137]: timeit np.sin(X*Y)
48.1 ms ± 32.8 µs per loop (mean ± std. dev. of 7 runs,10 loops each)

现在对稀疏设置执行相同的操作:

In [138]: X,sparse=True)
In [139]: timeit np.sin(X*Y)
47.4 ms ± 204 µs per loop (mean ± std. dev. of 7 runs,10 loops each)
In [140]: X.shape
Out[140]: (1,1000)

X 与执行 np.linspace(-4,1000)[None,:] 相同。

时差可以忽略不计。 “稀疏”数组使用较少的内存,但是一旦我们执行 X*Y,结果就是 (1000,1000),并且 sin 的计算点总数相同。我更喜欢使用

np.sin(x[:,None],y[None,:])

风格,但速度优势来自于在编译的 numpy 方法中做同样多的事情,而不是来自“绕过网格网格”。缓慢的事情是对一对标量进行 sin(x*y) 计算 1000*1000 次:

双循环:

In [144]: np.array([[np.sin(x*y) for x in np.linspace(-4,1000)] for y in np.linspace(-4,1000)]).shape
Out[144]: (1000,1000)
In [145]: timeit np.array([[np.sin(x*y) for x in np.linspace(-4,1000)])
3.26 s ± 5.75 ms per loop (mean ± std. dev. of 7 runs,1 loop each)

或者使用散乱的 meshgrid 数组:

In [146]: X,1000))
In [147]: timeit np.array([np.sin(x*y) for x,y in zip(X.ravel(),Y.ravel())]).reshape(X.shape)
3.35 s ± 5.32 ms per loop (mean ± std. dev. of 7 runs,1 loop each)

我意识到您的 rastrigin 函数比 test 更复杂,但是如果您想最充分地使用 numpy,您需要避免这些 Python 级别的迭代。

如果您必须迭代,使用列表(和 math 函数)而不是 numpy 通常更快:

In [148]: import math
In [149]: timeit np.array([math.sin(x*y) for x,y in zip(X.ravel().tolist(),Y.ravel().tolist())]).
    reshape(X.shape)
435 ms ± 3.96 ms per loop (mean ± std. dev. of 7 runs,1 loop each)

实际上在这种情况下,使用 math.sin 进行标量计算可以节省大量时间。

链接中的其他答案之一建议使用 fromitermap。节省的时间不多:

In [153]: timeit np.fromiter(map(lambda x,y:math.sin(x*y),X.ravel().tolist(),Y.ravel().tolist()),float).
     ...: reshape(X.shape)
377 ms ± 7.43 ms per loop (mean ± std. dev. of 7 runs,1 loop each)
,

现在让我们看看您的 rastrigin

In [154]: def rastrigin(dim,x,A=10):
     ...:     return  dim + np.sum([(xi**2 - A * np.cos(2 * math.pi * xi)) for xi in x])
 

用 2 个标量调用它:

In [155]: rastrigin(2,[1.2,3.4])
Out[155]: 19.99999999999999

现在创建 2 个数组 - 为了清楚起见,我会让它们不同:

In [156]: x = np.array([1.2,1.3,1.4])
In [157]: y = np.array([3.4,3.5])

产生 (3,2) 结果的双循环解决方案。 [0,0] 项匹配 [155]:

In [158]: np.array([[rastrigin(2,[i,j]) for j in y] for i in x])
Out[158]: 
array([[20.,22.59983006],[26.43033989,29.03016994],[31.70033989,34.30016994]])

并采用您的网格方法

In [159]: X,Y = np.meshgrid(x,y,indexing='ij')   # NOTE indexing
In [160]: X
Out[160]: 
array([[1.2,1.2],[1.3,1.3],[1.4,1.4]])
In [161]: Y
Out[161]: 
array([[3.4,3.5],[3.4,3.5]])

这 2 个数组是 (3,2) 形状。

In [162]: Z = np.array( [rastrigin(2,[x,y]) for x,y in zip(np.ravel(X),np.ravel(Y))] ).reshape(X.shape)
In [163]: Z
Out[163]: 
array([[20.,34.30016994]])

我在之前的回答中表明,这两种方法的时间大致相同。

In [164]: Z = rastrigin(2,[x[:,:]])
/usr/local/lib/python3.8/dist-packages/numpy/core/fromnumeric.py:87: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this,you must specify 'dtype=object' when creating the ndarray.
  return ufunc.reduce(obj,axis,dtype,out,**passkwargs)
In [165]: Z
Out[165]: 
array([[20.,34.30016994]])

但是对于二维数组参数,我们只能得到一个值:

In [166]: Z = rastrigin(2,[X,Y])
In [167]: Z
Out[167]: 154.06084971874733

当我第一次看到 rastrigin 时,我想知道您为什么要迭代 x,认为 x 是一个数组。但是这个 x 实际上是一个包含 2 个值的列表,即调用表达式中的 [x,y]。所以对于 2 个变量,它可以写成:

def rastrigin1(dim,A=10):
    return dim + (x**2 - A * np.cos(2 * np.pi * x)) +\
                 (y**2 - A * np.cos(2 * np.pi * y))

这可以用标量、meshgrid 数组和 broadcasted 数组(当然还有任何一种迭代)调用:

In [181]: rastrigin1(2,1.2,3.4)
Out[181]: 19.99999999999999
In [182]: rastrigin1(2,X,Y)
Out[182]: 
array([[20.,34.30016994]])
In [183]: rastrigin1(2,x[:,:])
Out[183]: 
array([[20.,34.30016994]])

您的 rastrigin 的数组输入有问题,因为 np.sum 步骤必须首先从列表推导中生成数组。

In [186]: np.array([x[:,:]])
<ipython-input-186-ae2f6203b4cd>:1: VisibleDeprecationWarning: ...
Out[186]: 
array([array([[1.2],[1.3],[1.4]]),array([[3.4,3.5]])],dtype=object)
In [187]: np.array([X,Y])
Out[187]: 
array([[[1.2,1.4]],[[3.4,3.5]]])

[186] 收到参差不齐的数组警告,因为输入 a (3,1) 和 (1,2),形状不兼容。 [187] 从 2 个形状相同的数组中创建一个 3d 数组。

应用于 [186] 的

sum 是否将应用于 (2,3,2) [187] 的“外部”broadcasted sum producing a (3,2) result. sum` 展平并产生一个值。

np.sum([X,Y],axis=0) 的作用与 X+Y 相同。

这个版本的 rastrigin(未经测试)应该让我们两全其美,接受数组列表(不仅仅是 2),并获得正确的总和。

def rastrigin(dim,alist,A=10):
    res = dim
    for xi in alist:
        res += (xi**2 - A * np.cos(2 * np.pi * xi))
    return res

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。