forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] add auto resume (open-mmlab#1172)
* [Feature] add auto resume * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <[email protected]> * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <[email protected]> * modify docstring * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <[email protected]> * add copyright Co-authored-by: Miao Zheng <[email protected]>
- Loading branch information
1 parent
ae51615
commit 43ad37b
Showing
5 changed files
with
93 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .collect_env import collect_env | ||
from .logger import get_root_logger | ||
from .misc import find_latest_checkpoint | ||
|
||
__all__ = ['get_root_logger', 'collect_env'] | ||
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import glob | ||
import os.path as osp | ||
import warnings | ||
|
||
|
||
def find_latest_checkpoint(path, suffix='pth'): | ||
"""This function is for finding the latest checkpoint. | ||
It will be used when automatically resume, modified from | ||
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py | ||
Args: | ||
path (str): The path to find checkpoints. | ||
suffix (str): File extension for the checkpoint. Defaults to pth. | ||
Returns: | ||
latest_path(str | None): File path of the latest checkpoint. | ||
""" | ||
if not osp.exists(path): | ||
warnings.warn("The path of the checkpoints doesn't exist.") | ||
return None | ||
if osp.exists(osp.join(path, f'latest.{suffix}')): | ||
return osp.join(path, f'latest.{suffix}') | ||
|
||
checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) | ||
if len(checkpoints) == 0: | ||
warnings.warn('The are no checkpoints in the path') | ||
return None | ||
latest = -1 | ||
latest_path = '' | ||
for checkpoint in checkpoints: | ||
if len(checkpoint) < len(latest_path): | ||
continue | ||
# `count` is iteration number, as checkpoints are saved as | ||
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number. | ||
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) | ||
if count > latest: | ||
latest = count | ||
latest_path = checkpoint | ||
return latest_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os.path as osp | ||
import tempfile | ||
|
||
from mmseg.utils import find_latest_checkpoint | ||
|
||
|
||
def test_find_latest_checkpoint(): | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
# no checkpoints in the path | ||
path = tempdir | ||
latest = find_latest_checkpoint(path) | ||
assert latest is None | ||
|
||
# The path doesn't exist | ||
path = osp.join(tempdir, 'none') | ||
latest = find_latest_checkpoint(path) | ||
assert latest is None | ||
|
||
# test when latest.pth exists | ||
with tempfile.TemporaryDirectory() as tempdir: | ||
with open(osp.join(tempdir, 'latest.pth'), 'w') as f: | ||
f.write('latest') | ||
path = tempdir | ||
latest = find_latest_checkpoint(path) | ||
assert latest == osp.join(tempdir, 'latest.pth') | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
for iter in range(1600, 160001, 1600): | ||
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f: | ||
f.write(f'iter_{iter}.pth') | ||
latest = find_latest_checkpoint(tempdir) | ||
assert latest == osp.join(tempdir, 'iter_160000.pth') | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
for epoch in range(1, 21): | ||
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f: | ||
f.write(f'epoch_{epoch}.pth') | ||
latest = find_latest_checkpoint(tempdir) | ||
assert latest == osp.join(tempdir, 'epoch_20.pth') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters