如何解决在 TF
我有一个模型,我想计算它的梯度 w.r.t.输入。计算需要内存,因此我想将其拆分为批处理。
因为我关心计算时间,所以我想将所有内容都包含在 tf.function
中。
这是我会做的一个例子:
import tensorflow as tf
import tensorflow_probability as tfp
def model(sample):
# Just a trivial example. This function in reality takes more arguments and creates
# a large computational graph
return tf.reduce_logsumexp(tfp.distributions.normal(0.,1.).log_prob(sample),axis=1)
input_ = tf.random.uniform(maxval=1.,shape=(100,10000000))
compiled_model = tf.function(model)
def get_batches(vars_,batch_size=10):
current_beginning = 0
all_elems = vars_[0].shape[0]
while current_beginning < all_elems:
yield tf.Variable(vars_[current_beginning:current_beginning+batch_size])
current_beginning += batch_size
res = []
for batch in get_batches(input_,batch_size=1):
with tf.GradientTape() as tape_logprob:
tape_logprob.watch(batch)
log_prob = compiled_model(batch)
res.append(tape_logprob.gradient(log_prob,batch))
如果您运行此代码,您会发现它会在 XLA 编译期间导致回溯,并严重影响性能:
WARNING:tensorflow:5 次调用
我不明白为什么回溯发生在这里。遵循警告中提到的几点:
1)。我没有在循环中定义 tf.function
(虽然我在循环中运行它)。
2)。输入张量的形状总是相同的,因此我相信编译应该只发生一次。
3)。我不使用普通的 Python 对象。
我在这里遗漏了什么细微差别?如何使这个例子工作?
在进行实验时,我注意到我可以通过将 log_prob = compiled_model(batch)
包装成一个简单的 tf.map_fn
来消除警告消息,但与非批处理版本相比,我仍然观察到性能下降很大计算。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。