微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 numpy 从每个最里面的维度中选择一个元素

如何解决使用 numpy 从每个最里面的维度中选择一个元素

我有一个三维 numpy 源数组和一个二维 numpy 索引数组。

例如:

src = np.array([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])
idx = np.array([[0,1],[1,2]])

我想得到一个二维数组,其中每个元素代表该位置最内维的索引值:

array([[1,5],[8,12]])

我如何用 numpy 做到这一点?

解决方法

你可以试试np.take,这里是documentation

但是,您应该在展平所有元素后计算数组的索引。例如你应该使用

src = np.array([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])
idx = np.array([[0,4],[7,11]])

# Wanted result
res = np.take(src,idx)

其中 src 被视为 [1,3,4,6,7,9,10,12]

你也可以试试np.take_along_axis,这里是documentation

使用此方法需要您的 srcidx 处于同一维度,因此,您应该先解压src挤压 res

# Unsqueezed the last dim
idx = np.expand_dims(idx,axis=-1)

# Squeeze the last dim
res = np.take_along_axis(src,idx,axis=2).squeeze(-1)
,

您可以使用 np.choose 方法稍加改造:

np.choose(idx.reshape((1,2)),src.transpose()).reshape((2,2))

>>>> array([[ 1,8],[ 5,12]])
,

直接索引:

src[np.arange(2)[:,None],np.arange(2),idx]

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。