如何解决如何有效地过滤一个 numpy 数组,使其包含第二个数组中的行
例如。实际上,a 有 100 多万行。 b 是固定大小。
a = array([[0,0],[1,1,1],[0,0]])
b = array([[0,1]])
我正在这样做:
matches=[]
for n,a_row in enumerate(a):
for b_row in b:
if np.all(a_row ==b_row ):
matches.append(n)
a[matches]
似乎应该有更好的方法......
解决方法
您可以在这里非常有效地使用 np.packbits
。它会将具有八列或更少列的任何数组转换为单列 uint8
:
ai = np.packbits(a,axis=-1)
bi = np.packbits(a,axis=-1)
对于最多为 64 的任意数量的列,您可以使用 np.min_scalar_type
使用适当的整数类型:
t = np.min_scalar_type(2**a.shape[-1] - 1)
ai = np.concatenate((ai,np.zeros((a.shape[0],t.itemsize % ai.shape[-1]),np.uint8)),axis=-1)
bi = np.concatenate((bi,t.itemsize % bi.shape[-1]),axis=-1)
ai = ai.view(t)
bi = bi.view(t)
您可以简单地将数组与 np.isin
(或 np.in1d
)进行比较:
mask = np.isin(ai,bi).ravel()
您现在可以像以前一样直接索引 a
:
a[mask,:]
你原来的例子变成了单行:
a[np.in1d(np.packbits(a,axis=-1),np.packbits(b,axis=-1))]
对于足够大的 a
,您可能希望使用基于排序或 set
的方法进一步加快速度。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。