微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

将 3D Array 过滤为一个 1D Array 的最快方法

如何解决将 3D Array 过滤为一个 1D Array 的最快方法

我正在尝试按条件过滤多个 3D 数组,并将结果用作 scipy.stats.binned_statistic_2d()np.histogram2d() 的输入。

这是一个简单的例子:

import numpy as np
import xarray as xr

array = xr.DataArray(np.random.randn(20,2000,3000))

stacked = array.stack(z=('dim_0','dim_1','dim_2'))
res = stacked.where(stacked>0,drop=True)

stack 步骤的运行速度非常慢:

16.9 s ± 1 s per loop (mean ± std. dev. of 7 runs,1 loop each)

是否可以更快地获取 res 并将 dim_0dim_1dim_2 的值也保存为一维变量?

解决方法

你在这里使用 xarray 究竟是为了什么?

由于 scipy 和 numpy 都适用于裸 numpy 数组,您可以简单地将数据作为 numpy 数组,并使用 ravel 获得相同内存的一维视图。

如果您还不熟悉副本和视图,请参阅 numpy 手册:https://numpy.org/devdocs/reference/arrays.ndarray.html?highlight=view

这是闪电般的速度:

ravelled = array.data.ravel()

只要它是一个一维数组,并且你只需要布尔索引,numpy 就可以了:

result = ravelled[ravelled > 0]

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。