微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

JAX 仅在 jit

如何解决JAX 仅在 jit

我正在使用 JAX,我想执行类似的操作

@jax.jit
def fun(x,index):
    x[:index] = other_fun(x[:index])
    return x

这不能在 jit 下执行。有没有办法用 jax.opsjax.lax 做到这一点? 我想过使用 jax.ops.index_update(x,idx,y),但我无法找到一种计算 y方法,而不会再次遇到同样的问题。

解决方法

您的实施似乎存在两个问题。首先,切片产生动态形状的数组(不允许在即时代码中)。其次,与 numpy 数组不同,JAX 数组是不可变的(即数组的内容不能改变)。

您可以通过组合 static_argnumsjax.lax.dynamic_update_slice 来克服这两个问题。下面是一个例子:

def other_fun(x):
    return x + 1

@jax.partial(jax.jit,static_argnums=(1,))
def fun(x,index):
    update = other_fun(x[:index])
    return jax.lax.dynamic_update_slice(x,update,(0,))

x = jnp.arange(5)
print(fun(x,3))  # prints [1 2 3 3 4]

本质上,上面的例子使用 static_argnums 来指示函数应该为不同的 index 值重新编译,jax.lax.dynamic_update_slice 创建一个 x 的副本,并在 { {1}}。

,

@rvinas 的 previous answer 使用 dynamic_slice 如果您的索引是静态的,则效果很好,但您也可以使用 jnp.where 使用动态索引来完成此操作。例如:

import jax
import jax.numpy as jnp

def other_fun(x):
    return x + 1

@jax.jit
def fun(x,index):
  mask = jnp.arange(x.shape[0]) < index
  return jnp.where(mask,other_fun(x),x)

x = jnp.arange(5)
print(fun(x,3))
# [1 2 3 3 4]

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。