From 09a6cd8b7a401fa9c74507141b5a3ae41c05a939 Mon Sep 17 00:00:00 2001 From: Yining Li Date: Tue, 22 Feb 2022 00:28:42 +0800 Subject: [PATCH] [Feature] Add base transform interface (#1538) * Support deepcopy for Config (#1658) * Support deepcopy for Config * Iterate the `__dict__` of Config directly. * Use __new__ to avoid unnecessary initialization. * Improve according to comments * [Feature] Add spconv ops from mmdet3d (#1581) * add ops (spconv) of mmdet3d * fix typo * refactor code * resolve comments in #1452 * fix compile error * fix bugs * fix bug * transform from 'types.h' to 'extension.h' * fix bug * transform from 'types.h' to 'extension.h' in parrots * add extension.h in pybind.cpp * add unittest * Recover code * (1) Remove prettyprint.h (2) Switch `T` to `scalar_t` (3) Remove useless lines (4) Refine example in docstring of sparse_modules.py * (1) rename from `cu.h` to `cuh` (2) remove useless files (3) move cpu files to `pytorch/cpu` * reorganize files * Add docstring for sparse_functional.py * use dispatcher * remove template * use dispatch in cuda ops * resolve Segmentation fault * remove useless files * fix lint * fix lint * fix lint * fix unittest in test_build_layers.py * add tensorview into include_dirs when compiling * recover all deleted files * fix lint and comments * recover setup.py * replace tv::GPU as tv::TorchGPU & support device guard * fix lint Co-authored-by: hdc Co-authored-by: grimoire * Imporve the docstring of imfrombytes and fix a deprecation-warning (#1731) * [Refactor] Refactor the interface for RoIAlignRotated (#1662) * fix interface for RoIAlignRotated * Add a unit test for RoIAlignRotated * Make a unit test for RoIAlignRotated concise * fix interface for RoIAlignRotated * Refactor ext_module.nms_rotated * Lint cpp files * add transforms * add invoking time check for cacheable methods * fix lint * add unittest * fix bug in non-strict input mapping * fix ci * fix ci * fix compatibility with python<3.9 * fix typing compatibility * fix import * fix typing * add alternative for nullcontext * fix import * fix import * add docstrings * add docstrings * fix callable check * resolve comments * fix lint * enrich unittest cases * fix lint * fix unittest Co-authored-by: Ma Zerun Co-authored-by: Wenhao Wu <79644370+wHao-Wu@users.noreply.github.com> Co-authored-by: hdc Co-authored-by: grimoire Co-authored-by: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com> Co-authored-by: Hakjin Lee --- docs/en/api.rst | 5 + docs/zh_cn/api.rst | 5 + mmcv/__init__.py | 1 + mmcv/transform/__init__.py | 5 + mmcv/transform/base.py | 27 ++ mmcv/transform/builder.py | 4 + mmcv/transform/utils.py | 162 +++++++ mmcv/transform/wrappers.py | 457 ++++++++++++++++++ .../test_transform/test_transform_wrapper.py | 370 ++++++++++++++ 9 files changed, 1036 insertions(+) create mode 100644 mmcv/transform/__init__.py create mode 100644 mmcv/transform/base.py create mode 100644 mmcv/transform/builder.py create mode 100644 mmcv/transform/utils.py create mode 100644 mmcv/transform/wrappers.py create mode 100644 tests/test_transform/test_transform_wrapper.py diff --git a/docs/en/api.rst b/docs/en/api.rst index 5d3e623037..747aa659aa 100644 --- a/docs/en/api.rst +++ b/docs/en/api.rst @@ -47,3 +47,8 @@ ops ------ .. automodule:: mmcv.ops :members: + +transform +--------- +.. automodule:: mmcv.transform + :members: diff --git a/docs/zh_cn/api.rst b/docs/zh_cn/api.rst index 5d3e623037..747aa659aa 100644 --- a/docs/zh_cn/api.rst +++ b/docs/zh_cn/api.rst @@ -47,3 +47,8 @@ ops ------ .. automodule:: mmcv.ops :members: + +transform +--------- +.. automodule:: mmcv.transform + :members: diff --git a/mmcv/__init__.py b/mmcv/__init__.py index 14c556acdf..731d4cda17 100644 --- a/mmcv/__init__.py +++ b/mmcv/__init__.py @@ -3,6 +3,7 @@ from .arraymisc import * from .fileio import * from .image import * +from .transform import * from .utils import * from .version import * from .video import * diff --git a/mmcv/transform/__init__.py b/mmcv/transform/__init__.py new file mode 100644 index 0000000000..feb8e38d29 --- /dev/null +++ b/mmcv/transform/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import TRANSFORMS +from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap + +__all__ = ['TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap'] diff --git a/mmcv/transform/base.py b/mmcv/transform/base.py new file mode 100644 index 0000000000..67a0ab55c1 --- /dev/null +++ b/mmcv/transform/base.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Dict + + +class BaseTransform(metaclass=ABCMeta): + + def __call__(self, results: Dict) -> Dict: + + return self.transform(results) + + @abstractmethod + def transform(self, results: Dict) -> Dict: + """The transform function. All subclass of BaseTransform should + override this method. + + This function takes the result dict as the input, and can add new + items to the dict or modify existing items in the dict. And the result + dict will be returned in the end, which allows to concate multiple + transforms into a pipeline. + + Args: + results (dict): The result dict. + + Returns: + dict: The result dict. + """ diff --git a/mmcv/transform/builder.py b/mmcv/transform/builder.py new file mode 100644 index 0000000000..92162dad2c --- /dev/null +++ b/mmcv/transform/builder.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..utils.registry import Registry + +TRANSFORMS = Registry('transform') diff --git a/mmcv/transform/utils.py b/mmcv/transform/utils.py new file mode 100644 index 0000000000..92e4e8c73a --- /dev/null +++ b/mmcv/transform/utils.py @@ -0,0 +1,162 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import functools +import inspect +import weakref +from collections import defaultdict +from collections.abc import Iterable +from contextlib import contextmanager +from typing import Callable, Union + +from .base import BaseTransform + + +class cacheable_method: + """Decorator that marks a method of a transform class as a cacheable + method. + + This decorator is usually used together with the context-manager + :func`:cache_random_params`. In this context, a cacheable method will + cache its return value(s) at the first time of being invoked, and always + return the cached values when being invoked again. + + .. note:: + Only a instance method can be decorated as a cacheable_method. + """ + + def __init__(self, func): + + # Check `func` is to be bound as an instance method + if not inspect.isfunction(func): + raise TypeError('Unsupport callable to decorate with' + '@cacheable_method.') + func_args = inspect.getfullargspec(func).args + if len(func_args) == 0 or func_args[0] != 'self': + raise TypeError( + '@cacheable_method should only be used to decorate ' + 'instance methods (the first argument is `self`).') + + functools.update_wrapper(self, func) + self.func = func + self.instance_ref = None + + def __set_name__(self, owner, name): + # Maintain a record of decorated methods in the class + if not hasattr(owner, '_cacheable_methods'): + setattr(owner, '_cacheable_methods', []) + owner._cacheable_methods.append(self.__name__) + + def __call__(self, *args, **kwargs): + # Get the transform instance whose method is decorated + # by cacheable_method + instance = self.instance_ref() + name = self.__name__ + + # Check the flag `self._cache_enabled`, which should be + # set by the contextmanagers like `cache_random_parameters` + cache_enabled = getattr(instance, '_cache_enabled', False) + + if cache_enabled: + # Initialize the cache of the transform instances. The flag + # `cache_enabled` is set by contextmanagers like + # `cache_random_params`. + if not hasattr(instance, '_cache'): + setattr(instance, '_cache', {}) + + if name not in instance._cache: + instance._cache[name] = self.func(instance, *args, **kwargs) + # Return the cached value + return instance._cache[name] + else: + # Clear cache + if hasattr(instance, '_cache'): + del instance._cache + # Return function output + return self.func(instance, *args, **kwargs) + + def __get__(self, obj, cls): + self.instance_ref = weakref.ref(obj) + return self + + +@contextmanager +def cache_random_params(transforms: Union[BaseTransform, Iterable]): + """Context-manager that enables the cache of cacheable methods in + transforms. + + In this mode, cacheable methods will cache their return values on the + first invoking, and always return the cached value afterward. This allow + to apply random transforms in a deterministic way. For example, apply same + transforms on multiple examples. See `cacheable_method` for more + information. + + Args: + transforms (BaseTransform|list[BaseTransform]): The transforms to + enable cache. + """ + + # key2method stores the original methods that are replaced by the wrapped + # ones. These methods will be restituted when exiting the context. + key2method = dict() + + # key2counter stores the usage number of each cacheable_method. This is + # used to check that any cacheable_method is invoked once during processing + # on data sample. + key2counter = defaultdict(int) + + def _add_counter(obj, method_name): + method = getattr(obj, method_name) + key = f'{id(obj)}.{method_name}' + key2method[key] = method + + @functools.wraps(method) + def wrapped(*args, **kwargs): + key2counter[key] += 1 + return method(*args, **kwargs) + + return wrapped + + def _start_cache(t: BaseTransform): + # Set cache enabled flag + setattr(t, '_cache_enabled', True) + + # Store the original method and init the counter + if hasattr(t, '_cacheable_methods'): + setattr(t, 'transform', _add_counter(t, 'transform')) + for name in t._cacheable_methods: + setattr(t, name, _add_counter(t, name)) + + def _end_cache(t: BaseTransform): + # Remove cache enabled flag + del t._cache_enabled + if hasattr(t, '_cache'): + del t._cache + + # Restore the original method + if hasattr(t, '_cacheable_methods'): + key_transform = f'{id(t)}.transform' + for name in t._cacheable_methods: + key = f'{id(t)}.{name}' + if key2counter[key] != key2counter[key_transform]: + raise RuntimeError( + 'The cacheable method should be called once and only' + f'once during processing one data sample. {t} got' + f'unmatched number of {key2counter[key]} ({name}) vs' + f'{key2counter[key_transform]} (data samples)') + setattr(t, name, key2method[key]) + setattr(t, 'transform', key2method[key_transform]) + + def _apply(t: Union[BaseTransform, Iterable], + func: Callable[[BaseTransform], None]): + if isinstance(t, BaseTransform): + if hasattr(t, '_cacheable_methods'): + func(t) + if isinstance(t, Iterable): + for _t in t: + _apply(_t, func) + + try: + _apply(transforms, _start_cache) + yield + finally: + _apply(transforms, _end_cache) diff --git a/mmcv/transform/wrappers.py b/mmcv/transform/wrappers.py new file mode 100644 index 0000000000..8ff1af7331 --- /dev/null +++ b/mmcv/transform/wrappers.py @@ -0,0 +1,457 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from collections.abc import Sequence +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np + +import mmcv +from .base import BaseTransform +from .builder import TRANSFORMS +from .utils import cache_random_params + +# Indicator for required but missing keys in results +NotInResults = object() + +# Import nullcontext if python>=3.7, otherwise use a simple alternative +# implementation. +try: + from contextlib import nullcontext +except ImportError: + from contextlib import contextmanager + + @contextmanager + def nullcontext(resource=None): + try: + yield resource + finally: + pass + + +class Compose(BaseTransform): + """Compose multiple transforms sequentially. + + Args: + transforms (list[dict | callable]): Sequence of transform object or + config dict to be composed. + + Examples: + >>> pipeline = [ + >>> dict(type='Compose', + >>> transforms=[ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='Normalize') + >>> ] + >>> ) + >>> ] + """ + + def __init__(self, transforms: List[Union[Dict, Callable[[Dict], Dict]]]): + assert isinstance(transforms, Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = TRANSFORMS.build(transform) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict, but got' + f' {type(transform)}') + + def __iter__(self): + """Allow easy iteration over the transform sequence.""" + return iter(self.transforms) + + def transform(self, results: Dict) -> Optional[Dict]: + """Call function to apply transforms sequentially. + + Args: + results (dict): A result dict contains the results to transform. + + Returns: + dict or None: Transformed results. + """ + for t in self.transforms: + results = t(results) + if results is None: + return None + return results + + def __repr__(self): + """Compute the string representation.""" + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += f'\n {t}' + format_string += '\n)' + return format_string + + +@TRANSFORMS.register_module() +class Remap(BaseTransform): + """A transform wrapper to remap and reorganize the input/output of the + wrapped transforms (or sub-pipeline). + + Args: + transforms (list[dict | callable]): Sequence of transform object or + config dict to be wrapped. + input_mapping (dict): A dict that defines the input key mapping. + The keys corresponds to the inner key (i.e., kwargs of the + `transform` method), and should be string type. The values + corresponds to the outer keys (i.e., the keys of the + data/results), and should have a type of string, list or dict. + None means not applying input mapping. Default: None. + output_mapping (dict): A dict that defines the output key mapping. + The keys and values have the same meanings and rules as in the + `input_mapping`. Default: None. + inplace (bool): If True, an inverse of the input_mapping will be used + as the output_mapping. Note that if inplace is set True, + output_mapping should be None and strict should be True. + Default: False. + strict (bool): If True, the outer keys in the input_mapping must exist + in the input data, or an exception will be raised. If False, + the missing keys will be assigned a special value `NotInResults` + during input remapping. Default: True. + + Examples: + >>> # Example 1: Remap 'gt_img' to 'img' + >>> pipeline = [ + >>> # Use Remap to convert outer (original) field name 'gt_img' + >>> # to inner (used by inner transforms) filed name 'img' + >>> dict(type='Remap', + >>> input_mapping=dict(img='gt_img'), + >>> # inplace=True means output key mapping is the revert of + >>> # the input key mapping, e.g. inner 'img' will be mapped + >>> # back to outer 'gt_img' + >>> inplace=True, + >>> transforms=[ + >>> # In all transforms' implementation just use 'img' + >>> # as a standard field name + >>> dict(type='Crop', crop_size=(384, 384)), + >>> dict(type='Normalize'), + >>> ]) + >>> ] + >>> # Example 2: Collect and structure multiple items + >>> pipeline = [ + >>> # The inner field 'imgs' will be a dict with keys 'img_src' + >>> # and 'img_tar', whose values are outer fields 'img1' and + >>> # 'img2' respectively. + >>> dict(type='Remap', + >>> dict( + >>> type='Remap', + >>> input_mapping=dict( + >>> imgs=dict( + >>> img_src='img1', + >>> img_tar='img2')), + >>> transforms=...) + >>> ] + """ + + def __init__(self, + transforms: List[Union[Dict, Callable[[Dict], Dict]]], + input_mapping: Optional[Dict] = None, + output_mapping: Optional[Dict] = None, + inplace: bool = False, + strict: bool = True): + + self.inplace = inplace + self.strict = strict + self.input_mapping = input_mapping + + if self.inplace: + if not self.strict: + raise ValueError('Remap: `strict` must be set True if' + '`inplace` is set True.') + + if output_mapping is not None: + raise ValueError('Remap: `output_mapping` must be None if' + '`inplace` is set True.') + self.output_mapping = input_mapping + else: + self.output_mapping = output_mapping + + self.transforms = Compose(transforms) + + def __iter__(self): + """Allow easy iteration over the transform sequence.""" + return iter(self.transforms) + + def remap_input(self, data: Dict, input_mapping: Dict) -> Dict[str, Any]: + """Remap inputs for the wrapped transforms by gathering and renaming + data items according to the input_mapping. + + Args: + data (dict): The original input data + input_mapping (dict): The input key mapping. See the document of + `mmcv.transforms.wrappers.Remap` for details. + + Returns: + dict: The input data with remapped keys. This will be the actual + input of the wrapped pipeline. + """ + + def _remap(data, m): + if isinstance(m, dict): + # m is a dict {inner_key:outer_key, ...} + return {k_in: _remap(data, k_out) for k_in, k_out in m.items()} + if isinstance(m, (tuple, list)): + # m is a list or tuple [outer_key1, outer_key2, ...] + # This is the case when we collect items from the original + # data to form a list or tuple to feed to the wrapped + # transforms. + return m.__class__(_remap(data, e) for e in m) + + # m is an outer_key + if self.strict: + return data.get(m) + else: + return data.get(m, NotInResults) + + collected = _remap(data, input_mapping) + collected = { + k: v + for k, v in collected.items() if v is not NotInResults + } + + # Retain unmapped items + inputs = data.copy() + inputs.update(collected) + + return inputs + + def remap_output(self, data: Dict, output_mapping: Dict) -> Dict[str, Any]: + """Remap outputs from the wrapped transforms by gathering and renaming + data items according to the output_mapping. + + Args: + data (dict): The output of the wrapped pipeline. + output_mapping (dict): The output key mapping. See the document of + `mmcv.transforms.wrappers.Remap` for details. + + Returns: + dict: The output with remapped keys. + """ + + def _remap(data, m): + if isinstance(m, dict): + assert isinstance(data, dict) + results = {} + for k_in, k_out in m.items(): + assert k_in in data + results.update(_remap(data[k_in], k_out)) + return results + if isinstance(m, (list, tuple)): + assert isinstance(data, (list, tuple)) + assert len(data) == len(m) + results = {} + for m_i, d_i in zip(m, data): + results.update(_remap(d_i, m_i)) + return results + + return {m: data} + + # Note that unmapped items are not retained, which is different from + # the behavior in remap_input. This is to avoid original data items + # being overwritten by intermediate namesakes + return _remap(data, output_mapping) + + def transform(self, results: Dict) -> Dict: + + inputs = self.remap_input(results, self.input_mapping) + outputs = self.transforms(inputs) + + if self.output_mapping: + outputs = self.remap_output(outputs, self.output_mapping) + + results.update(outputs) + return results + + +@TRANSFORMS.register_module() +class ApplyToMultiple(Remap): + """A transform wrapper to apply the wrapped transforms to multiple data + items. For example, apply Resize to multiple images. + + Args: + transforms (list[dict | callable]): Sequence of transform object or + config dict to be wrapped. + input_mapping (dict): A dict that defines the input key mapping. + Note that to apply the transforms to multiple data items, the + outer keys of the target items should be remapped as a list with + the standard inner key (The key required by the wrapped transform). + See the following example and the document of + `mmcv.transforms.wrappers.Remap` for details. + output_mapping (dict): A dict that defines the output key mapping. + The keys and values have the same meanings and rules as in the + `input_mapping`. Default: None. + inplace (bool): If True, an inverse of the input_mapping will be used + as the output_mapping. Note that if inplace is set True, + output_mapping should be None and strict should be True. + Default: False. + strict (bool): If True, the outer keys in the input_mapping must exist + in the input data, or an exception will be raised. If False, + the missing keys will be assigned a special value `NotInResults` + during input remapping. Default: True. + share_random_params (bool): If True, the random transform + (e.g., RandomFlip) will be conducted in a deterministic way and + have the same behavior on all data items. For example, to randomly + flip either both input image and ground-truth image, or none. + Default: False. + + .. note:: + To apply the transforms to each elements of a list or tuple, instead + of separating data items, you can remap the outer key of the target + sequence to the standard inner key. See example 2. + example. + + Examples: + >>> # Example 1: + >>> pipeline = [ + >>> dict(type='LoadImageFromFile', key='lq'), # low-quality img + >>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img + >>> # ApplyToMultiple maps multiple outer fields to standard the + >>> # inner field and process them with wrapped transforms + >>> # respectively + >>> dict(type='ApplyToMultiple', + >>> # case 1: from multiple outer fields + >>> input_mapping=dict(img=['lq', 'gt']), + >>> inplace=True, + >>> # share_random_param=True means using identical random + >>> # parameters in every processing + >>> share_random_param=True, + >>> transforms=[ + >>> dict(type='Crop', crop_size=(384, 384)), + >>> dict(type='Normalize'), + >>> ]) + >>> ] + >>> # Example 2: + >>> pipeline = [ + >>> dict(type='LoadImageFromFile', key='lq'), # low-quality img + >>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img + >>> # ApplyToMultiple maps multiple outer fields to standard the + >>> # inner field and process them with wrapped transforms + >>> # respectively + >>> dict(type='ApplyToMultiple', + >>> # case 2: from one outer field that contains multiple + >>> # data elements (e.g. a list) + >>> # input_mapping=dict(img='images'), + >>> inplace=True, + >>> share_random_param=True, + >>> transforms=[ + >>> dict(type='Crop', crop_size=(384, 384)), + >>> dict(type='Normalize'), + >>> ]) + >>> ] + """ + + def __init__(self, + transforms: List[Union[Dict, Callable[[Dict], Dict]]], + input_mapping: Optional[Dict] = None, + output_mapping: Optional[Dict] = None, + inplace: bool = False, + strict: bool = True, + share_random_params: bool = False): + super().__init__(transforms, input_mapping, output_mapping, inplace, + strict) + + self.share_random_params = share_random_params + + def scatter_sequence(self, data: Dict) -> List[Dict]: + # infer split number from input + seq_len = None + key_rep = None + for key in self.input_mapping: + + assert isinstance(data[key], Sequence) + if seq_len is not None: + if len(data[key]) != seq_len: + raise ValueError('Got inconsistent sequence length: ' + f'{seq_len} ({key_rep}) vs. ' + f'{len(data[key])} ({key})') + else: + seq_len = len(data[key]) + key_rep = key + + scatters = [] + for i in range(seq_len): + scatter = data.copy() + for key in self.input_mapping: + scatter[key] = data[key][i] + scatters.append(scatter) + return scatters + + def transform(self, results: Dict): + # Apply input remapping + inputs = self.remap_input(results, self.input_mapping) + + # Scatter sequential inputs into a list + inputs = self.scatter_sequence(inputs) + + # Control random parameter sharing with a context manager + if self.share_random_params: + # The context manager :func`:cache_random_params` will let + # cacheable method of the transforms cache their outputs. Thus + # the random parameters will only generated once and shared + # by all data items. + ctx = cache_random_params + else: + ctx = nullcontext + + with ctx(self.transforms): + outputs = [self.transforms(_input) for _input in inputs] + + # Collate output scatters (list of dict to dict of list) + outputs = { + key: [_output[key] for _output in outputs] + for key in outputs[0] + } + + # Apply output remapping + if self.output_mapping: + outputs = self.remap_output(outputs, self.output_mapping) + + results.update(outputs) + return results + + +@TRANSFORMS.register_module() +class RandomChoice(BaseTransform): + """Process data with a randomly chosen pipeline from given candidates. + + Args: + pipelines (list[list]): A list of pipeline candidates, each is a + sequence of transforms. + pipeline_probs (list[float], optional): The probabilities associated + with each pipeline. The length should be equal to the pipeline + number and the sum should be 1. If not given, a uniform + distribution will be assumed. + + Examples: + >>> # config + >>> pipeline = [ + >>> dict(type='RandomChoice', + >>> pipelines=[ + >>> [dict(type='RandomHorizontalFlip')], # subpipeline 1 + >>> [dict(type='RandomRotate')], # subpipeline 2 + >>> ] + >>> ) + >>> ] + """ + + def __init__(self, + pipelines: List[List[Union[Dict, Callable[[Dict], Dict]]]], + pipeline_probs: Optional[List[float]] = None): + + if pipeline_probs is not None: + assert mmcv.is_seq_of(pipeline_probs, float) + assert len(pipelines) == len(pipeline_probs), \ + '`pipelines` and `pipeline_probs` must have same lengths. ' \ + f'Got {len(pipelines)} vs {len(pipeline_probs)}.' + assert sum(pipeline_probs) == 1 + + self.pipeline_probs = pipeline_probs + self.pipelines = [Compose(transforms) for transforms in pipelines] + + def transform(self, results): + pipeline = np.random.choice(self.pipelines, p=self.pipeline_probs) + return pipeline(results) diff --git a/tests/test_transform/test_transform_wrapper.py b/tests/test_transform/test_transform_wrapper.py new file mode 100644 index 0000000000..31bfcb4c42 --- /dev/null +++ b/tests/test_transform/test_transform_wrapper.py @@ -0,0 +1,370 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +import pytest + +from mmcv.transform.base import BaseTransform +from mmcv.transform.builder import TRANSFORMS +from mmcv.transform.utils import cache_random_params, cacheable_method +from mmcv.transform.wrappers import (ApplyToMultiple, Compose, RandomChoice, + Remap) + + +@TRANSFORMS.register_module() +class AddToValue(BaseTransform): + """Dummy transform to test transform wrappers.""" + + def __init__(self, constant_addend=0, use_random_addend=False) -> None: + super().__init__() + self.constant_addend = constant_addend + self.use_random_addend = use_random_addend + + @cacheable_method + def get_random_addend(self): + return np.random.rand() + + def transform(self, results): + augend = results['value'] + + if isinstance(augend, list): + warnings.warn('value is a list', UserWarning) + if isinstance(augend, dict): + warnings.warn('value is a dict', UserWarning) + + def _add_to_value(augend, addend): + if isinstance(augend, list): + return [_add_to_value(v, addend) for v in augend] + if isinstance(augend, dict): + return {k: _add_to_value(v, addend) for k, v in augend.items()} + return augend + addend + + if self.use_random_addend: + addend = self.get_random_addend() + else: + addend = self.constant_addend + + results['value'] = _add_to_value(results['value'], addend) + return results + + +@TRANSFORMS.register_module() +class SumTwoValues(BaseTransform): + """Dummy transform to test transform wrappers.""" + + def transform(self, results): + if 'num_1' in results and 'num_2' in results: + results['sum'] = results['num_1'] + results['num_2'] + else: + results['sum'] = np.nan + return results + + +def test_compose(): + + # Case 1: build from cfg + pipeline = [dict(type='AddToValue')] + pipeline = Compose(pipeline) + _ = str(pipeline) + + # Case 2: build from transform list + pipeline = [AddToValue()] + pipeline = Compose(pipeline) + + # Case 3: invalid build arguments + pipeline = [[dict(type='AddToValue')]] + with pytest.raises(TypeError): + pipeline = Compose(pipeline) + + # Case 4: contain transform with None output + class DummyTransform(BaseTransform): + + def transform(self, results): + return None + + pipeline = Compose([DummyTransform()]) + results = pipeline({}) + assert results is None + + +def test_cache_random_parameters(): + + transform = AddToValue(use_random_addend=True) + + # Case 1: cache random parameters + assert hasattr(AddToValue, '_cacheable_methods') + assert 'get_random_addend' in AddToValue._cacheable_methods + + with cache_random_params(transform): + results_1 = transform(dict(value=0)) + results_2 = transform(dict(value=0)) + np.testing.assert_equal(results_1['value'], results_2['value']) + + # Case 2: do not cache random parameters + results_1 = transform(dict(value=0)) + results_2 = transform(dict(value=0)) + with pytest.raises(AssertionError): + np.testing.assert_equal(results_1['value'], results_2['value']) + + # Case 3: invalid use of cacheable methods + with pytest.raises(RuntimeError): + with cache_random_params(transform): + _ = transform.get_random_addend() + + # Case 4: apply on nested transforms + transform = Compose([AddToValue(use_random_addend=True)]) + with cache_random_params(transform): + results_1 = transform(dict(value=0)) + results_2 = transform(dict(value=0)) + np.testing.assert_equal(results_1['value'], results_2['value']) + + +def test_remap(): + + # Case 1: simple remap + pipeline = Remap( + transforms=[AddToValue(constant_addend=1)], + input_mapping=dict(value='v_in'), + output_mapping=dict(value='v_out')) + + results = dict(value=0, v_in=1) + results = pipeline(results) + + np.testing.assert_equal(results['value'], 0) # should be unchanged + np.testing.assert_equal(results['v_in'], 1) + np.testing.assert_equal(results['v_out'], 2) + + # Case 2: collecting list + pipeline = Remap( + transforms=[AddToValue(constant_addend=2)], + input_mapping=dict(value=['v_in_1', 'v_in_2']), + output_mapping=dict(value=['v_out_1', 'v_out_2'])) + results = dict(value=0, v_in_1=1, v_in_2=2) + + with pytest.warns(UserWarning, match='value is a list'): + results = pipeline(results) + + np.testing.assert_equal(results['value'], 0) # should be unchanged + np.testing.assert_equal(results['v_in_1'], 1) + np.testing.assert_equal(results['v_in_2'], 2) + np.testing.assert_equal(results['v_out_1'], 3) + np.testing.assert_equal(results['v_out_2'], 4) + + # Case 3: collecting dict + pipeline = Remap( + transforms=[AddToValue(constant_addend=2)], + input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')), + output_mapping=dict(value=dict(v1='v_out_1', v2='v_out_2'))) + results = dict(value=0, v_in_1=1, v_in_2=2) + + with pytest.warns(UserWarning, match='value is a dict'): + results = pipeline(results) + + np.testing.assert_equal(results['value'], 0) # should be unchanged + np.testing.assert_equal(results['v_in_1'], 1) + np.testing.assert_equal(results['v_in_2'], 2) + np.testing.assert_equal(results['v_out_1'], 3) + np.testing.assert_equal(results['v_out_2'], 4) + + # Case 4: collecting list with inplace mode + pipeline = Remap( + transforms=[AddToValue(constant_addend=2)], + input_mapping=dict(value=['v_in_1', 'v_in_2']), + inplace=True) + results = dict(value=0, v_in_1=1, v_in_2=2) + + with pytest.warns(UserWarning, match='value is a list'): + results = pipeline(results) + + np.testing.assert_equal(results['value'], 0) + np.testing.assert_equal(results['v_in_1'], 3) + np.testing.assert_equal(results['v_in_2'], 4) + + # Case 5: collecting dict with inplace mode + pipeline = Remap( + transforms=[AddToValue(constant_addend=2)], + input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')), + inplace=True) + results = dict(value=0, v_in_1=1, v_in_2=2) + + with pytest.warns(UserWarning, match='value is a dict'): + results = pipeline(results) + + np.testing.assert_equal(results['value'], 0) + np.testing.assert_equal(results['v_in_1'], 3) + np.testing.assert_equal(results['v_in_2'], 4) + + # Case 6: nested collection with inplace mode + pipeline = Remap( + transforms=[AddToValue(constant_addend=2)], + input_mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]), + inplace=True) + results = dict(value=0, v1=1, v21=2, v22=3, v3=4) + + with pytest.warns(UserWarning, match='value is a list'): + results = pipeline(results) + + np.testing.assert_equal(results['value'], 0) + np.testing.assert_equal(results['v1'], 3) + np.testing.assert_equal(results['v21'], 4) + np.testing.assert_equal(results['v22'], 5) + np.testing.assert_equal(results['v3'], 6) + + # Case 7: `strict` must be True if `inplace` is set True + with pytest.raises(ValueError): + pipeline = Remap( + transforms=[AddToValue(constant_addend=2)], + input_mapping=dict(value=['v_in_1', 'v_in_2']), + inplace=True, + strict=False) + + # Case 8: output_map must be None if `inplace` is set True + with pytest.raises(ValueError): + pipeline = Remap( + transforms=[AddToValue(constant_addend=1)], + input_mapping=dict(value='v_in'), + output_mapping=dict(value='v_out'), + inplace=True) + + # Case 9: non-strict input mapping + pipeline = Remap( + transforms=[SumTwoValues()], + input_mapping=dict(num_1='a', num_2='b'), + strict=False) + + results = pipeline(dict(a=1, b=2)) + np.testing.assert_equal(results['sum'], 3) + + results = pipeline(dict(a=1)) + assert np.isnan(results['sum']) + + # Test basic functions + pipeline = Remap( + transforms=[AddToValue(constant_addend=1)], + input_mapping=dict(value='v_in'), + output_mapping=dict(value='v_out')) + + # __iter__ + for _ in pipeline: + pass + + # __repr__ + _ = str(pipeline) + + +def test_apply_to_multiple(): + + # Case 1: apply to list in results + pipeline = ApplyToMultiple( + transforms=[AddToValue(constant_addend=1)], + input_mapping=dict(value='values'), + inplace=True) + results = dict(values=[1, 2]) + + results = pipeline(results) + + np.testing.assert_equal(results['values'], [2, 3]) + + # Case 2: apply to multiple keys + pipeline = ApplyToMultiple( + transforms=[AddToValue(constant_addend=1)], + input_mapping=dict(value=['v_1', 'v_2']), + inplace=True) + results = dict(v_1=1, v_2=2) + + results = pipeline(results) + + np.testing.assert_equal(results['v_1'], 2) + np.testing.assert_equal(results['v_2'], 3) + + # Case 3: apply to multiple groups of keys + pipeline = ApplyToMultiple( + transforms=[SumTwoValues()], + input_mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']), + output_mapping=dict(sum=['a', 'b'])) + + results = dict(a_1=1, a_2=2, b_1=3, b_2=4) + results = pipeline(results) + + np.testing.assert_equal(results['a'], 3) + np.testing.assert_equal(results['b'], 7) + + # Case 4: inconsistent sequence length + with pytest.raises(ValueError): + pipeline = ApplyToMultiple( + transforms=[SumTwoValues()], + input_mapping=dict(num_1='list_1', num_2='list_2')) + + results = dict(list_1=[1, 2], list_2=[1, 2, 3]) + _ = pipeline(results) + + # Case 5: share random parameter + pipeline = ApplyToMultiple( + transforms=[AddToValue(use_random_addend=True)], + input_mapping=dict(value='values'), + inplace=True, + share_random_params=True, + ) + + results = dict(values=[0, 0]) + results = pipeline(results) + + np.testing.assert_equal(results['values'][0], results['values'][1]) + + # Test repr + _ = str(pipeline) + + +def test_randomchoice(): + + # Case 1: given probability + pipeline = RandomChoice( + pipelines=[[AddToValue(constant_addend=1.0)], + [AddToValue(constant_addend=2.0)]], + pipeline_probs=[1.0, 0.0]) + + results = pipeline(dict(value=1)) + np.testing.assert_equal(results['value'], 2.0) + + # Case 1: default probability + pipeline = RandomChoice(pipelines=[[AddToValue( + constant_addend=1.0)], [AddToValue(constant_addend=2.0)]]) + + _ = pipeline(dict(value=1)) + + +def test_utils(): + # Test cacheable_method: normal case + class DummyTransform(BaseTransform): + + @cacheable_method + def func(self): + return np.random.rand() + + def transform(self, results): + _ = self.func() + return results + + transform = DummyTransform() + _ = transform({}) + with cache_random_params(transform): + _ = transform({}) + + # Test cacheable_method: invalid function type + with pytest.raises(TypeError): + + class DummyTransform(): + + @cacheable_method + @staticmethod + def func(): + return np.random.rand() + + # Test cacheable_method: invalid function argument list + with pytest.raises(TypeError): + + class DummyTransform(): + + @cacheable_method + def func(cls): + return np.random.rand()