diff --git a/tools/train.py b/tools/train.py index d14a3d6fba1..41ea55d3b4c 100644 --- a/tools/train.py +++ b/tools/train.py @@ -3,6 +3,7 @@ import copy import os import os.path as osp +import string import time import warnings @@ -98,6 +99,17 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) + + # update data root according to environment variable + if os.environ.get('MMDET_DATASETS', None) is not None: + def update_data_root(cfg, str_o, str_n): + for k, v in cfg.items(): + if isinstance(v, mmcv.ConfigDict): + update_data_root(cfg[k], str_o, str_n) + if isinstance(v, str): + cfg[k] = v.replace(str_o, str_n) + update_data_root(cfg, cfg.data_root, os.environ['MMDET_DATASETS']) + if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) @@ -108,10 +120,6 @@ def main(): if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True - # update data root according to environment variable - if os.environ.get('MMDET_DATASETS', None) is not None: - cfg.data_root = os.environ['MMDET_DATASETS'] - # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None