diff --git a/mmdet3d/apis/inference.py b/mmdet3d/apis/inference.py index 3af7153daf..b089a6781a 100644 --- a/mmdet3d/apis/inference.py +++ b/mmdet3d/apis/inference.py @@ -57,7 +57,7 @@ def init_model(config, checkpoint=None, device='cuda:0'): config.model.train_cfg = None model = build_model(config.model, test_cfg=config.get('test_cfg')) if checkpoint is not None: - checkpoint = load_checkpoint(model, checkpoint) + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') if 'CLASSES' in checkpoint['meta']: model.CLASSES = checkpoint['meta']['CLASSES'] else: