Skip to content

Commit

Permalink
Support open-mmlab#6915: seperate function in tools/utils.py, support…
Browse files Browse the repository at this point in the history
… test.py and browse_dataset.py
  • Loading branch information
CCODING04 committed Mar 16, 2022
1 parent 1b9c0d1 commit 3dd0bef
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
6 changes: 6 additions & 0 deletions tools/misc/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mmdet.core.visualization import imshow_det_bboxes
from mmdet.datasets.builder import build_dataset

from tools.utils import update_data_root

def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
Expand Down Expand Up @@ -55,6 +56,11 @@ def skip_pipeline_steps(config):
]

cfg = Config.fromfile(config_path)

# update data root according to MMDET_DATASETS
if os.environ.get('MMDET_DATASETS', None) is not None:
update_data_root(cfg, os.environ['MMDET_DATASETS'])

if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
train_data_cfg = cfg.data.train
Expand Down
6 changes: 6 additions & 0 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mmdet.models import build_detector
from mmdet.utils import setup_multi_processes

from tools.utils import update_data_root

def parse_args():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -133,6 +134,11 @@ def main():
raise ValueError('The output file must be a pkl file.')

cfg = Config.fromfile(args.config)

# update data root according to MMDET_DATASETS
if os.environ.get('MMDET_DATASETS', None) is not None:
update_data_root(cfg, os.environ['MMDET_DATASETS'])

if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

Expand Down
12 changes: 4 additions & 8 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger, setup_multi_processes

from tools.utils import update_data_root

def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
Expand Down Expand Up @@ -99,15 +100,10 @@ def main():

cfg = Config.fromfile(args.config)

# update data root according to environment variable
# update data root according to MMDET_DATASETS
if os.environ.get('MMDET_DATASETS', None) is not None:
def update_data_root(cfg, str_o, str_n):
for k, v in cfg.items():
if isinstance(v, mmcv.ConfigDict):
update_data_root(cfg[k], str_o, str_n)
if isinstance(v, str):
cfg[k] = v.replace(str_o, str_n)
update_data_root(cfg, cfg.data_root, os.environ['MMDET_DATASETS'])
update_data_root(cfg, os.environ['MMDET_DATASETS'])


if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
Expand Down
28 changes: 28 additions & 0 deletions tools/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
from mmcv.utils import print_log

def update_data_root(cfg, data_root_n, logger=None):
"""updata data root
Args:
cfg (mmcv.Config): model config
data_root_n (str): new data root
logger (logging.Logger | str | None): the way to print msg
"""
assert isinstance(cfg, mmcv.Config), \
f"cfg got wrong type: {type(cfg)}, expected mmcv.Config"

def update(cfg, str_o, str_n):
for k, v in cfg.items():
if isinstance(v, mmcv.ConfigDict):
update(cfg[k], str_o, str_n)
if isinstance(v, str) and str_o in v:
cfg[k] = v.replace(str_o, str_n)

update(cfg.data, cfg.data_root, data_root_n)
cfg.data_root = data_root_n
print_log(
f"Set data root to {data_root_n} according to MMDET_DATASETS",
logger=logger)

0 comments on commit 3dd0bef

Please sign in to comment.