如何解决将一段代码与 jax 跟踪隔离
提前为这个问题的含糊程度致歉(不幸的是,我对 jax 跟踪的工作原理知之甚少,无法更准确地表达它),但是:有没有办法将函数或代码块与 jax 跟踪完全隔离?
def f(x,y):
z = h(y)
return g(x,z)
本质上,我想调用 g(x,z)
,并在执行任何 jax 转换时将 z
视为常量。但是,设置参数z
非常笨拙,因此使用辅助函数h
将更易于指定的输入y
转换为g
所需的格式.我希望 jax 将 h
视为不可追踪的黑匣子,因此对特定 jit(lambda x: f(x,y0))
执行 y0
与第一次计算 z0 = h(y0)
相同使用 numpy
,然后执行 jit(lambda x: g(x,z0))
(与 grad
或任何其他函数转换类似)。
在我的代码中,我已经编写了 h
只使用标准的 numpy
(我认为这可能会导致黑盒行为),但是 jit(lambda x: f(x,y0))
的编译时间是明显长于 jit(lambda x: g(x,z0))
的 z0 = h(y0)
编译时间。我有一种感觉,编译时间可能与 jax 跟踪 h
中的许多循环有关,但我不确定。
一些附加说明:
想法?
为清楚起见编辑添加:我知道可能有办法解决这个问题,例如f
是顶级函数。在这种情况下,让用户首先调用 h
来“预编译”对 g
的 jax 友好输入,然后自由地执行他们想要的任何 jax 转换并不是什么大问题lambda x: g(x,z0)
。但是,我想象的情况是,我们有许多要链接在一起的函数,它们具有与 f
相同的结构,其中存在一些对 jax 不友好的输入/计算,但这些输入将始终被处理作为计算的 jax 部分的常量。原则上,我们总是可以提取这些预先计算来设置 jax 的东西,但是如果我们有一个将相互调用的此类函数的重要集合,这似乎很困难。
是否有某种方法可以控制跟踪 f
的方式,以便在跟踪时它知道只评估 z=h(y)
(而不是跟踪 h
)然后继续跟踪 {{1} }?
解决方法
f_jitted = jax.jit(f,static_argnums=1)
static_argnums 参数可能有帮助
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
您可以使用诸如 static_argnums
之类的转换参数来代替 jit
来避免跟踪转换函数的特定参数,但代价是需要更多的重新编译。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。