微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在Numpy中向量化块式操作的建议

如何解决在Numpy中向量化块式操作的建议

我正在尝试实施一系列统计操作,并且需要帮助向量化我的代码

这个想法是从两个图像中提取NxN个补丁,计算这两个补丁之间的距离度量。

为此,我首先使用以下循环构建补丁:

params = []
for i in range(0,patch1.shape[0],1):
    for j in range(0,patch1.shape[1],1):
        window1 = np.copy(imga[i:i+N,j:j+N]).flatten()
        window2 = np.copy(imgb[i:i+N,j:j+N]).flatten()
        params.append((window1,window2))
print(f"We took {time()- t0:2.2f} seconds to prepare {len(params)/1e6} million patches.")

这大约需要10秒钟才能完成,并且我对预处理时间并不过分。接下来的步骤是我要优化的步骤。

在此之后,为了加快处理速度,我使用了多池计算实际结果。包含实际计算的函数如下:

@njit
def cauchy_schwartz(imga,imgb):
    p,_ = np.histogram(imga,bins=10)
    p = p/np.sum(p)
    q,_ = np.histogram(imgb,bins=10)
    q = q/np.sum(q)

    n_d = np.array(np.sum(p * q)) 
    d_d = np.array(np.sum(np.power(p,2) * np.power(q,2)))
    return -1.0 * np.log10( n_d,d_d)

我使用此结构来处理所有补丁:

def f(param):
    return cauchy_schwartz(*param)

with Pool(4) as p:
    r = list(tqdm.tqdm(p.imap(f,params),total=len(params)))

我确信必须有一些更优雅的方法,因为如果我将整个10Kpx x 10Kpx的图像发送到cauchy_schwartz函数中,它将在一秒钟内处理所有内容,但是即使是在4个核心需要很长时间。

我的思维模式是blockproc在Matlab中的工作方式-我最终以这种模式编写了这段代码。对于改善此代码性能的任何建议,我将不胜感激。

解决方法

通过使用apply_along_axis,您可以摆脱cauchy_schwartz。由于您不太担心预处理时间,因此假设您已获得包含扁平化补丁的数组params

params = np.random.rand(3,2,100)

您可以看到params的形状为(3,100),只是随机选择三个数字3、2和100来创建一个辅助数组,以演示使用{{1} }。 3对应于您拥有的色块数量(由色块形状和图像大小确定),2对应于两个图像,100对应于展平的色块。因此,apply_along_axis的轴是params,这与您的代码创建的列表(idx of patches,idx of images,idx of entries of a flattened patch)完全匹配

params

使用辅助数组params = [] for i in range(0,patch1.shape[0],1): for j in range(0,patch1.shape[1],1): window1 = np.copy(imga[i:i+N,j:j+N]).flatten() window2 = np.copy(imgb[i:i+N,j:j+N]).flatten() params.append((window1,window2)) ,这是我的解决方案:

params
,

首先,分析您的代码以识别瓶颈。您可以使用https://mg.pov.lt/profilehooks/。我认为瓶颈在于修补程序的创建,因为您正在为流程创建修补程序的副本。通过仅传递补丁程序的索引,您可以使用更少的内存:

params = []
for i in range(0,1):
        start,end = (i,i+N),(j,j+N)
        params.append((start,end))

然后,假设imgaimgb是全局的,则可以通过cauchy_schwartz函数创建补丁,如下所示:

@njit
def cauchy_schwartz(start,end):

    a,b = start; c,d = end
    window1 = np.copy(imga[a:b,c:d]).flatten()
    window2 = np.copy(imgb[a:b,c:d]).flatten()

    # process patches window1 and window2

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。