微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

加速 Python 中的嵌套 for 循环

如何解决加速 Python 中的嵌套 for 循环

我有一个用 python 编写的嵌套循环系统,如下所示:

for yt in range(dims[1]):
  for xt in range(dims[2]):
    for yp in range(dims[1]):
       for xp in range(dims[2]):
           corr[yt,xt,yp,xp] = sp.spearmanr(prec_tar[:,yt,xt],prec_pre[:,xp],axis=0)[0] 
           corr2[yt,prec_pre2[:,axis=0)[0]
           corr3[yt,prec_pre3[:,axis=0)[0]

其中 dims 的形状为 (1710,69,21) 并且 corr、corr2 和 corr3 都是 xarray Dataarray,其中包含形状为 (69,21,21) 的空 NumPy 数组。

现在,问题是这个脚本需要永远完成(~ 6+ 小时)。我不确定嵌套循环设置是否导致了它,或者 sp.spearmanr 是否是罪魁祸首(或者两者都有)。我正在寻找使这个运行更快的方法,具体来说,我想知道是否可以利用并行处理。也欢迎其他提示。提前致谢!

编辑:我还应该补充一点,prec_tar、prec_pre、prec_pre2 和 prec_pre3 都具有与 dims 相同的形状(即 (1710,21))。

解决方法

您可以使用以下代码段使您的代码并行。

import time
import itertools
import multiprocessing

yt = range(2)
xt = range(2)
yp = range(2)
xp = range(2)

param_list = list(itertools.product(yt,xt,yp,xp))

def task(args):
    print(args)
    # task
    time.sleep(1)
    return args

pool = multiprocessing.Pool()

response = pool.map(task,param_list)
print(response)
,

您可以在矢量化而不是循环代码时加快速度。

尝试使用矢量化和并行化 spearmanr 函数的 xski​​llscore。 https://xskillscore.readthedocs.io/en/stable/api/xskillscore.spearman_r.html#xskillscore.spearman_r

,

这是基于@aaron.spring 建议的针对此问题的有效解决方案。我希望有一天这对某人有所帮助。

# Problem at hand: Very slow.
t1 = time.time()
for i in range(dims[1]):   #dims = ((1000,4,5))
    for j in range(dims[2]):
        for x in range(dims[1]):
            for y in range(dims[2]):
                acorrb[i,j,x,y] = spearmanr(a[:,i,j],b[:,y],dim='time')
t2 = time.time()
print(t2-t1)  # 0.3600752353668213

# Faster solution based on xarray's vectorized indexing and using  xskillscore.spearman_r instead of spearmanr from scipy.stats. 

ind_i = xr.DataArray(range(dims[1]),dims=['i'])
ind_j = xr.DataArray(range(dims[2]),dims=['j'])
ind_x = xr.DataArray(range(dims[1]),dims=['x'])
ind_y = xr.DataArray(range(dims[2]),dims=['y'])

t3 = time.time()
acorrb2[ind_i,ind_j,ind_x,ind_y]=spearmanr(a[:,ind_i,ind_j],ind_y],dim='time')
t4 = time.time()
print(t4-t3) #0.07205533981323242

快 5 倍以上。

print((acorrb.values==acorrb2.values).all()) #True

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。