微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在所有维度上进行 Numpy 迭代,但最后一个维度数量未知

如何解决在所有维度上进行 Numpy 迭代,但最后一个维度数量未知

物理背景

我正在开发一个函数,该函数可以为最多四维温度场(时间、经度、纬度、压力作为高度测量)中的每个垂直剖面计算一些指标。我有一个工作函数,它在单个位置获取压力和温度并返回指标(对流层顶信息)。我想用一个函数包装它,该函数将它应用于传递的数据中的每个垂直剖面。

问题的技术描述

我希望我的函数将另一个函数应用于对应于我的 N 维数组中最后一个维度的每个一维数组,其中 N

为什么我提出一个新问题

我知道有几个问题(例如,iterating over some dimensions of a ndarrayIterating over the last dimensions of a numpy arrayIterating over 3D numpy using one dimension as iterator remaining dimensions in the loopIterating over a numpy matrix with unknown dimension)询问如何迭代特定维度 如何迭代具有未知维度的数组。据我所知,这两个问题的结合是新的。例如,使用 numpy.nditer 我还没有找到如何只排除最后一个维度而不管剩余的维度数。

编辑

我试着做一个最小的、可重复的例子:

import numpy as np

def outer_function(array,*args):
    """
    Array can be 1D,2D,3D,or 4D. Regardless the inner_function 
    should be applied to all 1D arrays spanned by the last axis
    """
    # Unpythonic if-else solution
    if array.ndim == 1:
        return inner_function(array)
    elif array.ndim == 2:
        return [inner_function(array[i,:]) for i in range(array.shape[0])]
    elif array.ndim == 3:
        return [[inner_function(array[i,j,:]) for i in range(array.shape[0])] for j in range(array.shape[1])]
    elif array.ndim == 4:
        return [[[inner_function(array[i,k,:]) for i in range(array.shape[0])] for j in range(array.shape[1])] for k in range(array.shape[2])]
    else:
        return -1

def inner_function(array_1d):
    return np.interp(2,np.arange(array_1d.shape[0]),array_1d),np.sum(array_1d)

请假设实际的inner_function 不能修改为应用于多维,而只能应用于一维数组。

编辑结束

如果它对我拥有/想要拥有的代码结构有帮助:

def tropopause_ds(ds):
    """
    wraps around tropopause profile calculation. The vertical coordinate has to be the last one.
    """
    
    t = ds.t.values # numpy ndarray
    p_profile = ds.plev.values # 1d numpy ndarray

    len_t = ds.time.size
    len_lon = ds.lon.size
    len_lat = ds.lat.size
    nlevs = ds.plev.size

    ttp = np.empty([len_t,len_lon,len_lat])
    ptp = np.empty([len_t,len_lat])
    ztp = np.empty([len_t,len_lat])
    dztp = np.empty([len_t,len_lat,nlevs])

    # Approach 1: use numpy.ndindex - doesn't work in a list comprehension,slow
    for idx in np.ndindex(*t.shape[:-1]):
        ttp[idx],ptp[idx],ztp[idx],dztp[idx] = tropopause_profile(t[idx],p_profile)

    # Approach 2: use nested list comprehensions - doesn't work for different number of dimensions
    ttp,ptp,ztp,dztp = [[[tropopause_profile(t[i,:],p_profile) for k in range(len_lat)]
                            for j in range(len_lon)] for i in range(len_t)]

    return ttp,dztp

内部函数的结构如下:

def tropopause_profile(t_profile,p_profile):
    if tropopause found:
        return ttp,dztp
    return np.nan,np.nan,np.nan

我已经尝试了几个选项。计时案例中的测试数据具有形状 (2,360,180,105):

  • xarray's apply_ufunc 似乎将整个数组传递给函数。然而,我的内部函数基于获取一维数组,并且很难重新编程以处理多维数据
  • 嵌套列表推导工作并且似乎相当快,但是如果一维(例如时间)只有一个值(定时:8.53 s ±每个循环 11.9 毫秒(平均值 ± 标准差。7 次运行,每次 1 次循环)
  • 使用 numpy's nditer 在标准的 for 循环中工作,该循环使用列表理解来加速。但是,使用这种方法,该函数不会返回 4 个 ndarray,而是一个包含每个索引的四个返回值作为列表元素的列表。 (定时与列表理解:每个循环 1 分钟 4 秒 ± 740 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 次))

解决这个问题的一个丑陋的方法是检查我的数据有多少维,然后对正确数量的列表推导进行 if else 选择,但我希望 python 有更流畅的方法解决这个问题。如果有帮助,可以轻松更改尺寸的顺序。我在 2 核、10 GB 内存的 jupyterhub 服务器上运行代码

解决方法

我已经多次使用@hpaulj 的重塑方法。这意味着循环可以将整个数组迭代 1d 个切片。

简化了函数和数据以进行测试。

import numpy as np

arr = np.arange( 2*3*3*2*6 ).reshape( 2,3,2,6 )

def inner_function(array_1d):
    return np.array( [ array_1d.sum(),array_1d.mean() ])
    # return np.array( [np.interp(2,np.arange(array_1d.shape[0]),array_1d),np.sum(array_1d) ])

def outer_function( arr,*args ):
    res_shape = list( arr.shape )
    res_shape[ -1 ] = 2

    result = np.zeros( tuple( res_shape ) )  # result has the same shape as arr for n-1 dimensions,then two

    # Reshape arr and result to be 2D arrays.  These are views into arr and result
    work = arr.reshape( -1,arr.shape[-1] )
    res = result.reshape( -1,result.shape[-1] )

    for ix,w1d in enumerate( work ):  # Loop through all 1D 
        res[ix] = inner_function( w1d )
    return result 

outer_function( arr )

结果是

array([[[[[  15.,2.5],[  51.,8.5]],[[  87.,14.5],[ 123.,20.5]],...

         [[1167.,194.5],[1203.,200.5]],[[1239.,206.5],[1275.,212.5]]]]])

我相信这可以进一步优化,并考虑到应用程序所需的实际功能。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。