python – 访问TensorFlow Dataset API中排队项的数量

我正在将TensorFlow代码从旧队列接口更改为新的Dataset API.使用旧接口,我可以通过访问图中的原始计数器来监视实际填充的队列大小,例如,如下:

queue = tf.train.shuffle_batch(...,  name="training_batch_queue")
queue_size_op = "training_batch_queue/random_shuffle_queue_Size:0"
queue_size = session.run(queue_size_op)

但是,使用新的数据集API,我似乎无法在图表中找到与队列/数据集相关的任何变量,因此我的旧代码不再起作用.有没有办法使用新的Dataset API获取队列中的项目数(例如在tf.Dataset.prefetch或tf.Dataset.shuffle队列中)?

对我来说监控队列中的项目数非常重要,因为这告诉我很多关于队列中预处理的行为,包括预处理或其余部分(例如神经网络)是否是速度瓶颈.

解决方法:

作为解决方法,可以保留一个计数器来指示队列中有多少项.以下是如何定义计数器:

 queue_size = tf.get_variable("queue_size", initializer=0,
                              trainable=False, use_resource=True)

然后,当预处理数据时(例如在dataset.map函数中),我们可以递增该计数器:

 def pre_processing():
    data_size = ... # compute this (could be just '1')
    queue_size_op = tf.assign_add(queue_size, data_size)  # adding items
    with tf.control_dependencies([queue_size_op]):
        # do the actual pre-processing here

然后,我们可以在每次使用一批数据运行模型时递减计数器:

 def model():
    queue_size_op = tf.assign_add(queue_size, -batch_size)  # removing items
    with tf.control_dependencies([queue_size_op]):
        # define the actual model here

现在,我们需要做的就是在训练循环中运行queue_size tensor以找出当前队列大小是什么,即此时队列中的项目数:

 current_queue_size = session.run(queue_size)

与旧的方式(在数据集API之前)相比,它有点不那么优雅了,但它确实可以解决问题.

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。

相关推荐