如何解决在model.fit的前向计算过程中需要将numpy数组传递给函数,但是tf.Tensor没有`.numpy()`属性
我已经构建了一个自定义 keras 模型,并且在它的前向传递过程中,它使用了另一个库中函数的输出。但是,此函数的参数必须是一个 numpy 数组。在 model.compile()
期间,我可以将 run_eagerly
参数设置为 True,然后我可以使用 EagerTensor 的 .numpy()
方法将输出从前向传递转换为 numpy,但这似乎计算效率不高,因为.numpy()
在我的网络中只需要一次。 如何将张量转换为仅用于一次计算的热切张量?这可能吗?我曾尝试使用 K.get_session()
和 hidden_layer_outputs = sess.run(hidden_layer_outputs)
,但这会引发“无法在 TensorFlow 图形函数中获取会话”错误。下面是一个说明我的问题的例子。
def third_party_library_fxn(np_arr):
"""Uses len to get the first dimension of the array.
:param np_arr: <class 'numpy.ndarray'>
"""
# First part of function is to get the 1st dimension of the array
first_dim = len(np_arr)
# This function does more comptuations on first_dim
# ....
# ....
# return results
import tensorflow as tf
from tensorflow.python.framework.ops import EagerTensor
from third_party_library import third_party_library_fxn
class CustomModel(tf.keras.Model):
def __init__(self,units,**kwargs):
self.dense = tf.keras.layers.Dense(units=units)
self.lambd = tf.keras.layers.Lambda(third_party_library_fxn)
def call(self,inputs):
hidden_layer_outputs = self.dense(inputs)
# Raises Error: This block executes if the model is NOT running eagerly (i.e.,during model.fit)
if not(isinstance(hidden_layer_outputs,EagerTensor)):
# You cannot calculate `len()` of `tf.Tensor`
outputs_from_third_party_library = self.lambd(hidden_layer_outputs)
# No Error: This block is executed if the model IS running eagerly
else:
outputs_from_third_party_library = self.lambd(hidden_layer_outputs.numpy())
引起错误的编译和拟合:
model = CustomModel(units=arbitrary_number)
# Compilation is NOT eager by default
model.compile(loss=arbitrary_loss,optimizer=arbitrary_optimizer,run_eagerly=False)
# Raises error
model.fit(arbitrary_tf_batch_data,epochs=arbitrary_epochs)
编辑 1:
编辑 2:
我尝试过的一种方法是将第三方库函数包装到一个 tf.keras.layers.Layer
类中,然后设置 dynamic=True
。这不能解决问题,因为 model.fit(...,run_eagerly=False)
失败并出现以下错误。同样的错误建议使用 tf.py_function
,但在之前的努力中,我发现将第三方库函数包装在 tf.py_function
中并没有解决问题(更多关于这方面的内容也尽快)。
ValueError:您的模型包含只能在 Eager Execution 中成功运行的层(使用 dynamic=True
构造的层)。您不能设置 run_eagerly=False
。
最好,
贾里德
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。