微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在 Keras 中为 TF 数据集中的单热编码标签指定类或样本权重

如何解决在 Keras 中为 TF 数据集中的单热编码标签指定类或样本权重

我正在尝试在不平衡的训练集上训练图像分类器。为了应对类不平衡,我想对类或单个样本进行加权。加权类似乎不起作用。不知何故,对于我的设置,我无法找到指定样本权重的方法。您可以在下面阅读我如何加载和编码训练数据以及我尝试过的两种方法

训练数据加载和编码

我的训练数据存储在一个目录结构中,其中每个图像都放在与其类别对应的子文件夹中(我总共有 32 个类别)。由于训练数据太大,一次加载到内存中,我使用 image_dataset_from_directory 并通过它描述 TF 数据集 中的数据:

train_ds = keras.preprocessing.image_dataset_from_directory (training_data_dir,batch_size=batch_size,image_size=img_size,label_mode='categorical')

我使用 label_mode 'categorical',以便将标签描述为one-hot 编码向量

然后我预取数据:

train_ds = train_ds.prefetch(buffer_size=buffer_size)

方法一:指定类权重

在这方法中,我尝试通过 fit 的 class_weight 参数指定类的类权重:

model.fit(
    train_ds,epochs=epochs,callbacks=callbacks,validation_data=val_ds,class_weight=class_weights
)

对于每个类,我们计算权重,权重与该类的训练样本数成反比。这是按如下方式完成的(这是在上述 train_ds.prefetch() 调用之前完成的):

class_num_training_samples = {}
for f in train_ds.file_paths:
    class_name = f.split('/')[-2]
    if class_name in class_num_training_samples:
        class_num_training_samples[class_name] += 1
    else:
        class_num_training_samples[class_name] = 1
max_class_samples = max(class_num_training_samples.values())
class_weights = {}
for i in range(0,len(train_ds.class_names)):
    class_weights[i] = max_class_samples/class_num_training_samples[train_ds.class_names[i]]

我不确定这个解决方案是否有效,因为 keras 文档没有指定 class_weights 字典的键,以防标签是单热编码的。 我尝试以这种方式训练网络,但发现权重对生成的网络没有真正的影响:当我查看每个单独类的预测类的分布时,我可以识别整个训练集的分布,其中对于每个类别,最有可能预测优势类别。 在没有指定任何类权重的情况下运行相同的训练会导致类似的结果。 所以我怀疑权重似乎对我的情况没有影响。

这是因为指定类权重对单热编码标签不起作用,还是因为我可能做错了其他事情(在我没有在这里显示代码中)?

方法二:指定样本权重

为了提出不同的(在我看来不太优雅)解决方案,我想通过 fit 方法的 sample_weight 参数指定单个样本权重。但是从 documentation 我发现:

[...] 当 x 是数据集、生成器或 keras.utils.Sequence 实例时,不支持此参数,而是提供 sample_weights 作为 x 的第三个元素。

在我的设置中确实是这种情况,其中 train_ds 是一个数据集。现在我真的很难找到可以从中得出如何修改 train_ds 的文档,以便它具有带有权重的第三个元素。我认为使用数据集的 map 方法可能很有用,但我想出的解决方案显然无效:

train_ds = train_ds.map(lambda img,label: (img,label,class_weights[np.argmax(label)]))

有没有人有可以与 image_dataset_from_directory 加载的数据集结合使用的解决方案?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。