diff --git a/mmpretrain/apis/model.py b/mmpretrain/apis/model.py index eba475e7f7..3f05e1074c 100644 --- a/mmpretrain/apis/model.py +++ b/mmpretrain/apis/model.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import fnmatch +import inspect import os.path as osp import re import warnings @@ -207,8 +208,12 @@ def get_model(model: Union[str, Config], dataset_meta = {'classes': checkpoint['meta']['CLASSES']} if len(dataset_meta) == 0 and 'test_dataloader' in config: - from mmpretrain.registry import DATASETS - dataset_class = DATASETS.get(config.test_dataloader.dataset.type) + # compatible with new config + if inspect.isclass(config.test_dataloader.dataset.type): + dataset_class = config.test_dataloader.dataset.type + else: + from mmpretrain.registry import DATASETS + dataset_class = DATASETS.get(config.test_dataloader.dataset.type) dataset_meta = getattr(dataset_class, 'METAINFO', {}) if device_map is not None: