微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在 jit 编译的 jax 代码中执行非整数索引算法?

如何解决如何在 jit 编译的 jax 代码中执行非整数索引算法?

如果我们对数组索引执行非整数计算(然后转换为 int() ),似乎我们仍然无法将结果用作 jit 编译的 jax 代码中的有效索引。我们如何解决这个问题?

以下是一个最小的例子。具体问题:是否可以在不向 fun() 传递额外参数的情况下使命令 jnp.diag_indices(d) 起作用

在 Jupiter 单元中运行:

import jax.numpy as jnp
from jax import jit

@jit
def fun(t):
    d = jnp.sqrt(t.size**2)
    d = jnp.array(d,int)
    
    jnp.diag_indices(t.size)   # this line works
    jnp.diag_indices(d)        # this line breaks. Comment it out to see that d and t.size have the same dtype=int32 

    return t.size,d
    
fun(jnp.array([1,2]))    

解决方法

问题不在于 d 的类型,而是 d 是 jax 操作的结果,因此在 JIT 上下文中跟踪这一事实。在 JAX 中,数组的形状和大小不能依赖于跟踪的数量,这就是您的代码导致错误的原因。

为了解决这个问题,一个有用的模式是使用 np 操作而不是 jnp 操作来确保 d 是静态的并且不被跟踪:

import jax.numpy as jnp
from jax import jit

@jit
def fun(t):
    d = np.sqrt(t.size**2)
    d = np.array(d,int)
    
    jnp.diag_indices(t.size)
    jnp.diag_indices(d)

    return t.size,d
    
print(fun(jnp.array([1,2])))
# (DeviceArray(2,dtype=int32),DeviceArray(2,dtype=int32))

有关跟踪、静态值和类似主题的简要背景信息,How To Think In JAX 文档页面可能会有所帮助。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。