如何解决在 Python 中可以更快地进行这种修复插值吗?
有一个用 Matlab (Inpaintn) 编写的修复函数,使用离散余弦变换来填充多维数据集中的缺失值,根据这篇论文 Garcia et. al. (2012)。我尝试将这段代码 (inpaintn.m) 移植到 Python 中,如下所示,
import numpy as np
from scipy.ndimage import distance_transform_edt
from scipy.fft import idctn,dctn
from tqdm import tqdm
def fill_nd(data,invalid=None):
if invalid is None: invalid = np.isnan(data)
ind = distance_transform_edt(invalid,return_distances=False,return_indices=True)
return data[tuple(ind)]
def InitialGuess(y,I):
z = fill_nd(y)
s0 = 3
return z,s0
def idctnn(y):
return idctn(y,norm='ortho')
def dctnn(y):
return dctn(y,norm='ortho')
def inpaint(xx,y0=[],n=100,m=2,verbose=False):
x = xx.copy() #as it changes x itself,so copying it to another variable.
sizx = np.shape(x)
d = np.ndim(x)
Lambda = np.zeros(sizx,dtype='float')
for i in range(0,d):
siz0 = np.ones(d,dtype='int')
siz0[i] = sizx[i]
Lambda = Lambda + np.cos(np.pi * np.reshape(np.arange(1,sizx[i] + 0.1) - 1,siz0) / sizx[i])
Lambda = 2 * (d - Lambda)
# Initial condition
W = np.isfinite(x)
if len(y0) == len(x):
y = y0
s0 = 3 # note: s = 10 ^ s0
else:
if np.any(~W):
if verbose: print('Initial Guess as Nearest Neighbors')
y,s0 = InitialGuess(x,np.isfinite(x).astype('bool'))
else:
y = x
s0 = 3
# return x
x[~W] = 0.
# Smoothness parameters: from high to negligible
s = np.logspace(s0,-6,n)
RF = 2. # Relaxation Factor
Lambda = Lambda ** m
if verbose: print('Inpainting .......')
for i in tqdm(range(n)):
Gamma = 1. / (1 + s[i] * Lambda)
y = RF * idctnn(Gamma * dctnn((W * (x - y)) + y)) + (1 - RF) * y
y[W] = x[W]
return y
代码运行良好,但我一直在努力寻找使代码运行得更快的方法,尤其是因为我的数据集很大。使用这种类型的插值的好处是我可以提供整个 3D 数据集(带有时间和网格坐标)来填充缺失值,而不是为每个时间坐标都做。
这是一个使用 python 的示例数据集
import numpy as np
#A 3D dataset with dimensions (time,latitude,longitude)
X = np.random.randn(1000,180,360)
# Randomly choosing indices to insert 64800 NaN values (say).
#NaNs can also be present as blocks in the data,not randomly dispersed as below.
index_nan = np.random.choice(X.size,64800,replace=False)
#Inserting NaNs.
X.ravel()[index_nan] = np.nan
我尝试了一些方法,但都没有成功,
- 使用 Numba
jit 装饰器使它变慢,即使使用 parallel/fastmath/vectorize,nopython=True
之类的选项也是如此。
- 使用 Cython
我尝试对这些函数中使用的所有变量进行排版,但它仍然比原生 python 实现慢。而且,在我的机器上编译 Cython 代码很麻烦。
- 使用 Numpy 矢量化
我已经用 scipy
函数替换了离散余弦变换函数及其逆函数,但我似乎无法想到将内部 for 循环向量化以使其更快的方法,以及它是否可能。
我试过分析我的代码,瓶颈似乎在使用 scipy
的离散余弦变换中。还有其他瓶颈,但对我来说没有意义。我还附上了一张用于分析的图像。
如果有可行的方法来加速这段代码,那真的会很有帮助。我在 Python 方面并不是很先进,但是我可以从中学到很多东西,尤其是我的问题的可行性。
解决方法
该算法适用于一个相当大的数组(不适合 CPU 缓存),这部分解释了为什么它有点慢。此外,众所周知,DCT/IDCT 是昂贵的操作。话虽如此,您可以通过使用 Numba 的 JIT 和 scipy 函数的 workers=-1
选项来并行化算法。此外,您可以通过就地工作来避免创建许多昂贵的临时数组。这是未经测试的结果代码:
# In-place computation
def idctnn(y):
return idctn(y,norm='ortho',workers=-1,overwrite_x=True)
# In-place computation
def dctnn(y):
return dctn(y,overwrite_x=True)
# In-place computation (writes in `Transformed`)
@nb.njit('void(float64[:,:,::1],float64[:,float64)',parallel=True)
def ComputeGammaTransform(Transformed,Lambda,sVal):
for i in nb.prange(Transformed.shape[0]):
for j in range(Transformed.shape[1]):
for k in range(Transformed.shape[2]):
Transformed[i,j,k] /= (1. + sVal * Lambda[i,k])
# Out-of-place computation (writes in `out`)
@nb.njit('void(float64[:,boolean[:,::1])',parallel=True)
def ComputeDctInput(out,x,y,W):
for i in nb.prange(out.shape[0]):
for j in range(out.shape[1]):
for k in range(out.shape[2]):
out[i,k] = W[i,k] * (x[i,k] - y[i,k]) + y[i,k]
# In-place computation (writes in `y`)
@nb.njit('void(float64[:,parallel=True)
def ComputeDctOutput(dctResult,RF):
for i in nb.prange(y.shape[0]):
for j in range(y.shape[1]):
for k in range(y.shape[2]):
y[i,k] = RF * dctResult[i,k] + (1.0 - RF) * y[i,k]
def ComputeSteps(Lambda,W,s,RF):
dctData = np.empty(Lambda.shape,dtype=Lambda.dtype)
for i in tqdm(range(s.shape[0])):
ComputeDctInput(dctData,W)
dctnn(dctData)
ComputeGammaTransform(dctData,s[i])
idctnn(dctData)
ComputeDctOutput(dctData,RF)
此代码在我的机器上快 5 倍。您可以使用简单精度而不是双精度进一步加快速度。这使得最终代码比我机器上的原始代码快 7.5 倍。
我或许可以通过基于 GPU 的计算进一步加速代码。困难的部分是在 Python 中找到一个支持正交归一化的 DCT/IDCT 的 GPU 实现。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。