如何解决在numba jitted函数中使用ndarray子类“__matmul__”方法
我编写了一个名为 ndarray
的 BandedMatrix
子类。它用专用的 BLAS 例程 matmul
替换了标准的 Numpy dgbmv
方法。这部分工作正常。
但是,我正在努力让 Numba jitted 函数使用子类 __matmul__
方法而不是股票 Numpy。
一个最小的例子是:
from numpy import ndarray,asarray,ones
from scipy.linalg.blas import dgbmv
from numba import njit
class BandedMatrix(ndarray):
def __new__(cls,bands,kl,ku,m=None):
obj = asarray(bands).view(cls)
obj.kl = kl
obj.ku = ku
obj.n = bands.shape[1]
if m:
obj.m = m
else:
obj.m = bands.shape[1]
return obj
def __array_finalize__(self,obj):
if obj is None:
return
def __matmul__(self,x):
return dgbmv(self.m,self.n,self.kl,self.ku,1,self,x)
def __rmatmul__(self,x.T,trans=1)
@property
def T(self):
return BandedMatrix(self.view(ndarray).transpose(),m=self.n)
@property
def shape(self):
return (self.m,self.n)
@njit()
def matvec(A,x):
return A @ x
if __name__ == "__main__":
A = BandedMatrix(ones([3,10]),1)
x = ones(10)
y1 = matvec(A,x)
y2 = A @ x
我知道 Numba 无法在这里加速我的计算。我只是试图让这个例子尽可能短 - 因此无用的 matvec 函数。我尝试将我的子类插入到 this Numba example 中描述的函数中,但没有让它工作。
感谢任何帮助,我也乐于寻求解决方法。
我的主要目标是在能够使用 @
语法的同时对带状矩阵进行更快的矩阵乘法。此外,我希望让它与 Numba 一起使用,而无需重写 BLAS 函数。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。