如何解决线性空间中的Numpy矩阵乘法最大值
我有两个大数组 A
是 n x k
,B
是 k x m
。
我想计算 np.argmax(A @ B,axis=1)
,即每行的最大列数。
不幸的是,如果我天真地这样做,numpy 将计算整个数组 A @ B
,这需要 n x m
内存——比我拥有的还要多。
应该可以在没有额外内存的情况下做到这一点,只需单独计算每个条目并保持最大值。
可以在 numpy 中做到这一点吗?
解决方法
正如评论中所讨论的,numpy 不能完全按照您的意愿行事。但是,如果您有足够的内存,则可以遍历 A
的行并逐段进行矩阵乘法,在循环的每次迭代中收集结果 1xk
的 argmax。这将是完全使用 numpy 和从头开始滚动您自己的实现之间的一种折衷。类似于以下内容。您可以使用 numba 来提高一点速度。
import numpy as np
from numba import jit
a = np.random.randn(500,200)
b = np.random.randn(200,1000)
def max_by_row(a,b):
out = np.zeros((a.shape[0],),dtype='int64')
for idx in range(a.shape[0]):
out[idx] = np.argmax(a[idx,:] @ b)
return out
@jit
def max_by_row_jit(a,:] @ b)
return out
比较时间:
%timeit np.argmax(a@b,axis=1)
4.6 ms ± 226 µs per loop (mean ± std. dev. of 7 runs,100 loops each)
%timeit max_by_row(a,b)
12.2 ms ± 233 µs per loop (mean ± std. dev. of 7 runs,100 loops each)
%timeit max_by_row_jit(a,b)
8.85 ms ± 135 µs per loop (mean ± std. dev. of 7 runs,100 loops each)
,
我最终使用了 vq
function in scipy
。它需要两个矩阵,A
和 B
,对于 a
的每一行 A
,它告诉您 B
的哪一行更接近。
这为我节省了 20 倍的时间,甚至不考虑内存,就我而言,其中 B
有大约一百万行。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。