如何解决用 numpy 减少循环 关于结果的差异:关于运行时:
我们正在尝试实现给定的 Modified Gram Schmidt 算法:
我们首先尝试以下面的方式实现第 5-7 行:
for j in range(i+1,N):
R[i,j] = np.matmul(Q[:,i].transpose(),U[:,j])
u = U[:,j] - R[i,j] * Q[:,i]
U[:,j] = u
为了减少运行时间,我们尝试用这样的矩阵运算替换循环:
# we changed the inner loop to matrix operations in order to improve running time
R[i,i + 1:] = np.matmul(Q[:,i],i + 1:])
U[:,i + 1:] = U[:,i + 1:] - R[i,i + 1:] * np.transpose(np.tile(Q[:,(N - i - 1,1)))
结果不一样,但非常相似。我们的二审有问题吗?
谢谢!
编辑: 完整的功能是:
def gram_schmidt2(A):
"""
decomposes a matrix A ∈ R into a product A = QR of an
orthogonal matrix Q (i.e. QTQ = I) and an upper triangular matrix R (i.e. entries below
the main diagonal are zero)
:return: Q,R
"""
N = np.shape(A)[0]
U = A.copy()
Q = np.zeros((N,N),dtype=np.float64)
R = np.zeros((N,dtype=np.float64)
for i in range(N):
R[i,i] = np.linalg.norm(U[:,i])
# Handling devision by zero by exiting the program as was advised in the forum
if R[i,i] == 0:
zero_devision_error(gram_schmidt._name_)
Q[:,i] = np.divide(U[:,R[i,i])
# we changed the inner loop to matrix operatins in oreder to improve running time
for j in range(i+1,N):
R[i,j])
u = U[:,i]
U[:,j] = u
return Q,R
和:
def gram_schmidt1(A):
"""
decomposes a matrix A ∈ R into a product A = QR of an
orthogonal matrix Q (i.e. QTQ = I) and an upper triangular matrix R (i.e. entries below
the main diagonal are zero)
:return: Q,i])
# we changed the inner loop to matrix operatins in oreder to improve running time
R[i,i + 1:])
U[:,1)))
return Q,R
当我们在矩阵上运行函数时:
[[ 1.00000000e+00 -1.98592571e-02 -1.00365698e-04 -1.45204974e-03
-9.95711793e-01 -1.77405377e-04 -7.68526195e-03]
[-1.98592571e-02 1.00000000e+00 -1.77809186e-02 -1.55937174e-01
-9.80881385e-03 -2.05317715e-02 -2.01456899e-01]
[-1.00365698e-04 -1.77809186e-02 1.00000000e+00 -1.87979660e-01
-5.12368040e-05 -8.35323206e-01 -4.59007949e-05]
[-1.45204974e-03 -1.55937174e-01 -1.87979660e-01 1.00000000e+00
-8.69848133e-04 -3.64095785e-01 -5.55408776e-04]
[-9.95711793e-01 -9.80881385e-03 -5.12368040e-05 -8.69848133e-04
1.00000000e+00 -9.54867422e-05 -5.92716161e-03]
[-1.77405377e-04 -2.05317715e-02 -8.35323206e-01 -3.64095785e-01
-9.54867422e-05 1.00000000e+00 -5.55505343e-05]
[-7.68526195e-03 -2.01456899e-01 -4.59007949e-05 -5.55408776e-04
-5.92716161e-03 -5.55505343e-05 1.00000000e+00]]
我们得到不同的输出:
对于克 shmidt 1:
问:
[[ 7.34036501e-01 -8.55006295e-04 -8.15634583e-03 -9.24967764e-02
-4.91879501e-02 -4.90769704e-01 1.58268518e-01]
[-2.78569770e-04 7.14001661e-01 -2.70586659e-03 -2.70735367e-02
5.78840577e-01 2.37376069e-01 1.97835647e-02]
[-2.48309244e-03 -2.34709092e-03 7.38351181e-01 2.63187853e-01
-3.35473487e-01 3.38823696e-01 3.36320600e-01]
[-4.27658449e-03 -2.12584453e-03 -6.70730760e-01 3.82666405e-01
-3.44451231e-01 3.46085878e-01 -7.71559024e-01]
[-6.53970073e-04 -7.00117873e-01 -2.68125144e-03 -2.31536583e-02
5.94568750e-01 2.38329853e-01 -2.76969906e-01]
[-9.26674350e-02 -5.07961588e-03 -6.97972068e-02 -8.79879575e-01
-2.78679804e-01 2.78781202e-01 0.00000000e+00]
[-6.72739327e-01 1.73894101e-04 2.25707383e-03 1.69052581e-02
-1.26723666e-02 -5.77668322e-01 -4.35238424e-01]]
R:
[[ 1.36233007e+00 1.11436069e-03 1.04418015e-02 1.27072186e-02
1.10993692e-03 -7.82681536e-02 -1.33081669e+00]
[ 0.00000000e+00 1.40055740e+00 5.29057231e-04 1.44628716e-03
-1.40014587e+00 3.57535802e-04 2.25417515e-03]
[ 0.00000000e+00 0.00000000e+00 1.35440586e+00 -1.33059602e+00
6.67148806e-04 -3.51561140e-02 2.23809829e-02]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 2.81147599e-01
1.33951520e-02 -9.55057795e-01 2.36910667e-01]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
3.37143743e-02 -1.97436093e-01 7.90539705e-02]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 3.40545951e-01 -1.75971454e-01]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 3.50740324e-16]]
对于克 shmidt 2:
问:
[[ 7.34036501e-01 -8.55006295e-04 -8.15634583e-03 -9.24967764e-02
-4.91879501e-02 -4.90769704e-01 4.55677949e-01]
[-2.78569770e-04 7.14001661e-01 -2.70586659e-03 -2.70735367e-02
5.78840577e-01 2.37376069e-01 -1.89865812e-01]
[-2.48309244e-03 -2.34709092e-03 7.38351181e-01 2.63187853e-01
-3.35473487e-01 3.38823696e-01 9.49329061e-02]
[-4.27658449e-03 -2.12584453e-03 -6.70730760e-01 3.82666405e-01
-3.44451231e-01 3.46085878e-01 -4.36691368e-01]
[-6.53970073e-04 -7.00117873e-01 -2.68125144e-03 -2.31536583e-02
5.94568750e-01 2.38329853e-01 -1.13919487e-01]
[-9.26674350e-02 -5.07961588e-03 -6.97972068e-02 -8.79879575e-01
-2.78679804e-01 2.78781202e-01 -1.51892650e-01]
[-6.72739327e-01 1.73894101e-04 2.25707383e-03 1.69052581e-02
-1.26723666e-02 -5.77668322e-01 -7.21490087e-01]]
R:
[[ 1.36233007e+00 1.11436069e-03 1.04418015e-02 1.27072186e-02
1.10993692e-03 -7.82681536e-02 -1.33081669e+00]
[ 0.00000000e+00 1.40055740e+00 5.29057231e-04 1.44628716e-03
-1.40014587e+00 3.57535802e-04 2.25417515e-03]
[ 0.00000000e+00 0.00000000e+00 1.35440586e+00 -1.33059602e+00
6.67148806e-04 -3.51561140e-02 2.23809829e-02]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 2.81147599e-01
1.33951520e-02 -9.55057795e-01 2.36910667e-01]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
3.37143743e-02 -1.97436093e-01 7.90539705e-02]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 3.40545951e-01 -1.75971454e-01]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 3.65463051e-16]]
解决方法
以下代码以更有效的方式执行您想要的操作:
Q_i = Q[:,i].reshape(1,-1)
R[i,i+1:] = np.matmul(Q_i,U[:,i+1:])
U[:,i+1:] -= np.multiply(R[i,i+1:],Q_i.T)
第一行只是为了方便,使代码更具可读性。
除了最后一行之外,一切都与您的原始提案相同。最后一行执行逐元素乘法,这最终是您在内循环的最后一行中所做的。
关于结果的差异:
你的代码没问题,两者都是一样的。在处理浮点数时,不应测试为 A == B
。相反,我建议您检查两个数组的不同之处。
特别是跑步
Q1,R1 = gram_schmidt2(A)
Q2,R2 = gram_schmidt1(A)
(Q1 - Q2).mean()
(R1 - R2).mean()
分别给出:
-5.4997372770547595e-09 and -5.2465803662044656e-18
已经非常接近于 0。 1e-18 低于 dtype np.float64 的错误,所以你很好。
如果您运行差异 3*0.1 - 0.3
(约 1e-17),您可以检查这一点
矩阵 Q 的误差较大,因为它来自浮点数之间的除法,如果矩阵元素的量级较小(这里有时就是这种情况),则会增加误差。
关于运行时:
在运行您的代码的两个版本时,我得到了相似的运行时间:(243 µs ± 25.5 µs
使用循环,241 µs ± 6.82 µs
使用您的第二个版本);而此处提供的代码实现了 152 µs ± 1.49 µs
。
我建议您使用 Numba,它是一个出色的速度优化器,通过将许多 Python 程序 JIT 编译为 C++ 和机器代码,它可以将许多 Python 程序提升 50-200 倍。
要安装 numba,只需执行一次 python -m pip install numba
。
以下是将您的算法应用于 numba 的代码,主要是在第一行函数之前只是一个 @numba.njit
装饰器。
在 numba 代码中,您可以只编写常规 Python 循环和任何数学计算,即使不使用 Numpy,您的最终代码也会非常快,大多数情况下甚至比任何 Numpy 代码都快。
我使用您的 gram_schmidt2()
函数作为基础,仅将 np.multiply()
替换为 np.dot()
,因为 Numba 似乎仅实现了 np.dot() 功能。
import numpy as np,numba
@numba.njit(cache = True,fastmath = True,parallel = True)
def gram_schmidt2(A):
"""
decomposes a matrix A ∈ R into a product A = QR of an
orthogonal matrix Q (i.e. QTQ = I) and an upper triangular matrix R (i.e. entries below
the main diagonal are zero)
:return: Q,R
"""
N = np.shape(A)[0]
U = A.copy()
Q = np.zeros((N,N),dtype=np.float64)
R = np.zeros((N,dtype=np.float64)
for i in range(N):
R[i,i] = np.linalg.norm(U[:,i])
# Handling devision by zero by exiting the program as was advised in the forum
if R[i,i] == 0:
assert False #zero_devision_error(gram_schmidt._name_)
Q[:,i] = np.divide(U[:,i],R[i,i])
# we changed the inner loop to matrix operatins in oreder to improve running time
for j in range(i+1,N):
R[i,j] = np.dot(Q[:,i].transpose(),j])
u = U[:,j] - R[i,j] * Q[:,i]
U[:,j] = u
return Q,R
a = np.array(
[[ 1.00000000e+00,-1.98592571e-02,-1.00365698e-04,-1.45204974e-03,-9.95711793e-01,-1.77405377e-04,-7.68526195e-03],[-1.98592571e-02,1.00000000e+00,-1.77809186e-02,-1.55937174e-01,-9.80881385e-03,-2.05317715e-02,-2.01456899e-01],[-1.00365698e-04,-1.87979660e-01,-5.12368040e-05,-8.35323206e-01,-4.59007949e-05],[-1.45204974e-03,-8.69848133e-04,-3.64095785e-01,-5.55408776e-04],[-9.95711793e-01,-9.54867422e-05,-5.92716161e-03],[-1.77405377e-04,-5.55505343e-05],[-7.68526195e-03,-2.01456899e-01,-4.59007949e-05,-5.55408776e-04,-5.92716161e-03,-5.55505343e-05,1.00000000e+00]],dtype = np.float64)
print(gram_schmidt2(a))
输出:
(array([[ 7.08543467e-01,-5.53704898e-03,-2.70026740e-04,-3.47742384e-03,1.84840892e-01,-5.24814365e-01,-4.33966083e-01],[-1.40711469e-02,9.68398634e-01,-2.12833250e-02,1.19174521e-01,-1.98433167e-01,-3.04695775e-02,-8.39439437e-02],[-7.11134597e-05,-1.72252300e-02,7.59699130e-01,-1.47406821e-01,-1.01157914e-01,3.77137817e-01,-4.98362473e-01],[-1.02884036e-03,-1.51071666e-01,-1.41567550e-01,9.02766638e-01,-8.55711320e-02,2.12039165e-01,-2.99775521e-01],[-7.05505086e-01,-2.31427937e-02,3.84334272e-04,-6.68149305e-03,1.96907249e-01,-5.24473268e-01,-4.33402818e-01],[-1.25699421e-04,-1.98909561e-02,-6.34318769e-01,-3.82156774e-01,-9.76029595e-02,4.04531367e-01,-5.27283410e-01],[-5.44534215e-03,-1.95250685e-01,1.53606576e-03,-5.45941927e-02,-9.27687435e-01,-3.12618155e-01,-2.30333938e-02]]),array([[ 1.41134602e+00,-1.99608442e-02,4.42769473e-04,8.12375351e-04,-1.41083897e+00,5.39174765e-04,-3.87373035e-03],[ 0.00000000e+00,1.03234256e+00,1.05802339e-02,-2.91464191e-01,-2.58368570e-02,2.96333339e-02,-3.90075744e-01],0.00000000e+00,1.31655051e+00,-5.01046784e-02,9.97649491e-04,-1.21693202e+00,5.90252943e-03],1.05107524e+00,-4.80557952e-03,-5.90160540e-01,-7.90098043e-02],2.03928769e-02,2.21268065e-02,-8.90241765e-01],1.30829767e-02,-2.99495426e-01],9.31764881e-10]]))
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。