Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Resume from the latest checkpoint automatically. #245

Merged
merged 4 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mmselfsup/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from mmselfsup.core import (DistOptimizerHook, GradAccumFp16OptimizerHook,
build_optimizer)
from mmselfsup.datasets import build_dataloader, build_dataset
from mmselfsup.utils import get_root_logger, multi_gpu_test, single_gpu_test
from mmselfsup.utils import (find_latest_checkpoint, get_root_logger,
multi_gpu_test, single_gpu_test)


def init_random_seed(seed=None, device='cuda'):
Expand Down Expand Up @@ -192,6 +193,12 @@ def train_model(model,
eval_hook(val_dataloader, test_fn=eval_fn, **eval_cfg),
priority='LOW')

resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
5 changes: 3 additions & 2 deletions mmselfsup/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from .extractor import Extractor
from .gather import concat_all_gather, gather_tensors, gather_tensors_batch
from .logger import get_root_logger
from .misc import find_latest_checkpoint
from .setup_env import setup_multi_processes
from .test_helper import multi_gpu_test, single_gpu_test

__all__ = [
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
'distributed_sinkhorn', 'Extractor', 'concat_all_gather', 'gather_tensors',
'gather_tensors_batch', 'get_root_logger', 'multi_gpu_test',
'single_gpu_test', 'setup_multi_processes'
'gather_tensors_batch', 'get_root_logger', 'find_latest_checkpoint',
'multi_gpu_test', 'single_gpu_test', 'setup_multi_processes'
]
36 changes: 36 additions & 0 deletions mmselfsup/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings

import mmcv
import numpy as np

Expand All @@ -15,3 +19,35 @@ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
img, mean, std, to_bgr=to_rgb).astype(np.uint8)
imgs.append(np.ascontiguousarray(img))
return imgs


def find_latest_checkpoint(path, suffix='pth'):
"""Find the latest checkpoint from the working directory.
Args:
path(str): The path to find checkpoints.
suffix(str): File extension.
Defaults to pth.
Returns:
latest_path(str | None): File path of the latest checkpoint.
References:
.. [1] https://github.com/microsoft/SoftTeacher
/blob/main/ssod/utils/patch.py
"""
if not osp.exists(path):
warnings.warn('The path of checkpoints does not exist.')
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')

checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('There are no checkpoints in the path.')
return None
latest = -1
latest_path = None
for checkpoint in checkpoints:
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
43 changes: 42 additions & 1 deletion tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

import pytest
import torch

from mmselfsup.utils.misc import tensor2imgs
from mmselfsup.utils.misc import find_latest_checkpoint, tensor2imgs


def test_tensor2imgs():
Expand All @@ -12,3 +15,41 @@ def test_tensor2imgs():
fake_imgs = tensor2imgs(fake_tensor)
assert len(fake_imgs) == 3
assert fake_imgs[0].shape == (16, 16, 3)


def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir
latest = find_latest_checkpoint(path)
# There are no checkpoints in the path.
assert latest is None

path = osp.join(tmpdir, 'none')
latest = find_latest_checkpoint(path)
# The path does not exist.
assert latest is None

with tempfile.TemporaryDirectory() as tmpdir:
with open(osp.join(tmpdir, 'latest.pth'), 'w') as f:
f.write('latest')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'latest.pth')

with tempfile.TemporaryDirectory() as tmpdir:
with open(osp.join(tmpdir, 'iter_4000.pth'), 'w') as f:
f.write('iter_4000')
with open(osp.join(tmpdir, 'iter_8000.pth'), 'w') as f:
f.write('iter_8000')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'iter_8000.pth')

with tempfile.TemporaryDirectory() as tmpdir:
with open(osp.join(tmpdir, 'epoch_1.pth'), 'w') as f:
f.write('epoch_1')
with open(osp.join(tmpdir, 'epoch_2.pth'), 'w') as f:
f.write('epoch_2')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'epoch_2.pth')
5 changes: 5 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def parse_args():
parser.add_argument('--work_dir', help='the dir to save logs and models')
parser.add_argument(
'--resume_from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
Expand Down Expand Up @@ -100,6 +104,7 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
Expand Down