Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add BaseInferencer to MMEngine #874

Merged
merged 7 commits into from
Jan 16, 2023
Merged

[Feature] Add BaseInferencer to MMEngine #874

merged 7 commits into from
Jan 16, 2023

Conversation

HAOCHENYE
Copy link
Collaborator

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Add BaseInferencer for providing an easy and clean interface for single or multiple images inferencing.

How to use

class XXXInferencer(BaseInferencer):
    ...

inferencer = XXXInferencer(model='path-to-config', weights='path-to-checkpoint', device='cuda')
results = inferencer('img or filename or directory')
predictions = results['predictions']
visualization = results['visualization']

How to build an inferencer

Based on BaseInferencer, the __init__ of subclasses should at least accept 3 arguments:

  • model: model could be many types:
    • model name define in metafile such as this. OpenMMLab repos could also update each model defined in metafile with Alias, the config could also be inferred by the Alias
    • path to the config
    • Config, ConfigDict or dict instance.
  • weights: Path to the loaded weights, if it is not specified, model must be a model name defined in the metafile with a default weights.
  • device: Device to run inference.

The standard workflow of inferencer

BaseInferencer implement the standard inference work flow in __call__:

  1. Preprocess data and return an iterable chunked data(dataloader).
  2. Transverse the chunked data and inference with the batch data.
  3. visualize the predictions and return visualization
  4. postprocess the predictions to the target format, and return the postprocessed results and visualization with a dict.

if DownStream repos want to customize the workflow, they can override the __call__ method

preprocess data

prepare the pipeline(abstract method)

subclasses should override the _init_pipeline to customize the pipeline. The returned pipeline will be used to process each single data.

prepare the collate_fn

BaseInferencer provides a common way to get collate_fn from cfg. If you have more custom usage, you can override the method to get target collate_fn

prepare the chunked data.

subclasses could override the prepare to get a custom chunked data. BaseInferencer provides a common way to build the chunked data in _get_chunk_data.

preprocess will use the prepared pipeline, collate_fn to return a target chunked data, of which each item could be directly passed to model.test_step

forward

Inference with the chunked data. BaseInferencer call model.test_step in forward by default.

visualize(abstract method)

Subclasses should implement visualize to visualize the result and return the visualization result.

postprocess(abstract method)

Subclasses should implement postprocess to get the target format result(DataSample or dict) and the visualization result.

from typing import Any, Callable, List, Optional, Union

import numpy as np

from mmengine.infer.infer import (BaseInferencer, ConfigType, InputsType,
                                  ModelType)
from mmengine.visualization import Visualizer


class ToyInferencer(BaseInferencer):
    # kwargs defined here should be able to be dispatched to the
    # corresponding methods.
    preprocess_kwargs: set = {'preprocess_arg'}
    forward_kwargs: set = {'forward_arg'}
    visualize_kwargs: set = {'vis_arg'}
    postprocess_kwargs: set = {'pos_arg'}

    def __init__(
        self,
        model: Union[ModelType, str],
        weights: Optional[str] = None,
        device: Optional[str] = None,
    ) -> None:
        super().__init__(model, weights, device)

    def _init_pipeline(self, cfg: ConfigType) -> Callable:
        ...
        # return a pipeline to handle varies of input data, such as str,
        # np.ndarray. It is an abstract method in BaseInferencer, and should be
        # implemented in subclasses

        # The pipeline will be used to process a single data. It will be used
        # in preprocess like this:

        # dataset = map(self.pipeline, dataset)

    def _init_collate(self, cfg: ConfigType) -> Callable:
        # Get the collate_fn from cfg.
        # BaseInferencer will try to get collate_fn from
        # `cfg.test_dataloader.collate_fn`, if it is None, it will return the
        # pseudo collate_fn.
        return super()._init_collate(cfg)

    def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]:
        # Return a visualizer
        return super()._init_visualizer(cfg)

    def __call__(self,
                 inputs: InputsType,
                 return_datasamples: bool = False,
                 batch_size: int = 1,
                 **kwargs) -> dict:
        # BaseInferencer
        return super().__call__(inputs, return_datasamples, batch_size,
                                **kwargs)

    def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
        # Implement your own preprocess logic here. Preprocess should return an
        # iterable object, of which each item will be used as the input of
        # model.test_step.

        # BaseInferencer.preprocess will return an iterable dataloader, which
        # will be used like this in __call__:
        # dataloader = self.preprocess(inputs, batch_size, **kwargs)
        # for batch in dataloader:
        #    preds = self.forward(batch, **kwargs)
        return super().preprocess(inputs, batch_size, **kwargs)

    def forward(self, batch, forward_arg=None, **kwargs):
        # Implement your own forward logic here. Forward should return the
        # prediction of the model.
        return super().forward(batch, **kwargs)

    def visualize(self,
                  inputs: InputsType,
                  preds: Any,
                  show: bool = False,
                  **kwargs) -> List[np.ndarray]:
        # Visualize the predictions and return the visualization results.
        return super().visualize(inputs, preds, show, **kwargs)

    def postprocess(self,
                    preds: Any,
                    visualization: List[np.ndarray],
                    return_datasample=False,
                    **kwargs) -> dict:
        ...
        # Subclasses should implement this methods to do postprocess, such as
        # convert datasample to simple dict or dump the results.
        # It must return a dict with key `predictions` and `visualization`. See
        # more information in doctring 

The coverage rate of unit test:

image

Modification

Please briefly describe what modification is made in this PR.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMCls.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@HAOCHENYE HAOCHENYE changed the title Inference [Feature] Add BaseInferencer to MMEngine Jan 13, 2023
@HAOCHENYE HAOCHENYE added this to the 0.5.0 milestone Jan 13, 2023
gaotongxiao
gaotongxiao previously approved these changes Jan 16, 2023
HAOCHENYE and others added 5 commits January 16, 2023 15:31
* Update BaseInferencer

* Fix ci

* Fix CI and rename iferencer to infer

* Fix CI

* Add renamed file

* Add test file

* Adjust interface sequence

* refine preprocess

* Update unit test

Update unit test

* Update unit test

* Fix unit test

* Fix as comment

* Minor refine

* Fix docstring and support load image from different backend

* Support load collate_fn from downstream repos, refine dispatch

* Minor refine

* Fix lint

* refine grammar

* Remove FileClient

* Refine docstring

* add rich

* Add list_models

* Add list_models

* Remove backend args

* Minor refine
* Add preprocess inputs

* Add type hint

* update api/infer in index.rst

* rename preprocess_inputs to _inputs_to_list

* Fix doc format

* Update infer.py

Co-authored-by: Zaida Zhou <[email protected]>
* first commit

* [Enhance] Support build model from weight

* minor refine

* Fix type hint

* refine comments

* Update docstring

* refine as comment

* Add  method

* Refine docstring

* Fix as comment

* refine comments

* Refine warning message

* Fix unit test and refine comments
mmengine/infer/infer.py Outdated Show resolved Hide resolved
mmengine/infer/infer.py Outdated Show resolved Hide resolved
mmengine/infer/infer.py Outdated Show resolved Hide resolved
@zhouzaida zhouzaida merged commit 2d8f2be into main Jan 16, 2023
@zhouzaida zhouzaida deleted the inference branch January 16, 2023 08:01
@zhouzaida zhouzaida restored the inference branch January 16, 2023 08:05
zhouzaida added a commit that referenced this pull request Jan 16, 2023
@zhouzaida zhouzaida deleted the inference branch January 16, 2023 08:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants