微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 TF

如何解决在 TF

我有一个模型,我想计算它的梯度 w.r.t.输入。计算需要内存,因此我想将其拆分为批处理。

因为我关心计算时间,所以我想将所有内容都包含在 tf.function 中。

这是我会做的一个例子:

import tensorflow as tf
import tensorflow_probability as tfp

def model(sample):
    # Just a trivial example. This function in reality takes more arguments and creates
    # a large computational graph
    return tf.reduce_logsumexp(tfp.distributions.normal(0.,1.).log_prob(sample),axis=1)

input_ = tf.random.uniform(maxval=1.,shape=(100,10000000))

compiled_model = tf.function(model)

def get_batches(vars_,batch_size=10):
    current_beginning = 0
    all_elems = vars_[0].shape[0]
    while current_beginning < all_elems:
        yield tf.Variable(vars_[current_beginning:current_beginning+batch_size])
        current_beginning += batch_size

res = []        

for batch in get_batches(input_,batch_size=1):
    with tf.GradientTape() as tape_logprob:
        tape_logprob.watch(batch)
        log_prob = compiled_model(batch)
    
    res.append(tape_logprob.gradient(log_prob,batch))
    

如果您运行此代码,您会发现它会在 XLA 编译期间导致回溯,并严重影响性能

WARNING:tensorflow:5 次调用 触发了 tf.function 回溯。跟踪是昂贵的,并且过多的跟踪可能是由于 (1) 在循环中重复创建 @tf.function,(2) 传递具有不同形状的张量,(3) 传递 Python 对象而不是张量。对于(1),请在循环之外定义您的@tf.function。对于 (2),@tf.function 有 Experiment_relax_shapes=True 选项,可以放宽参数形状,避免不必要的回溯。对于 (3),请参阅 https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_argshttps://www.tensorflow.org/api_docs/python/tf/function 了解更多详情。

我不明白为什么回溯发生在这里。遵循警告中提到的几点: 1)。我没有在循环中定义 tf.function (虽然我在循环中运行它)。 2)。输入张量的形状总是相同的,因此我相信编译应该只发生一次。 3)。我不使用普通的 Python 对象。

在这里遗漏了什么细微差别?如何使这个例子工作?

在进行实验时,我注意到我可以通过将 log_prob = compiled_model(batch) 包装成一个简单的 tf.map_fn 来消除警告消息,但与非批处理版本相比,我仍然观察到性能下降很大计算。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?