微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

选择JAX矩阵子集的最快方法是什么?

如何解决选择JAX矩阵子集的最快方法是什么?

假设我有一个2D矩阵,我想在直方图中绘制其值。为此,我需要做类似的事情:

list_1d = matrix_2d.reshape((-1,)).tolist()

然后使用该列表绘制直方图。到目前为止,一切都很好,只是我要排除的原始矩阵中有项目。为了简单起见,假设我有一个像这样的列表:

exclude = [(2,5),(3,4),(6,1)]

因此,list_1d应该具有矩阵中的所有项目,而没有exclude所指向的项目(exclude的项目是行索引和列索引)。

顺便说一句,matrix_2d一个JAX数组,这意味着其内容位于GPU中。

解决方法

执行此操作的一种方法是创建用于选择所需阵列子集的遮罩阵列。掩码索引操作返回所选数据的一维副本:

import jax.numpy as jnp
from jax import random
matrix_2d = random.uniform(random.PRNGKey(0),(10,10))
exclude = [(2,5),(3,4),(6,1)]

ind = tuple(jnp.array(exclude).T)
mask = jnp.ones_like(matrix_2d,dtype=bool).at[ind].set(False)

list_1d = matrix_2d[mask].tolist()
len(list_1d)
# 97

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。