微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在model.fit()中自定义train_step“OperatorNotAllowedInGraphError:不允许迭代`tf.Tensor`:AutoGraph确实转换了这个函数”

如何解决在model.fit()中自定义train_step“OperatorNotAllowedInGraphError:不允许迭代`tf.Tensor`:AutoGraph确实转换了这个函数”

我正在尝试编写一个自定义 train_step 以在 tf.keras.Model.fit() 函数中使用。我正在关注tensor flow tutorial。根据我的理解,在 train_step 函数中,输入参数数据应该是我即将在 Model.fit() 函数中传递的训练数据集。我的数据集是 TFRecordDataset。我的数据集给出了三个特定的特征,即图像、标签和框。因此,在 train_step 函数中,我首先尝试从传递的数据参数中获取 img、labels 和 Box 参数。

def train_step(self,data):
        print("printing data fed to train_step")
        print(data)
        img,label,gt_Boxes = data
        if self.DEBUG:
            if(img == None):
                print("img input in train step is none")
        with tf.GradientTape() as tape:
            rpn_classification,rpn_regression = self(img,training=True)
            self.tf_rpn_target_generation_layer(gt_Boxes,rpn_regression)
            loss = self.rpn_loss_function(rpn_classification)
        
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss,trainable_vars)

        self.optimizer.apply_gradients(zip(gradients,trainable_vars))

        loss_tracker.update_state(loss)
        #mae_metric.update_state()
        return [loss_tracker]

以上是我用于自定义 train_step 函数代码。当我运行 fit 时,出现以下错误 OperatorNotAllowedInGraphError:不允许迭代 tf.Tensor:AutoGraph 确实转换了此函数。这可能表明您正在尝试使用不受支持功能

我在训练数据集上使用了随机播放、缓存和重复操作。谁能帮我理解为什么会出现这个错误

根据我之前的经验,我通常为数据集创建一个迭代器,然后通过 get_next 操作来获取特征。

编辑: 我已经尝试了以下程序,但没有产生任何结果

  1. 由于发送到 train_step 的数据是一个数据集对象,所以我使用 tf.raw_ops.IteratorGetNext 方法来访问迭代器的元素,它返回一个错误说 “类型错误:‘IteratorGetNext’操作的输入‘迭代器’的类型字符串与预期的资源类型不匹配。”

  2. 为了修复这个错误,我假设它可能是 tensorflow 返回迭代器图,因此无法访问元素,所以我在 model.compile() 函数添加了 run_eagerly=True 参数,该函数返回了乱码打印出来同样的错误

Epoch 1/5
printing data fed to train_step
Tensor("Shape:0",shape=(0,),dtype=int32)
Tensor("IteratorGetNext:0",shape=(),dtype=string)

解决方法

我找到了解决方案。传递给我的 step 函数的数据是一个迭代器,因此我必须使用 tf.raw_ops.IteratorGetNext 方法来访问迭代器的内容。

在执行此操作时,我最初遇到另一个错误,指出迭代器类型与预期的资源类型不匹配,仔细调试后我明白我必须对数据集执行的 read_tfrecords 映射不成功,这导致了数据集仍然包含格式为 tf.string 的未映射 tfrecords,这不是 train_Step 的预期资源类型。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。