微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在 Detectron2 中使用 DefaultTrainer 保存模型?

如何解决如何在 Detectron2 中使用 DefaultTrainer 保存模型?

如何使用 DefaultTrainer 在 Detectron2 中保存检查点? 这是我的设置:

reduce_mean

我收到错误

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))

cfg.DATASETS.TRAIN = (DatasetLabels.TRAIN,)
cfg.DATASETS.TEST = ()
cfg.DataLoader.NUM_WORKERS = 2
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 273  # Number of output classes

cfg.OUTPUT_DIR = "outputs"
os.makedirs(cfg.OUTPUT_DIR,exist_ok=True)
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.soLVER.ims_PER_BATCH = 2
cfg.soLVER.BASE_LR = 0.00025#0.00025  # Learning Rate
cfg.soLVER.MAX_ITER = 10000  # 20000 MAx Iterations
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128  # Batch Size

trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()


# Save the model
from detectron2.checkpoint import DetectionCheckpointer,Checkpointer
checkpointer = DetectionCheckpointer(trainer,save_dir=cfg.OUTPUT_DIR)
checkpointer.save("mymodel_0")  

文档:https://detectron2.readthedocs.io/en/latest/modules/checkpoint.html

解决方法

checkpointer = DetectionCheckpointer(trainer.model,save_dir=cfg.OUTPUT_DIR)

是要走的路。

或者:

torch.save(trainer.model.state_dict(),os.path.join(cfg.OUTPUT_DIR,"mymodel.pth"))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。