Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Refactor high-level APIs #1410

Merged
merged 87 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
eca56c6
add high-level inference api and run conditional model good.
liuwenran Nov 2, 2022
8ec6c53
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
liuwenran Nov 2, 2022
b0bff2c
[high-level api] run unconditional model.
liuwenran Nov 2, 2022
a2eb0aa
Merge branch 'dev-1.x' of https://github.com/liuwenran/mmediting into…
liuwenran Nov 2, 2022
6d2593d
[high-level api] add matting inferencer.
liuwenran Nov 3, 2022
163c28d
[high-level api] add inpainting inferencer.
liuwenran Nov 3, 2022
63e2930
[high-level api] add translation inferencer.
liuwenran Nov 3, 2022
8f30c54
[high-level api] add restoration inferencer.
liuwenran Nov 3, 2022
d6de502
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
liuwenran Nov 4, 2022
8a83973
[high-level api] add inference functions
liuwenran Nov 4, 2022
0fdb238
Merge branch 'dev-1.x' into dev-1.x
liuwenran Nov 7, 2022
e33c2ca
[high-level api] add video interpolation inference
liuwenran Nov 7, 2022
d51f6e2
Merge branch 'dev-1.x' of https://github.com/liuwenran/mmediting into…
liuwenran Nov 7, 2022
c47c262
[high-level api] add video restoration inferencer
liuwenran Nov 7, 2022
064dd5b
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
liuwenran Nov 8, 2022
ac6973a
[high-level api] pass linter check
liuwenran Nov 8, 2022
5f130c5
[hight-level api] append linter check
liuwenran Nov 8, 2022
8cee84a
[high-level api] delete old interface py code
liuwenran Nov 8, 2022
09f846f
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
liuwenran Nov 8, 2022
66bb49b
[high-level api] remove unused code.
liuwenran Nov 8, 2022
4657903
[high-level api] add comments for inferences.
liuwenran Nov 8, 2022
c181555
[high-level api] delete unused parameters and add extra parameters.
liuwenran Nov 9, 2022
5a621b2
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
liuwenran Nov 9, 2022
f57d6b0
Merge branch 'dev-1.x' of https://github.com/liuwenran/mmediting into…
liuwenran Nov 9, 2022
a13dfad
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
liuwenran Nov 9, 2022
1aec4aa
[high-level api] add inference tutorial.
liuwenran Nov 9, 2022
0658dd9
[high-level api] add unit test for inference_functions
liuwenran Nov 10, 2022
19c1d32
[high-level api] fix path error.
liuwenran Nov 10, 2022
02a9873
[high-level api] delete old unit test file.
liuwenran Nov 10, 2022
5b8f00e
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
liuwenran Nov 10, 2022
dafeccc
[high-level api] add ut for edit and conditional inferncer.
liuwenran Nov 10, 2022
4ff4992
[high-level api] fix edit ut bug and add ut of base_inference, inpain…
liuwenran Nov 10, 2022
95f7fa4
[high-level api] fix edit.py ut and add two uts.
liuwenran Nov 10, 2022
43ba5d6
[high-level api] add result_out_dir to UTs and add two new uts.
liuwenran Nov 10, 2022
ca345e8
[high-level api] fix unconditional ut out of ram and add two uts.
liuwenran Nov 10, 2022
33b9433
[high-level api] try to satisfy ut code coverage.
liuwenran Nov 10, 2022
c160d7d
[high-level api] use pytest.raises to catch ut error.
liuwenran Nov 10, 2022
e9dcb82
[high-level api] add more test case in inference_functions
liuwenran Nov 10, 2022
12c31d2
[high-level api] add video restoration test case.
liuwenran Nov 10, 2022
9864642
[high-level api] video interpolation support dir input output and add…
liuwenran Nov 11, 2022
6ee380c
[high-level api] video restoration support input_dir and add uts.
liuwenran Nov 11, 2022
d57ae4b
[high-level api] add more uts for inference funcs.
liuwenran Nov 11, 2022
a9bf85a
[high-level api] fix bug in inference_functions
liuwenran Nov 11, 2022
6501045
[high-level api] add more uts.
liuwenran Nov 11, 2022
ea3f934
[high-level api] make colorization inference be tested.
liuwenran Nov 11, 2022
f74eaea
[high-level api] roll back last commit.
liuwenran Nov 11, 2022
842fc48
[high-level api] delete unused code.
liuwenran Nov 11, 2022
05b3259
[high-level api] remove default value.
liuwenran Nov 15, 2022
b94a67e
[high-level api] rename version to setting.
liuwenran Nov 15, 2022
6b5f86c
[high-level api] add examples in edit.py and remove duplicated funcs.
liuwenran Nov 15, 2022
a63b392
Merge branch 'dev-1.x' into dev-1.x
liuwenran Nov 16, 2022
cdcbc23
[high-level api] add log for functions not used.
liuwenran Nov 16, 2022
a8b3621
Merge branch 'dev-1.x' of https://github.com/liuwenran/mmediting into…
liuwenran Nov 16, 2022
69e1378
[high-level api] load mean std from cfg and use basedataelement
liuwenran Nov 16, 2022
da2557b
[high-level api] do unittest with cu102 version.
liuwenran Nov 16, 2022
374583d
[high-level api] add more uts.
liuwenran Nov 16, 2022
6015489
[high-level api] revert change in da2557b
liuwenran Nov 16, 2022
67c8172
[high-level api] add uts.
liuwenran Nov 16, 2022
7dc7207
[high-level api] replace ckpt to http url./
liuwenran Nov 16, 2022
31f9359
Merge branch 'dev-1.x' into dev-1.x
liuwenran Nov 16, 2022
4940944
[high-level api] read config from metafile and download ckpt automati…
liuwenran Nov 17, 2022
25fe31e
Merge branch 'dev-1.x' of https://github.com/liuwenran/mmediting into…
liuwenran Nov 17, 2022
0606266
[high-level api] make default setting all to 0.
liuwenran Nov 17, 2022
517144f
[high-level api] add content in ipynb
liuwenran Nov 17, 2022
c05d212
[high-level api] add ut for test_edit modification.
liuwenran Nov 18, 2022
5a16055
[high-level api] fix task in scripts and metafiles
liuwenran Nov 18, 2022
b6a9307
[high-level api] read task name from metafile
liuwenran Nov 18, 2022
4f952b7
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
liuwenran Nov 18, 2022
f7d8edd
[high-level api] reproduce nafnet metafile
liuwenran Nov 18, 2022
6ff01cb
[high-level api] put good example to inference ipynb
liuwenran Nov 20, 2022
23b6c0d
[high-level api] revert inpainting demo to global local.
liuwenran Nov 20, 2022
91d7137
[high-level api] add model setting and extra parameters in ipynb
liuwenran Nov 21, 2022
5dcc15a
[high-level api] add README.md and more instructions in ipynb.
liuwenran Nov 21, 2022
b77d794
[high-level api] append to last commit.
liuwenran Nov 21, 2022
e883473
[high-level api] fix typo.
liuwenran Nov 21, 2022
f15e2e2
[high-level api] add ut to refresh building.
liuwenran Nov 21, 2022
ae1d938
Merge branch 'dev-1.x' into dev-1.x
liuwenran Nov 22, 2022
6c40d79
[high-level api] fix readme review comments.
liuwenran Nov 22, 2022
49e13e7
[high-level api] refresh readme.
liuwenran Nov 23, 2022
945ac86
[high-level api] refresh demo readme and fix typo
liuwenran Nov 23, 2022
a933c3d
Merge branch 'dev-1.x' into dev-1.x
liuwenran Nov 23, 2022
face3a8
[high-level api] resolve review comment
Nov 23, 2022
df3cddc
Merge branch 'dev-1.x' of https://github.com/liuwenran/mmediting into…
Nov 23, 2022
5e383be
[high-level api] fix type.
Nov 23, 2022
9c922c9
[high-level api] revert change.
Nov 23, 2022
a022dbd
[high-level api] misc change.
Nov 23, 2022
1a4a0cd
[high-level api] fix type again.
Nov 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions demo/mmediting_inference_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from argparse import ArgumentParser

from mmedit.edit import MMEdit

# resources/input/matting/beach_fg.png

def parse_args():
parser = ArgumentParser()
parser.add_argument(
'--img',
type=str,
default='resources/input/restoration/0901x2.png',
help='Input image file.')
parser.add_argument(
'--label',
type=int,
default=1,
help='Input label for conditional models.')
parser.add_argument(
'--trimap',
type=str,
default='resources/input/matting/beach_trimap.png',
help='Input for matting models.')
parser.add_argument(
'--mask',
type=str,
default='resources/input/inpainting/mask_2_resized.png',
liuwenran marked this conversation as resolved.
Show resolved Hide resolved
help='path to input mask file')
parser.add_argument(
'--img-out-dir',
type=str,
default='resources/demo_results/inferencer_samples_apis.png',
help='Output directory of images.')
parser.add_argument(
'--model-name',
type=str,
default='esrgan',
help='Pretrained editing algorithm')
parser.add_argument(
'--model-version',
type=str,
default='a',
help='Pretrained editing algorithm')
parser.add_argument(
'--model-config',
type=str,
default=None,
help='Path to the custom config file of the selected editing model.')
parser.add_argument(
'--model-ckpt',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected det model.')
parser.add_argument(
'--device',
type=str,
default='cuda',
help='Device used for inference.')
parser.add_argument(
'--show',
action='store_true',
help='Display the image in a popup window.')
parser.add_argument(
'--print-result',
action='store_true',
help='Whether to print the results.')
parser.add_argument(
'--pred-out-file',
type=str,
default='',
help='File to save the inference results.')

args = parser.parse_args()
return args

def main():
args = parse_args()
editor = MMEdit(**vars(args))
editor.infer(**vars(args))

if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions mmedit/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .restoration_video_inference import restoration_video_inference
from .translation_inference import sample_img2img_model
from .video_interpolation_inference import video_interpolation_inference
from .inferencers import *

__all__ = [
'init_model', 'delete_cfg', 'set_random_seed', 'matting_inference',
Expand Down
6 changes: 6 additions & 0 deletions mmedit/apis/inferencers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mmedit_inferencer import MMEditInferencer

__all__ = [
'MMEditInferencer'
]
262 changes: 262 additions & 0 deletions mmedit/apis/inferencers/base_inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright (c) OpenMMLab. All rights reserved.
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import mmcv
import os.path as osp
from mmengine.config import Config
from mmengine.runner import load_checkpoint
from mmengine.structures import InstanceData
from mmengine.dataset import Compose

from mmedit.registry import MODELS, VISUALIZERS
from mmedit.utils import ConfigType

InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[InstanceData, InstanceList]
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict]]


class BaseInferencer:
"""Base inferencer.

Args:
model (str or ConfigType): Model config or the path to it.
ckpt (str, optional): Path to the checkpoint.
device (str, optional): Device to run inference. If None, the best
device will be automatically used.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
pred_out_file: File to save the inference results. If left as empty, no
file will be saved.
print_result (bool): Whether to print the result.
Defaults to False.
"""

def __init__(self,
config: Union[ConfigType, str],
ckpt: Optional[str],
device: Optional[str] = None,
**kwargs) -> None:
# Load config to cfg
if isinstance(config, str):
cfg = Config.fromfile(config)
elif not isinstance(config, ConfigType):
raise TypeError('config must be a filename or any ConfigType'
f'object, but got {type(cfg)}')
self.cfg = cfg
if cfg.model.get('pretrained'):
cfg.model.pretrained = None

if device is None:
device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
self._init_model(cfg, ckpt, device)
self._init_visualizer(cfg)

# A global counter tracking the number of images processed, for
# naming of the output images
self.num_visualized_imgs = 0

def _init_model(self, cfg: Union[ConfigType, str], ckpt: Optional[str],
device: str) -> None:
"""Initialize the model with the given config and checkpoint on the
specific device."""
model = MODELS.build(cfg.model)
if ckpt is not None:
ckpt = load_checkpoint(model, ckpt, map_location='cpu')
model.cfg = cfg
model.to(device)
model.eval()
self.model = model

def _init_pipeline(self, cfg: ConfigType) -> None:
"""Initialize the test pipeline."""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline

self.file_pipeline = Compose(pipeline_cfg)

def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
"""Returns the index of the transform in a pipeline.

If the transform is not found, returns -1.
"""
for i, transform in enumerate(pipeline_cfg):
if transform['type'] == name:
return i
return -1

def _init_visualizer(self, cfg: ConfigType) -> None:
"""Initialize visualizers."""
# TODO: We don't export images via backends since the interface
# of the visualizer will have to be refactored.
self.visualizer = None
if 'visualizer' in cfg:
ts = str(datetime.timestamp(datetime.now()))
cfg.visualizer['name'] = f'inferencer{ts}'
self.visualizer = VISUALIZERS.build(cfg.visualizer)

def preprocess(self, inputs: InputsType) -> Dict:
"""Process the inputs into a model-feedable format."""
self._init_pipeline(self.cfg)

results = []
for single_input in inputs:
if isinstance(single_input, str):
if osp.isdir(single_input):
raise ValueError('Feeding a directory is not supported')
else:
data_ = dict(img_path=single_input)
results.append(self.file_pipeline(data_))
elif isinstance(single_input, np.ndarray):
data_ = dict(img=single_input)
results.append(self.ndarray_pipeline(data_))
else:
raise ValueError(
f'Unsupported input type: {type(single_input)}')

return self._collate(results)

def _collate(self, results: List[Dict]) -> Dict:
"""Collate the results from different images."""
results = {key: [d[key] for d in results] for key in results[0]}
return results

def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model."""
with torch.no_grad():
return self.model.test_step(inputs)

def visualize(self,
inputs: InputsType,
preds: PredType,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '') -> List[np.ndarray]:
"""Visualize predictions.

Args:
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
preds (List[Dict]): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
"""
if self.visualizer is None or not show and img_out_dir == '':
return None

if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None.')

results = []

for single_input, pred in zip(inputs, preds):
if isinstance(single_input, str):
img = mmcv.imread(single_input)
img = img[:, :, ::-1]
img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray):
img = single_input.copy()
img_num = str(self.num_visualized_imgs).zfill(8)
img_name = f'{img_num}.jpg'
else:
raise ValueError('Unsupported input type: '
f'{type(single_input)}')

out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
else None

self.visualizer.add_datasample(
img_name,
img,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
out_file=out_file,
)
results.append(img)
self.num_visualized_imgs += 1

return results

def postprocess(
self,
preds: PredType,
imgs: Optional[List[np.ndarray]] = None,
is_batch: bool = False,
print_result: bool = False,
pred_out_file: str = '',
get_datasample: bool = False,
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Postprocess predictions.

Args:
preds (List[Dict]): Predictions of the model.
imgs (Optional[np.ndarray]): Visualized predictions.
is_batch (bool): Whether the inputs are in a batch.
Defaults to False.
print_result (bool): Whether to print the result.
Defaults to False.
pred_out_file (str): Output file name to store predictions
without images. Supported file formats are “json”, “yaml/yml”
and “pickle/pkl”. Defaults to ''.
get_datasample (bool): Whether to use Datasample to store
inference results. If False, dict will be used.

Returns:
TODO
"""

results = preds
if not get_datasample:
results = []
for pred in preds:
result = self._pred2dict(pred)
results.append(result)
if not is_batch:
results = results[0]
if print_result:
print(results)
# Add img to the results after printing
if pred_out_file != '':
mmcv.dump(results, pred_out_file)
if imgs is None:
return results
return results, imgs

def _pred2dict(self, data_sample: torch.Tensor) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary. It's better to contain only basic data elements such as
strings and numbers in order to guarantee it's json-serializable.

Args:
data_sample (torch.Tensor): The data sample to be converted.

Returns:
dict: The output dictionary.
"""
result = {}
result['infer_res'] = data_sample
return result
Loading