如何解决JAX 仅在 jit
我正在使用 JAX,我想执行类似的操作
@jax.jit
def fun(x,index):
x[:index] = other_fun(x[:index])
return x
这不能在 jit
下执行。有没有办法用 jax.ops
或 jax.lax
做到这一点?
我想过使用 jax.ops.index_update(x,idx,y)
,但我无法找到一种计算 y
的方法,而不会再次遇到同样的问题。
解决方法
您的实施似乎存在两个问题。首先,切片产生动态形状的数组(不允许在即时代码中)。其次,与 numpy 数组不同,JAX 数组是不可变的(即数组的内容不能改变)。
您可以通过组合 static_argnums
和 jax.lax.dynamic_update_slice
来克服这两个问题。下面是一个例子:
def other_fun(x):
return x + 1
@jax.partial(jax.jit,static_argnums=(1,))
def fun(x,index):
update = other_fun(x[:index])
return jax.lax.dynamic_update_slice(x,update,(0,))
x = jnp.arange(5)
print(fun(x,3)) # prints [1 2 3 3 4]
本质上,上面的例子使用 static_argnums
来指示函数应该为不同的 index
值重新编译,jax.lax.dynamic_update_slice
创建一个 x
的副本,并在 { {1}}。
@rvinas 的 previous answer 使用 dynamic_slice
如果您的索引是静态的,则效果很好,但您也可以使用 jnp.where
使用动态索引来完成此操作。例如:
import jax
import jax.numpy as jnp
def other_fun(x):
return x + 1
@jax.jit
def fun(x,index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask,other_fun(x),x)
x = jnp.arange(5)
print(fun(x,3))
# [1 2 3 3 4]
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。