diff --git a/wenet/utils/checkpoint.py b/wenet/utils/checkpoint.py index 30f06d46b..9071a9bec 100644 --- a/wenet/utils/checkpoint.py +++ b/wenet/utils/checkpoint.py @@ -25,7 +25,7 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: logging.info('Checkpoint: loading from checkpoint %s' % path) - checkpoint = torch.load(path, map_location='cpu') + checkpoint = torch.load(path, map_location='cpu', mmap=True) missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) for key in missing_keys: