Skip to content

Commit

Permalink
[Feature] add auto resume (open-mmlab#1172)
Browse files Browse the repository at this point in the history
* [Feature] add auto resume

* Update mmseg/utils/find_latest_checkpoint.py

Co-authored-by: Miao Zheng <[email protected]>

* Update mmseg/utils/find_latest_checkpoint.py

Co-authored-by: Miao Zheng <[email protected]>

* modify docstring

* Update mmseg/utils/find_latest_checkpoint.py

Co-authored-by: Miao Zheng <[email protected]>

* add copyright

Co-authored-by: Miao Zheng <[email protected]>
  • Loading branch information
RockeyCoss and MeowZheng authored Jan 11, 2022
1 parent ae51615 commit 43ad37b
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 2 deletions.
6 changes: 5 additions & 1 deletion mmseg/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from mmseg.core import DistEvalHook, EvalHook
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import get_root_logger
from mmseg.utils import find_latest_checkpoint, get_root_logger


def init_random_seed(seed=None, device='cuda'):
Expand Down Expand Up @@ -160,6 +160,10 @@ def train_segmentor(model,
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)

if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
3 changes: 2 additions & 1 deletion mmseg/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .logger import get_root_logger
from .misc import find_latest_checkpoint

__all__ = ['get_root_logger', 'collect_env']
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint']
41 changes: 41 additions & 0 deletions mmseg/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings


def find_latest_checkpoint(path, suffix='pth'):
"""This function is for finding the latest checkpoint.
It will be used when automatically resume, modified from
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py
Args:
path (str): The path to find checkpoints.
suffix (str): File extension for the checkpoint. Defaults to pth.
Returns:
latest_path(str | None): File path of the latest checkpoint.
"""
if not osp.exists(path):
warnings.warn("The path of the checkpoints doesn't exist.")
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')

checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('The are no checkpoints in the path')
return None
latest = -1
latest_path = ''
for checkpoint in checkpoints:
if len(checkpoint) < len(latest_path):
continue
# `count` is iteration number, as checkpoints are saved as
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
40 changes: 40 additions & 0 deletions tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

from mmseg.utils import find_latest_checkpoint


def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tempdir:
# no checkpoints in the path
path = tempdir
latest = find_latest_checkpoint(path)
assert latest is None

# The path doesn't exist
path = osp.join(tempdir, 'none')
latest = find_latest_checkpoint(path)
assert latest is None

# test when latest.pth exists
with tempfile.TemporaryDirectory() as tempdir:
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
f.write('latest')
path = tempdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tempdir, 'latest.pth')

with tempfile.TemporaryDirectory() as tempdir:
for iter in range(1600, 160001, 1600):
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
f.write(f'iter_{iter}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'iter_160000.pth')

with tempfile.TemporaryDirectory() as tempdir:
for epoch in range(1, 21):
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
f.write(f'epoch_{epoch}.pth')
latest = find_latest_checkpoint(tempdir)
assert latest == osp.join(tempdir, 'epoch_20.pth')
5 changes: 5 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically.')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
Expand Down Expand Up @@ -118,6 +122,7 @@ def main():
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
cfg.auto_resume = args.auto_resume

# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
Expand Down

0 comments on commit 43ad37b

Please sign in to comment.