如何解决Mandelbrot Numba/Numpy 矢量化?
我使用 kivy 在 Python 中编写了一个交互式 mandelbrot 渲染器,您可以在其中使用鼠标指针进行缩放,并正在尽我所能对其进行优化。我目前使用这个实现来渲染设置/缩放(这是一个小片段,只是用来渲染它的两个函数):
import numba as nb
import numpy as np
@nb.njit(cache= True,parallel = True)
def mandelbrot(c_r,c_i,maxIt): #mandelbrot function
z_r = 0
z_i = 0
z_r2 = 0
z_i2= 0
for x in nb.prange(maxIt):
z_i = 2 * z_r * z_i + c_i
z_r = z_r2 - z_i2 + c_r
z_r2 = z_r * z_r
z_i2 = z_i * z_i
if z_r2 + z_i2 > 4:
return x
return maxIt
@nb.njit(cache= True,parallel = True)
def DrawSet(W,H,xStart,xdist,yStart,ydist,maxIt):
array = np.zeros((H,W,3),dtype=np.uint8) #array that holds 'hsv' tuple for every pixel
for x in nb.prange(0,W):
c_r = (x/W)* xdist + xStart #some math to calculate real part
for y in range (0,H):
c_i = -((y/H) * ydist + yStart) #some more math to calculate imaginary part
cIt = mandelbrot(c_r,maxIt)
color = int((255 * cIt) / maxIt)
array[y,x] = (color,255,255) #adds hue value
return array #returns hsv array,gets later displayed using PIL
我目前的表现相当不错。它可以在大约 0.08 - 0.09 秒内渲染一个 500x500 的区域,其中每个点都有界(所以基本上是黑色图片,最坏的情况),迭代 300 次。我将 Numba JIT 与并行范围函数“prange()”一起使用,这有很大帮助。
但是,我听说矢量化通常是渲染此类分形的最快方法。经过大量研究(我对矢量化很陌生),我设法将这个实现放在一起:
import numba as nb
import numpy as np
def DrawSet(W,xEnd,yEnd,maxIt):
array = np.zeros((H,dtype = np.uint8) # 3D array containing 'hsv' tuple (hue,saturation,value) of each pixel
x = np.linspace(xStart,W).reshape((1,W)) #scaling horizontal pixels to x-axis
y = np.linspace(yStart,H).reshape((H,1)) #scaling vertical pixels to y-axis
c = x + 1j * y #creating complex plane out of x axis (real) and y axis (imaginary)
z = np.zeros(c.shape,dtype= np.complex128)
div_time = np.zeros(z.shape,dtype= int)
m = np.full(c.shape,True,dtype= bool)
div_time = loop(z,c,div_time,m,maxIt)
array[:,:,0] = (div_time/maxIt) * 255 -20 #adding 'hue' value
array[:,1] = 255 #adding 'saturation' value
array[:,2] = 255 #adding 'value'
return array
@nb.vectorize(nb.int64[:,:](nb.complex128[:,:],nb.complex128[:,nb.int64[:,nb.boolean[:,nb.int64))
def loop(z,maxIt):
for i in range(maxIt):
z[m] = z[m]**2 + c[m]
diverged = np.greater(np.abs(z),2,out=np.full(c.shape,False),where=m)
div_time[diverged] = i
m[np.abs(z) > 2] = False
return div_time
没有@nb.vectorize 装饰器,它运行得非常慢。 (500x500 的最坏情况为 4 秒,300 It。)。使用 @nb.vectorize 装饰器,我收到此错误:
Traceback (most recent call last):
File "Mandelbrot.py",line 13,in <module>
from test import DrawSet
File "C:\Users\User\Documents\Code\Python\Mandelbrot-GUI\test.py",line 25,in <module>
def loop(z,maxIt):
File "C:\Users\User\AppData\Local\Programs\Python\python38\lib\site-packages\numba\np\ufunc\decorators.py",line 119,in wrap
for sig in ftylist:
TypeError: 'Signature' object is not iterable
我做错了什么?我是否以正确的方式定义了所有的 numba 签名? 这种矢量化方法会超过我当前的实现吗?
我会感谢每一个建议!提前致谢。
解决方法
您的实现已经矢量化了!
矢量化的想法是创建 universal functions 对数组进行元素操作。您只需定义对单个元素执行的操作,向量化机制将允许使用数组调用您的函数。
该函数计算单个点 c:
def mandelbrot_point(c,max_it):
z = 0j
for i in range(max_it):
z = z**2 + c
if abs(z) > 2:
return i
return 0
您可以使用 Numpy 对其进行矢量化:
@np.vectorize
def mandelbrot_numpy(c,max_it):
z = 0j
for i in range(max_it):
z = z**2 + c
if abs(z) > 2:
return i
return 0
或者您可以使用 Numba 对其进行矢量化。请注意,函数的签名描述了如何处理单个点:
@nb.vectorize([nb.int64(nb.complex128,nb.int64)])
def mandelbrot_numba(c,max_it):
z = 0j
for i in range(max_it):
z = z**2 + c
if abs(z) > 2:
return i
return 0
然后您可以使用任意维数的标量或数组调用向量化函数:
>>> p = 0.4+0.4j
>>> mandelbrot_point(p,99)
8
>>> mandelbrot_numpy(p,99)
array(8)
>>> mandelbrot_numba(p,99)
8
>>> x = np.linspace(-2,2,11)
>>> mandelbrot_numpy(x,99)
array([0,6,1,1])
>>> mandelbrot_numba(x,1])
>>> x = np.atleast_2d(x)
>>> y = x.T
>>> c = x + 1j * y
>>> mandelbrot_numpy(c,99)
array([[ 0,0],[ 0,3,5,17,8,1],0]])
>>> mandelbrot_numba(c,0]])
Numpy 的 vectorize 极大地简化了您的代码,但正如文档所说,它主要是为了方便,而不是为了性能。实现本质上是一个 for 循环。
根据我的测量,Numpy 向量化版本比您的原始实现略快,而 Numba 向量化版本快一个数量级。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。