微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何使用自定义 COCO 样式数据集重新训练 Torchvision 的关键点 R-CNN?

如何解决如何使用自定义 COCO 样式数据集重新训练 Torchvision 的关键点 R-CNN?

我使用 COCO annotator 创建了一个自定义 COCO 关键点样式数据集,并希望在其上重新训练 Torchvision 的 Keypoint R-CNN。 我正在尝试使用 torchvision 的 CocoDetection 数据集类来加载数据,我不得不重写 _load_image 方法,因为我的数据集有子目录。然后我尝试将数据集包装在数据加载器中并得到以下错误

>>> dl = DataLoader(coco,batch_size=4)
>>> feat,lbl = next(iter(dl))
Traceback (most recent call last):
  File "<stdin>",line 1,in <module>
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/DataLoader.py",line 521,in __next__
    data = self._next_data()
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/DataLoader.py",line 561,in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise stopiteration
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py",line 47,in fetch
    return self.collate_fn(data)
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py",line 84,in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py",in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py",line 74,in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py",in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/sam/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py",line 82,in default_collate
    raise RuntimeError('each element in list of batch should be of equal size')
RuntimeError: each element in list of batch should be of equal size

考虑到 Keypoint R-CNN 需要一个 [通道、高度、宽度] 张量列表,尝试将数据集放入数据加载器是否正确?

此外,当我获得可接受格式的数据时,我无法弄清楚应该如何实际训练模型。我查看了 https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractorhttps://github.com/pytorch/vision/tree/master/references/detection,但仍然有点困惑。我能否获得一些有关如何在具有单个 GPU 的机器上训练模型的指导?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。