微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何有效地过滤一个 numpy 数组,使其包含第二个数组中的行

如何解决如何有效地过滤一个 numpy 数组,使其包含第二个数组中的行

例如。实际上,a 有 100 多万行。 b 是固定大小。

a = array([[0,0],[1,1,1],[0,0]])

b = array([[0,1]])

我正在这样做:

matches=[]
for n,a_row in enumerate(a):
    for b_row in b:
        if np.all(a_row ==b_row ):
            matches.append(n)

a[matches]

似乎应该有更好的方法......

解决方法

您可以在这里非常有效地使用 np.packbits。它会将具有八列或更少列的任何数组转换为单列 uint8

ai = np.packbits(a,axis=-1)
bi = np.packbits(a,axis=-1)

对于最多为 64 的任意数量的列,您可以使用 np.min_scalar_type 使用适当的整数类型:

t = np.min_scalar_type(2**a.shape[-1] - 1)
ai = np.concatenate((ai,np.zeros((a.shape[0],t.itemsize % ai.shape[-1]),np.uint8)),axis=-1)
bi = np.concatenate((bi,t.itemsize % bi.shape[-1]),axis=-1)
ai = ai.view(t)
bi = bi.view(t)

您可以简单地将数组与 np.isin(或 np.in1d)进行比较:

mask = np.isin(ai,bi).ravel()

您现在可以像以前一样直接索引 a

a[mask,:]

你原来的例子变成了单行:

a[np.in1d(np.packbits(a,axis=-1),np.packbits(b,axis=-1))]

对于足够大的 a,您可能希望使用基于排序或 set 的方法进一步加快速度。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。