微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

警告:调用迭代器没有完全读取正在缓存的数据集为了避免数据集的意外截断

如何解决警告:调用迭代器没有完全读取正在缓存的数据集为了避免数据集的意外截断

这是在我使用tf.data.Dataset时发生的:

调用迭代器没有完全读取正在缓存的数据集。为了避免数据集的意外截断,数据集的部分缓存内容将被丢弃。如果您的输入管道类似于dataset.cache().take(k).repeat(),则可能会发生这种情况。您应该改dataset.take(k).cache().repeat()

根据其他问题,例如this one,它与方法序列中cache()的位置有关,但我不知道具体要做什么。

以下是重现警告的方法

import tensorflow_datasets as tfds

ds = tfds.load('iris',split='train')

ds = ds.take(100)

for elem in ds:
    pass

无论我做什么,无论我在哪里使用cache(),似乎都会弹出警告。

解决方法

我尝试在Google colab上运行您的代码,它成功运行而没有发出任何警告,我正在使用Tensorflow 2.3

但是,您可以在使用cache时遵循这种常规方法。

如果数据集足够小以适合内存,则可以显着提高速度 使用数据集的cache()方法将其内容缓存到 RAM。通常,您应该在加载并预处理 数据,但在shufflingrepeatingbatchingprefetching之前。这条路, 每个实例只会被读取和预处理一次(而不是每个时期一次)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。