微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 JAX 中使用 VJP 时,有没有办法禁用前向评估?

如何解决在 JAX 中使用 VJP 时,有没有办法禁用前向评估?

我在我的项目中经常使用 VJP。它运行受雅可比计算约束的函数,并与可调用的 vjp 函数一起返回一个 primals_out。 例如,JAX 文档中的自定义 VJP 定义如下所示:

from jax import custom_vjp

@custom_vjp
def f(x,y):
  return jnp.sin(x) * y

def f_fwd(x,y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
  return f(x,y),(jnp.cos(x),jnp.sin(x),y)

def f_bwd(res,g):
  cos_x,sin_x,y = res # Gets residuals computed in f_fwd
  return (cos_x * g * y,sin_x * g)

f.defvjp(f_fwd,f_bwd)

在这个例子中,我们看到使用 VJP 时需要对前向函数进行评估。使用常规 VJP 而不是自定义定义的 VJP 时也是如此。但是,当函数的评估成本很高并且因为我已经在代码中的某处运行了该函数时,我不希望 VJP 再次评估该函数

那么,有什么方法可以表明在计算其 VJP 时不会评估函数吗?

解决方法

我认为在这种情况下没有任何方法可以明确禁用前向评估,但是如果您将计算包装在 jit 编译中,XLA 编译器将自动执行死代码消除并从计算图。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。