微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么在JAX和numpy中此功能较慢?

如何解决为什么在JAX和numpy中此功能较慢?

我具有以下numpy函数,如下所示,我正在尝试使用JAX进行优化,但出于某种原因,它的速度较慢。

有人可以指出我可以做些什么来提高性能吗?我怀疑这与Cg_new的列表理解有关,但是将其分开并不会在JAX中产生任何进一步的性能提升。

import numpy as np 

def testFunction_numpy(C,Mi,C_new,Mi_new):
    Wg_new = np.zeros((len(Mi_new[:,0]),len(Mi[0])))
    Cg_new = np.zeros((1,len(Mi[0])))
    invertCsensor_new = np.linalg.inv(C_new)

    Wg_new = np.dot(invertCsensor_new,Mi_new)
    Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))),(Wg_new[:,m])) for m in range(0,len(Mi[0]))] 

    return C_new,Mi_new,Wg_new,Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

%timeit testFunction_numpy(C,Mi_new)
#1000 loops,best of 3: 1.73 ms per loop

相当于JAX:

import jax.numpy as jnp
import numpy as np
import jax

def testFunction_JAX(C,Mi_new):
    Wg_new = jnp.zeros((len(Mi_new[:,len(Mi[0])))
    Cg_new = jnp.zeros((1,len(Mi[0])))
    invertCsensor_new = jnp.linalg.inv(C_new)

    Wg_new = jnp.dot(invertCsensor_new,Mi_new)
    Cg_new = [jnp.dot(((-0.5*(Mi_new[:,8)

C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)

jitter = jax.jit(testFunction_JAX) 

%timeit jitter(C,Mi_new)
#1 loop,best of 3: 4.96 ms per loop

解决方法

当JAX jit编译遇到Python控制流(包括列表推导)时,它将有效地拉平循环并逐步执行整个操作序列。这可能会导致jit编译时间变慢以及代码不理想。幸运的是,您的函数中的列表理解很容易用本地numpy广播表示。此外,您还可以进行其他两项改进:

  • 在计算它们之前,无需转发声明Wg_newCg_new
  • 在计算dot(inv(A),B)时,使用np.linalg.solve而不是显式计算逆函数会更加高效和精确。

对numpy和JAX版本进行了这三项改进,结果如下:

def testFunction_numpy_v2(C,Mi,C_new,Mi_new):
    Wg_new = np.linalg.solve(C_new,Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return C_new,Mi_new,Wg_new,Cg_new

@jax.jit
def testFunction_JAX_v2(C,Mi_new):
    Wg_new = jnp.linalg.solve(C_new,Cg_new

%timeit testFunction_numpy_v2(C,Mi_new)
# 1000 loops,best of 3: 1.11 ms per loop
%timeit testFunction_JAX_v2(C_jax,Mi_jax,C_new_jax,Mi_new_jax)
# 1000 loops,best of 3: 1.35 ms per loop

由于改进了实现,这两个函数的速度都比以前快了一点。但是,您会注意到,JAX在这里仍然比numpy慢。这在某种程度上是可以预料的,因为对于这种简单程度的功能,JAX和numpy都有效地生成了在CPU体系结构上执行的相同简短系列的BLAS和LAPACK调用。 numpy的引用实现根本没有太多改进的空间,而且使用如此小的数组,JAX的开销显而易见。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。