如何解决具有最大计数阈值的 numpy 的“nunique”有哪些快速实现?
我确定它在其他某个域中有一个名称(可能大约不同?)。
假设您想计算 numpy 数组中不同元素的数量,但您只关心低于某个阈值和高于阈值的数字,您只需返回它具有超过 thresh 唯一条目。这对于高数量数组特别有用,因为您不在乎有 10000 个条目,只是可能有 10 个以上的条目。
在编译语言中,这很容易实现。但是有哪些快速实现暴露给 Python 的?
天真的人可能会像这样尝试 numba:
@numba.jit(nopython=True)
def nunique_max_thresh(x,thresh=10):
seen = set()
for i in range(len(x)):
seen.add(x[i])
if len(seen) > thresh:
return thresh
return len(seen)
但不支持 set
用法。
Cython 是一个选项,但我想知道这是否已经在某个库或 Python 的其他地方完成。看起来瓶颈会做这种事情,但实际上并没有。
https://bottleneck.readthedocs.io/en/latest/reference.html
例如,考虑这些类型的数组:
import string
import numpy as np
np.random.seed(0)
a = np.random.choice(list(string.ascii_letters),1e7)
b = np.ones(int(1e7))
而您只想知道这个数组是否有 10 个或更多的唯一值。不要使用这些长度为 1 的字符串这一事实。
作为参考,这个运行。但可能不是最优的。
import numpy as np
cimport numpy as np
def nunique_truncated(np.ndarray x_in,np.int thresh=10):
seen = set()
for i in range(x_in.shape[0]):
seen.add(x_in[i])
if len(seen) >= thresh:
return thresh
解决方法
正如@hpaulj 建议的那样,您可以只使用 numba
而没有 集合或字典,这应该是合理的,因为用例专门针对较短的列表。显然,某些机制会因缓慢的包含查找而受到影响。
import numba
@numba.jit(nopython=True)
def nunique_truncated_numba(x_in,thresh=10):
seen = list()
for i,x in enumerate(x_in):
if x not in seen:
seen.append(x)
if len(seen) > thresh:
return len(seen)
return len(seen)
而且困难的情况确实是当您没有达到阈值时(您正在使用 python 进行矢量化扫描)。
In [6]: %timeit cud.nunique_truncated(b)
116 µs ± 304 ns per loop (mean ± std. dev. of 7 runs,10000 loops each)
In [7]: %timeit len(np.unique(b))
1.26 ms ± 2.64 µs per loop (mean ± std. dev. of 7 runs,1000 loops each)
如果有人有其他建议和技巧会感兴趣。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。