如何解决使用tensorflow 2.x
给出类A
,[A() for _ in range(5)]
的实例的列表,我想随机选择其中一个(请参见以下代码作为示例)
class A:
def __init__(self,a):
self.a = a
def __call__(self):
return self.a
def f():
a_list = [A(i) for i in range(5)]
a = a_list[random.randint(0,5)]()
return a
f()
是否有一种方法可以用f
装饰@tf.function
而又不更改f
的功能,也不需要调用a_list
中的所有项目?
请注意,用f
直接修饰@tf.function
而不更改上面的代码是不可行的,因为它将始终返回相同的结果。另外,我知道可以通过首先调用a_list
中的所有元素,然后使用tf.gather_nd
对其进行索引来实现。但是,如果调用类型为A
的对象涉及到深度神经网络,则会产生大量开销。
解决方法
此刻我正在做同样的事情。这是到目前为止我得到的。如果有人知道更好的方法,我也会有兴趣听的。当我在昂贵的电话上运行它时,它比计算并返回所有值的速度要适当地快。
@tf.function
def f2():
a_list = [A(i) for i in range(5)]
idx = tf.cast(tf.random.uniform(shape=[],maxval=4),tf.int32)
return tf.switch_case(idx,a_list)
为了进行速度比较,我制作了A昂贵矩阵代数的调用方法。然后考虑一个替代函数,它调用每个函数:
@tf.function
def f3():
a_list = [A(i) for i in range(40)]
results = [a() for a in a_list]
return results
以40个元素运行f2:0.42643秒
使用40个元素运行f3:14.9153秒
因此,对于仅选择一个分支的预期40倍加速,这似乎是正确的。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。