Skip to content

Commit

Permalink
Support mmseg:dev-1.x (#790)
Browse files Browse the repository at this point in the history
* support pspnet + ort

* add rewriting for adapt_avg_pool

* test pspnet

* resize seg_pred to original image shape

* run with test.py

* keep as original

* fix ut of segmentation

* update var name

* fix export to torchscript

* sync with mmseg:test-1.x branch

* fix ut

* fix regression test for mmseg

* fix mmseg.ops

* update mmseg yml

* fix mmseg2.0 sdk

* fix adaptive pool

* update rewriting and tests

* fix sdk inputs
  • Loading branch information
RunningLeon authored Sep 14, 2022
1 parent 0aad635 commit 06028d6
Show file tree
Hide file tree
Showing 30 changed files with 647 additions and 1,079 deletions.
4 changes: 3 additions & 1 deletion configs/mmseg/segmentation_sdk_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@

backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
dict(type='LoadAnnotations'),
dict(
type='PackSegInputs', meta_keys=['img_path', 'ori_shape', 'img_shape'])
])
2 changes: 1 addition & 1 deletion mmdeploy/apis/core/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def pop_mp_output(self, call_id: int = None) -> Any:
call_id = self._call_id if call_id is None else call_id
if call_id not in self._mp_dict:
get_root_logger().error(
f'`{self._func_name}` with Call id: {call_id} failed. exit.')
f'`{self._func_name}` with Call id: {call_id} failed.')
exit(1)
ret = self._mp_dict[call_id]
self._mp_dict.pop(call_id)
Expand Down
5 changes: 4 additions & 1 deletion mmdeploy/apis/pytorch2torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def torch2torchscript(img: Any,
task_processor = build_task_processor(model_cfg, deploy_cfg, device)

torch_model = task_processor.build_pytorch_model(model_checkpoint)
_, model_inputs = task_processor.create_input(img, input_shape)
_, model_inputs = task_processor.create_input(
img,
input_shape,
data_preprocessor=getattr(torch_model, 'data_preprocessor', None))
if not isinstance(model_inputs, torch.Tensor):
model_inputs = model_inputs[0]

Expand Down
10 changes: 8 additions & 2 deletions mmdeploy/codebase/base/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,20 @@ def build_pytorch_model(self,
nn.Module: An initialized torch model generated by other OpenMMLab
codebases.
"""
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import MODELS

model = deepcopy(self.model_cfg.model)
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
model.setdefault('data_preprocessor', preprocess_cfg)
model = MODELS.build(model)
if model_checkpoint is not None:
from mmengine.runner.checkpoint import load_checkpoint
load_checkpoint(model, model_checkpoint)

model = revert_sync_batchnorm(model)
model = model.to(self.device)
model.eval()

return model

def build_dataset(self,
Expand Down Expand Up @@ -280,7 +283,10 @@ def visualize(self,
visualizer = self.get_visualizer(window_name, save_dir)

name = osp.splitext(save_name)[0]
image = mmcv.imread(image, channel_order='rgb')
if isinstance(image, str):
image = mmcv.imread(image, channel_order='rgb')
assert isinstance(image, np.ndarray)

visualizer.add_datasample(
name,
image,
Expand Down
3 changes: 1 addition & 2 deletions mmdeploy/codebase/mmseg/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mmsegmentation import MMSegmentation
from .segmentation import Segmentation

__all__ = ['MMSegmentation', 'Segmentation']
__all__ = ['Segmentation']
148 changes: 0 additions & 148 deletions mmdeploy/codebase/mmseg/deploy/mmsegmentation.py

This file was deleted.

Loading

0 comments on commit 06028d6

Please sign in to comment.