微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

创建一个 3D 零张量,在 numpy/jax 中的每个切片上随机放置一个“1”

如何解决创建一个 3D 零张量,在 numpy/jax 中的每个切片上随机放置一个“1”

我需要创建一个像这样的 3D 张量 (5,3,2)

array([[[0,0],[0,1],0]],[[1,[[0,[1,0]]])

应该在每个切片中随机放置一个一个”(如果您认为张量是一条面包)。这可以使用循环来完成,但我想矢量化这部分。

解决方法

尝试生成一个随机数组,然后找到max

a = np.random.rand(5,3,2)
out = (a == a.max(axis=(1,2))[:,None,None]).astype(int)
,

最直接的方法可能是创建一个零数组,并将随机索引设置为 1。在 NumPy 中,它可能如下所示:

import numpy as np

K,M,N = 5,2
i = np.random.randint(0,K)
j = np.random.randint(0,N,K)
x = np.zeros((K,N))
x[np.arange(K),i,j] = 1

在 JAX 中,它可能看起来像这样:

import jax.numpy as jnp
from jax import random

K,2
key1,key2 = random.split(random.PRNGKey(0))
i = random.randint(key1,(K,),M)
j = random.randint(key2,N)
x = jnp.zeros((K,N)).at[jnp.arange(K),j].set(1)

同时保证每个切片有一个 1 的更简洁的选项是使用具有适当构造范围的随机整数的广播相等性:

r = random.randint(random.PRNGKey(0),1,1),M * N)
x = (r == jnp.arange(M * N).reshape(M,N)).astype(int)
,

您可以创建一个零数组,其中每个子数组的第一个元素为 1,然后在最后两个轴上permute

x = np.zeros((5,2)); x[:,0] = 1

rng = np.random.default_rng()
x = rng.permuted(rng.permuted(x,axis=-1),axis=-2)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。