使用 Keras API,如何批量导入给定批次中每个 ID 恰好有 K 个实例的图像?

如何解决使用 Keras API,如何批量导入给定批次中每个 ID 恰好有 K 个实例的图像?

我正在尝试实现批量硬三元组损失,如 https://arxiv.org/pdf/2004.06271.pdf 的第 3.2 节所示。

我需要导入我的图像,以便每个批次在特定批次中都有恰好 K 个每个 ID 的实例。因此,每批必须是K的倍数

我的图像目录太大而无法放入内存,因此我使用 ImageDataGenerator.flow_from_directory() 导入图像,但我看不到此函数的任何参数以允许我需要的功能。>

如何使用 Keras 实现这种批处理行为?

解决方法

从 Tensorflow 2.4 开始,我看不到使用 ImageDataGenerator 执行此操作的标准方法。

所以我认为您需要根据 tensorflow.keras.utils.Sequence 类编写自己的内容,因此您可以自己定义批处理内容。

参考文献:
https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence
https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

,

您可以尝试以受控方式将多个数据流合并在一起。

假设您有 K 个 tf.data.Dataset 实例(不管您如何实例化它们)负责提供特定 ID 的训练实例,您可以将它们连接起来以在小批量中均匀分布:>

ds1 = ...  # Training instances with ID == 1
ds2 = ...  # Training instances with ID == 2
...
dsK = ... # Training instances with ID == K



train_dataset = tf.data.Dataset.zip((ds1,ds2,...,dsK)).flat_map(concat_datasets).batch(batch_size=N * K)

其中 concat_datasets 是合并函数:

def concat_datasets(*datasets):
    ds = tf.data.Dataset.from_tensors(datasets[0])
    for i in range(1,len(datasets)):
        ds = ds.concatenate(tf.data.Dataset.from_tensors(datasets[i]))
    return ds

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?