如何解决在model.fit()中自定义train_step“OperatorNotAllowedInGraphError:不允许迭代`tf.Tensor`:AutoGraph确实转换了这个函数”
我正在尝试编写一个自定义 train_step 以在 tf.keras.Model.fit() 函数中使用。我正在关注tensor flow tutorial。根据我的理解,在 train_step 函数中,输入参数数据应该是我即将在 Model.fit() 函数中传递的训练数据集。我的数据集是 TFRecordDataset。我的数据集给出了三个特定的特征,即图像、标签和框。因此,在 train_step 函数中,我首先尝试从传递的数据参数中获取 img、labels 和 Box 参数。
def train_step(self,data):
print("printing data fed to train_step")
print(data)
img,label,gt_Boxes = data
if self.DEBUG:
if(img == None):
print("img input in train step is none")
with tf.GradientTape() as tape:
rpn_classification,rpn_regression = self(img,training=True)
self.tf_rpn_target_generation_layer(gt_Boxes,rpn_regression)
loss = self.rpn_loss_function(rpn_classification)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss,trainable_vars)
self.optimizer.apply_gradients(zip(gradients,trainable_vars))
loss_tracker.update_state(loss)
#mae_metric.update_state()
return [loss_tracker]
以上是我用于自定义 train_step 函数的代码。当我运行 fit 时,出现以下错误
OperatorNotAllowedInGraphError:不允许迭代 tf.Tensor
:AutoGraph 确实转换了此函数。这可能表明您正在尝试使用不受支持的功能。
我在训练数据集上使用了随机播放、缓存和重复操作。谁能帮我理解为什么会出现这个错误?
根据我之前的经验,我通常为数据集创建一个迭代器,然后通过 get_next 操作来获取特征。
编辑: 我已经尝试了以下程序,但没有产生任何结果
-
由于发送到 train_step 的数据是一个数据集对象,所以我使用 tf.raw_ops.IteratorGetNext 方法来访问迭代器的元素,它返回一个错误说 “类型错误:‘IteratorGetNext’操作的输入‘迭代器’的类型字符串与预期的资源类型不匹配。”
-
为了修复这个错误,我假设它可能是 tensorflow 返回迭代器图,因此无法访问元素,所以我在 model.compile() 函数中添加了 run_eagerly=True 参数,该函数返回了乱码打印出来同样的错误。
Epoch 1/5
printing data fed to train_step
Tensor("Shape:0",shape=(0,),dtype=int32)
Tensor("IteratorGetNext:0",shape=(),dtype=string)
解决方法
我找到了解决方案。传递给我的 step 函数的数据是一个迭代器,因此我必须使用 tf.raw_ops.IteratorGetNext 方法来访问迭代器的内容。
在执行此操作时,我最初遇到另一个错误,指出迭代器类型与预期的资源类型不匹配,仔细调试后我明白我必须对数据集执行的 read_tfrecords 映射不成功,这导致了数据集仍然包含格式为 tf.string 的未映射 tfrecords,这不是 train_Step 的预期资源类型。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。