微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在 Tensorflow 中按张量形状过滤数据集

如何解决如何在 Tensorflow 中按张量形状过滤数据集

我从 tfds.load 加载了一个数据集,并想丢弃某些干扰正确训练/对我没有用的图像(例如,太小)。

似乎在任何地方都没有关于这个特定问题的信息,所以我选择了似乎最合适的方法,即数据集上的 .filter(predicate)。不幸的是,谓词的输入具有不确定的形状(None、None、3),并且正如预期的那样引发了一个错误,即 'int' 无法与 'nonetype' 进行比较。

甚至可以在 tensorflow 中解决这个问题还是我不​​应该浪费我的时间?

代码

ds_train = tfds.load('name')
ds_train = ds_train.map(lambda ds: ds['image'])
ds_train = ds_train.filter(lambda image: image.shape[0] >= 256)

解决方法

使用 tf.data.Dataset 编写代码时,应使用 tf.shape(tensor) 而不是 tensor.shape,因为 tf.data.Dataset 在图形模式下工作。

引用 tf.shape 的文档:

tf.shape 和 Tensor.shape 在 Eager 模式下应该是相同的。在 tf.function 或 compat.v1 上下文中,在执行之前并非所有维度都是已知的。因此,在为图形模式定义自定义层和模型时,更喜欢动态 tf.shape(x) 而不是静态 x.shape。

ds_train = ds_train.filter(lambda image: tf.shape(image)[0] >= 256)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。