如何解决Detectron2 中物体检测的输入图像类型
我正在使用 Detectron2 来训练用于对象检测的 Faster R-CNN 模型,我想训练模型动物园给出的模型,输入范围为 [0 1] 而不是 [0 255],所以我使用了颜色变换调用我的函数 scale_transform
def scale_transform(img):
return img/255.
此函数正在接收一个 numpy 数组并返回它的缩放比例。但是,在火车时间出现此错误
RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same
有人知道我该如何解决这个问题吗?或另一种缩放detectron2图像的方法?
谢谢
解决方法
我认为这里的相关词是类型。
也许确保输入被定义为浮点数。尽管它在正确的范围 (0-1) 内,但它可能会发现数据类型不正确,因此在那里绊倒了。
以下可能对它 -
def scale_transform(img):
img = img/255
img = img.astype(np.float32)
return img
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。