微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

无法使用 Numba 优化 Fractal 代码

如何解决无法使用 Numba 优化 Fractal 代码

我正在编写代码来可视化 Mandelbrot 集和其他分形。下面是正在运行的代码片段。代码按原样运行得很好,但我正在尝试优化它以更快地制作更高分辨率的图像。我曾尝试在 fractal() 上使用缓存,以及来自 Numba 的 @jit@njit。缓存导致崩溃(我假设是内存溢出),@jit 只是将我的程序的执行速度减慢了 6 倍。我也知道有许多数学方法可以让我的代码运行更快,正如我在维基百科页面上看到的那样,但我想看看我是否可以获得上述方法之一或其他一些替代方法

为了连续创建多个图像(制作缩放动画,就像这个)我已经实现了多处理(似乎一次运行 9 个进程)但我不知道如何在创建中实现相同的单个高分辨率图像。

这是我的代码片段:

import numpy as np
import cv2
import cmath
import math

# pick the fractal
def fractal(z,c):
# Mandelbrot
    if fractal_type == 0:
        return z**d + c
# Burning Ship
    if fractal_type == 1:
        return complex(abs(z.real),abs(z.imag))**d + c

#naive escape time algorithm
def naive_escape(arr):
    h = arr[0]
    w = arr[1]
    d = arr[2]
    zoom = pow(1.5,arr[3]) * pow(10,int(np.log10(h)))
    x_cen = arr[4]
    y_cen = arr[5]

    for i in range(w):
        sys.stdout.write("\r{0:03}%".format(np.round(i/w * 100,4)))
        sys.stdout.flush()

        for j in range(h):
            it = 0
        #coordinates
            cx = i - int(w/2)
            cy = j - int(h/2)
        #scaling
            sx = (cx / (zoom)) + x_cen
            sy = (cy / (zoom)) - y_cen

            c = complex(sx,sy)
            z = complex(0.0,0.0)

            while ((z.real)**2 + (z.imag)**2 <= 2**d) and (it < max_it):
                z = fractal(z,c)
                it += 1

            img[j][i] = color_dict[it]

    sys.stdout.write("\n")

    name = "fractal"

    cv2.imwrite("{}.png".format(name),img)
    print("\n{} created!\n".format(name),fractal_type)


我应该澄清一下,着色函数 naive_escape() 接受数组输入的原因是因为我实现了多处理。由于多处理中的 map() 只允许我们用一个输入映射函数,所以我只传递一个包含所有输入值的数组。

上面粘贴的代码是来自一个更大文件的片段,所以请原谅任何语法错误

任何有助于加快我的代码速度的帮助将不胜感激!

解决方法

This older answer 专门处理矢量化,但可以进行一些额外的优化。

你可以从 Numpy 向量化开始,方便但不是很快:

@np.vectorize
def mandelbrot_numpy(c: complex,max_it: int) -> int:
    z = c
    for i in range(max_it):
        if abs(z) > 2:
            return i
        z = z**2 + c
    return 0

或者 Numba 向量化,将速度提高一个数量级:

@nb.vectorize([nb.u2(nb.c16,nb.i8)])
def mandelbrot_numba(c: complex,max_it: int) -> int:
    z = c
    for i in range(max_it):
        if abs(z) > 2:
            return i
        z = z**2 + c
    return 0

然后你可以应用一些常用的优化:

@nb.vectorize([nb.u2(nb.c16,nb.u2)])
def mandelbrot_numba_opt(c: complex,max_it: int) -> int:
    x = cx = c.real
    y = cy = c.imag
    for i in range(max_it):
        x2 = x*x
        y2 = y*y
        if x2 + y2 > 4:
            return i
        y = (x+x)*y + cy
        x = x2 - y2 + cx
    return 0

你也可以并行化它(在这个例子中按行):

@nb.njit([nb.u2[:,:](nb.c16[:,:],nb.u2)],parallel=True)
def mandelbrot_parallel(c: np.ndarray,max_it: int) -> np.ndarray:
    result = np.zeros_like(c,dtype=nb.u2)
    for row in nb.prange(len(c)):
        result[row] = mandelbrot_numba_opt(c[row],max_it)
    return result

1000x1000 阵列上的一些计时:

N = 1000
x = np.linspace(-2,2,N).reshape((1,-1))
y = x.T
c = x + 1j * y

%timeit mandelbrot_numpy(c,99)
1.59 s ± 40.9 ms per loop (mean ± std. dev. of 7 runs,1 loop each)
%timeit mandelbrot_numba(c,99)
100 ms ± 406 µs per loop (mean ± std. dev. of 7 runs,10 loops each)
%timeit mandelbrot_numba_opt(c,99)
35 ms ± 140 µs per loop (mean ± std. dev. of 7 runs,10 loops each)
%timeit mandelbrot_parallel(c,99)
10.9 ms ± 64.3 µs per loop (mean ± std. dev. of 7 runs,100 loops each)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。