如何解决关于每次调用jax.jit的函数重新编译
我是 jax
的新手。当我阅读文档时,我对 jit
的缓存行为感到困惑。
在 caching section 中,它说“避免在循环内调用 jax.jit。这样做有效地在每次调用时创建一个新的 f,每次都会编译它而不是重用相同的缓存函数”。但是,运行以下代码只会产生一种打印副作用:
import jax
def unjitted_loop_body(prev_i):
print("tracing...")
return prev_i + 1
def g_inner_jitted_poorly(x,n):
i = 0
while i < n:
# Don't do this!
i = jax.jit(unjitted_loop_body)(i)
return x + i
g_inner_jitted_poorly(10,20)
# output:
WARNING:absl:No GPU/TPU found,falling back to cpu. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
tracing...
Out[1]: DeviceArray(30,dtype=int32)
字符串“tracing...”只打印一次,看来jit
不会再次跟踪函数。
这是故意的吗?感谢您的帮助!
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。