如何解决选择JAX矩阵子集的最快方法是什么?
假设我有一个2D矩阵,我想在直方图中绘制其值。为此,我需要做类似的事情:
list_1d = matrix_2d.reshape((-1,)).tolist()
然后使用该列表绘制直方图。到目前为止,一切都很好,只是我要排除的原始矩阵中有项目。为了简单起见,假设我有一个像这样的列表:
exclude = [(2,5),(3,4),(6,1)]
因此,list_1d
应该具有矩阵中的所有项目,而没有exclude
所指向的项目(exclude
的项目是行索引和列索引)。
顺便说一句,matrix_2d
是一个JAX数组,这意味着其内容位于GPU中。
解决方法
执行此操作的一种方法是创建用于选择所需阵列子集的遮罩阵列。掩码索引操作返回所选数据的一维副本:
import jax.numpy as jnp
from jax import random
matrix_2d = random.uniform(random.PRNGKey(0),(10,10))
exclude = [(2,5),(3,4),(6,1)]
ind = tuple(jnp.array(exclude).T)
mask = jnp.ones_like(matrix_2d,dtype=bool).at[ind].set(False)
list_1d = matrix_2d[mask].tolist()
len(list_1d)
# 97
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。