微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么我在 WiSe 数据集中使用 DeepLab v3+ 得到完全零预测,即使损失不断减少?

如何解决为什么我在 WiSe 数据集中使用 DeepLab v3+ 得到完全零预测,即使损失不断减少?

我正在尝试在 WiSe 数据集 (https://github.com/tensorflow/models/research/deeplab/) 上训练 DeepLab v3+ 模型 (https://cvhci.anthropomatik.kit.edu/~mhaurile/wise/)。我修改了提供的脚本中的参数并开始运行 train.py 脚本,但即使损失不断减少(从第 10 步的大约 2.7 到第 100 步的大约 1.9),我在导出的检查点所做的预测。即使在每张火车图像上,我也得到了全零预测。
数据集信息(我已经根据我的需要处理了数据集):
火车图片:1222
Val 图像:100
图片总数:1322
总班级:9(含背景)
类:['background','TitleSlide','PresTitle','ImageCaption','Image','Code','Enumeration','Tables','Paragraph'] \

我在 datasets/data_generator.py添加了以下代码

_WISE_SEG_informatION = DatasetDescriptor(
    splits_to_sizes={
        'train': 1222,'trainval': 1322,'val': 100,},num_classes=10,# 8 foreground + 1 background + 1 ignore
    ignore_label=255,)

_DATASETS_informatION = {
    'cityscapes': _CITYSCAPES_informatION,'pascal_voc_seg': _PASCAL_VOC_SEG_informatION,'ade20k': _ADE20K_informatION,'wise_seg': _WISE_SEG_informatION,}

请注意,在我的数据集中,实际上没有图像具有任何标签为 255 的像素。每个标签都在 [0,8] 范围内。我也尝试将 num_classes 设置为 9,但没有成功。
我的目录结构如下:

deeplab
├── datasets
│   ├── wise_seg
│   │   ├── exp
│   │   │   └── train_on_train_set
│   │   │       ├── eval
│   │   │       ├── export
│   │   │       ├── train
│   │   │       └── vis
│   │   ├── init_models
│   │   │   └── xception
|   |   |       ├── model.ckpt.data-00000-of-00001
|   |   |       └── model.ckpt.index
│   │   ├── tfrecord
│   │   └── WiSe
│   │       ├── Annotations
│   │       ├── imagesets
│   │       │   └── Segmentation
|   |       |       ├── train.txt
|   |       |       ├── trainval.txt
|   |       |       └── val.txt
│   │       ├── JPEGImages
│   │       ├── SegmentationClass
│   │       └── SegmentationClassRaw
│   └── __pycache__
|------ Other stuff

我用来运行训练的命令:

python ./train.py \
  --logtostderr \
  --train_split="train" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --train_crop_size="513,513" \
  --train_batch_size=16 \
  --training_number_of_steps=30000 \
  --fine_tune_batch_norm=true \
  --tf_initial_checkpoint="./datasets/wise_seg/init_models/xception/model.ckpt" \
  --train_logdir="./datasets/wise_seg/train" \
  --dataset="wise_seg" \
  --initialize_last_layer=false \
  --last_layers_contain_logits_only=false \
  --dataset_dir="./datasets/wise_seg/tfrecord"

请注意,我已经设置了 initialize_last_layer = Falselast_layers_contain_logits_only = False。我使用 ImageNet 预训练的 Xception-65 模型作为主干网络,我从给定的链接 here(具体来说,我使用了 xception_65_imagenet)下载了该模型。
我还在 utils/train_utils.py 中进行了以下更改:

exclude_list = ['global_step','logits']
  if not initialize_last_layer:
    exclude_list.extend(last_layers)

当我执行训练的时候,它能够成功地到达训练部分,现在已经训练到了大约110步。我使用以下命令导出了一个中间检查点:

python ./export_model.py \
  --logtostderr \
  --checkpoint_path="./datasets/wise_seg/exp/train_on_train_set/train/model.ckpt-41" \
  --export_path="./datasets/wise_seg/exp/train_on_train_set/export/frozen_inference_graph-41.pb" \
  --model_variant="xception_65" \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --num_classes=${3} \
  --crop_size=513 \
  --crop_size=513 \
  --inference_scales=1.0

检查点成功导出。然后我尝试使用给定 here 的示例笔记本运行推理。具体来说,当我运行以下部分时,0 会打印在输出中:

graph_path = './datasets/wise_seg/exp/train_on_train_set/export/frozen_inference_graph-41.pb'
MODEL = DeepLabModel(graph_path)
resized_im,seg_map = MODEL.run(Image.open('./datasets/wise_seg/WiSe/JPEGImages/130110-3MQQHISL3D-540_frame11610.jpg'))
print(sum(sum(seg_map)))

对于任何给定的图像都会发生同样的情况。为什么会这样?任何帮助将不胜感激。

解决方法

您应该尝试使用 110 步以上(至少 2000 步以上)进行训练。你的损失应该低于 1.9。还请确保标记的掩码显示 0、1、2、3、4、... 8 的像素值。此外,设置 num_classes = 9 是正确的。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?