diff --git a/configs/pointpillars/metafile.yml b/configs/pointpillars/metafile.yml index d076b2f741..ac2bfab281 100644 --- a/configs/pointpillars/metafile.yml +++ b/configs/pointpillars/metafile.yml @@ -29,6 +29,7 @@ Models: Weights: https://download.openmmlab.com/mmdetection3d/v1.0.0_models/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-car/hv_pointpillars_secfpn_6x8_160e_kitti-3d-car_20220331_134606-d42d15ed.pth - Name: pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class + Alias: pointpillars_kitti-3class In Collection: PointPillars Config: configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py Metadata: @@ -164,8 +165,9 @@ Models: Weights: https://download.openmmlab.com/mmdetection3d/v0.1.0_models/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-car_20200901_204315-302fc3e7.pth - Name: hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class + Alias: pointpillars_waymod5-3class In Collection: PointPillars - Config: configs/pointpillars/hv_pointpillars_secfpn_sbn_2x16_2x_waymoD5-3d-3class.py + Config: configs/pointpillars/pointpillars_hv_secfpn_sbn-all_16xb2-2x_waymoD5-3d-3class.py Metadata: Training Data: Waymo Training Memory (GB): 8.12 diff --git a/mmdet3d/apis/__init__.py b/mmdet3d/apis/__init__.py index d1b52e4c7c..a5526d9759 100644 --- a/mmdet3d/apis/__init__.py +++ b/mmdet3d/apis/__init__.py @@ -3,10 +3,12 @@ inference_mono_3d_detector, inference_multi_modality_detector, inference_segmentor, init_model) -from .inferencers import BaseDet3DInferencer, MonoDet3DInferencer +from .inferencers import (BaseDet3DInferencer, LidarDet3DInferencer, + MonoDet3DInferencer) __all__ = [ 'inference_detector', 'init_model', 'inference_mono_3d_detector', 'convert_SyncBN', 'inference_multi_modality_detector', - 'inference_segmentor', 'BaseDet3DInferencer', 'MonoDet3DInferencer' + 'inference_segmentor', 'BaseDet3DInferencer', 'MonoDet3DInferencer', + 'LidarDet3DInferencer' ] diff --git a/mmdet3d/apis/inferencers/__init__.py b/mmdet3d/apis/inferencers/__init__.py index 0aaf0b2984..26dbe9cb0b 100644 --- a/mmdet3d/apis/inferencers/__init__.py +++ b/mmdet3d/apis/inferencers/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_det3d_inferencer import BaseDet3DInferencer +from .lidar_det3d_inferencer import LidarDet3DInferencer from .mono_det3d_inferencer import MonoDet3DInferencer -__all__ = ['BaseDet3DInferencer', 'MonoDet3DInferencer'] +__all__ = [ + 'BaseDet3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer' +] diff --git a/mmdet3d/apis/inferencers/base_det3d_inferencer.py b/mmdet3d/apis/inferencers/base_det3d_inferencer.py index db7b4ae041..e16531c576 100644 --- a/mmdet3d/apis/inferencers/base_det3d_inferencer.py +++ b/mmdet3d/apis/inferencers/base_det3d_inferencer.py @@ -4,6 +4,8 @@ import mmengine import numpy as np import torch.nn as nn +from mmengine.fileio import (get_file_backend, isdir, join_path, + list_dir_or_file) from mmengine.infer.infer import BaseInferencer, ModelType from mmengine.runner import load_checkpoint from mmengine.structures import InstanceData @@ -110,6 +112,51 @@ def _init_model( model.eval() return model + def _inputs_to_list( + self, + inputs: Union[dict, list], + modality_key: Union[str, List[str]] = 'points') -> list: + """Preprocess the inputs to a list. + + Preprocess inputs to a list according to its type: + + - list or tuple: return inputs + - dict: the value of key 'points'/`img` is + - Directory path: return all files in the directory + - other cases: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + + Args: + inputs (Union[dict, list]): Inputs for the inferencer. + modality_key (Union[str, List[str]], optional): The key of the + modality. Defaults to 'points'. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + if isinstance(modality_key, str): + modality_key = [modality_key] + assert set(modality_key).issubset({'points', 'img'}) + + for key in modality_key: + if isinstance(inputs, dict) and isinstance(inputs[key], str): + img = inputs[key] + backend = get_file_backend(img) + if hasattr(backend, 'isdir') and isdir(img): + # Backends like HttpsBackend do not implement `isdir`, so + # only those backends that implement `isdir` could accept + # the inputs as a directory + filename_list = list_dir_or_file(img, list_dir=False) + inputs = [{ + f'{key}': join_path(img, filename) + } for filename in filename_list] + + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + return list(inputs) + def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int: """Returns the index of the transform in a pipeline. @@ -240,7 +287,7 @@ def pred2dict(self, data_sample: InstanceData) -> Dict: """ pred_instances = data_sample.pred_instances_3d.numpy() result = { - 'bboxes_3d': pred_instances.bboxes_3d.tensor.numpy().tolist(), + 'bboxes_3d': pred_instances.bboxes_3d.tensor.cpu().tolist(), 'labels_3d': pred_instances.labels_3d.tolist(), 'scores_3d': pred_instances.scores_3d.tolist() } diff --git a/mmdet3d/apis/inferencers/lidar_det3d_inferencer.py b/mmdet3d/apis/inferencers/lidar_det3d_inferencer.py new file mode 100644 index 0000000000..27488a3e74 --- /dev/null +++ b/mmdet3d/apis/inferencers/lidar_det3d_inferencer.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Dict, List, Optional, Sequence, Union + +import mmengine +import numpy as np +from mmengine.dataset import Compose +from mmengine.infer.infer import ModelType +from mmengine.structures import InstanceData + +from mmdet3d.registry import INFERENCERS +from mmdet3d.utils import ConfigType, register_all_modules +from .base_det3d_inferencer import BaseDet3DInferencer + +InstanceList = List[InstanceData] +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[InstanceData, InstanceList] +ImgType = Union[np.ndarray, Sequence[np.ndarray]] +ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] + + +@INFERENCERS.register_module(name='det3d-lidar') +@INFERENCERS.register_module() +class LidarDet3DInferencer(BaseDet3DInferencer): + """The inferencer of LiDAR-based detection. + + Args: + model (str, optional): Path to the config file or the model name + defined in metafile. For example, it could be + "pointpillars_kitti-3class" or + "configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py". # noqa: E501 + If model is not specified, user must provide the + `weights` saved by MMEngine which contains the config string. + Defaults to None. + weights (str, optional): Path to the checkpoint. If it is not specified + and model is a model name of metafile, the weights will be loaded + from metafile. Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + scope (str, optional): The scope of registry. + palette (str, optional): The palette of visualization. + """ + + preprocess_kwargs: set = set() + forward_kwargs: set = set() + visualize_kwargs: set = { + 'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr', + 'img_out_dir' + } + postprocess_kwargs: set = { + 'print_result', 'pred_out_file', 'return_datasample' + } + + def __init__(self, + model: Union[ModelType, str, None] = None, + weights: Optional[str] = None, + device: Optional[str] = None, + scope: Optional[str] = 'mmdet3d', + palette: str = 'none') -> None: + # A global counter tracking the number of frames processed, for + # naming of the output results + self.num_visualized_frames = 0 + self.palette = palette + register_all_modules() + super().__init__( + model=model, weights=weights, device=device, scope=scope) + + def _inputs_to_list(self, inputs: Union[dict, list]) -> list: + """Preprocess the inputs to a list. + + Preprocess inputs to a list according to its type: + + - list or tuple: return inputs + - dict: the value with key 'points' is + - Directory path: return all files in the directory + - other cases: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + + Args: + inputs (Union[dict, list]): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + return super()._inputs_to_list(inputs, modality_key='points') + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline.""" + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + + load_img_idx = self._get_transform_idx(pipeline_cfg, + 'LoadPointsFromFile') + if load_img_idx == -1: + raise ValueError( + 'LoadPointsFromFile is not found in the test pipeline') + + load_cfg = pipeline_cfg[load_img_idx] + self.coord_type, self.load_dim = load_cfg['coord_type'], load_cfg[ + 'load_dim'] + self.use_dim = list(range(load_cfg['use_dim'])) if isinstance( + load_cfg['use_dim'], int) else load_cfg['use_dim'] + + pipeline_cfg[load_img_idx]['type'] = 'LidarDet3DInferencerLoader' + return Compose(pipeline_cfg) + + def visualize(self, + inputs: InputsType, + preds: PredType, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + pred_score_thr: float = 0.3, + img_out_dir: str = '') -> Union[List[np.ndarray], None]: + """Visualize predictions. + + Args: + inputs (InputsType): Inputs for the inferencer. + preds (PredType): Predictions of the model. + return_vis (bool): Whether to return the visualization result. + Defaults to False. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw predicted bounding boxes. + Defaults to True. + pred_score_thr (float): Minimum score of bboxes to draw. + Defaults to 0.3. + img_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + Returns: + List[np.ndarray] or None: Returns visualization results only if + applicable. + """ + if self.visualizer is None or (not show and img_out_dir == '' + and not return_vis): + return None + + if getattr(self, 'visualizer') is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + results = [] + + for single_input, pred in zip(inputs, preds): + single_input = single_input['points'] + if isinstance(single_input, str): + pts_bytes = mmengine.fileio.get(single_input) + points = np.frombuffer(pts_bytes, dtype=np.float32) + points = points.reshape(-1, self.load_dim) + points = points[:, self.use_dim] + pc_name = osp.basename(single_input).split('.bin')[0] + pc_name = f'{pc_name}.png' + elif isinstance(single_input, np.ndarray): + points = single_input.copy() + pc_num = str(self.num_visualized_frames).zfill(8) + pc_name = f'pc_{pc_num}.png' + else: + raise ValueError('Unsupported input type: ' + f'{type(single_input)}') + + o3d_save_path = osp.join(img_out_dir, pc_name) \ + if img_out_dir != '' else None + + data_input = dict(points=points) + self.visualizer.add_datasample( + pc_name, + data_input, + pred, + show=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=draw_pred, + pred_score_thr=pred_score_thr, + o3d_save_path=o3d_save_path, + vis_task='lidar_det', + ) + results.append(points) + self.num_visualized_frames += 1 + + return results diff --git a/mmdet3d/apis/inferencers/mono_det3d_inferencer.py b/mmdet3d/apis/inferencers/mono_det3d_inferencer.py index 95259da399..f89289884e 100644 --- a/mmdet3d/apis/inferencers/mono_det3d_inferencer.py +++ b/mmdet3d/apis/inferencers/mono_det3d_inferencer.py @@ -6,11 +6,10 @@ import mmengine import numpy as np from mmengine.dataset import Compose -from mmengine.fileio import (get_file_backend, isdir, join_path, - list_dir_or_file) from mmengine.infer.infer import ModelType from mmengine.structures import InstanceData +from mmdet3d.registry import INFERENCERS from mmdet3d.utils import ConfigType from .base_det3d_inferencer import BaseDet3DInferencer @@ -22,6 +21,8 @@ ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] +@INFERENCERS.register_module(name='det3d-mono') +@INFERENCERS.register_module() class MonoDet3DInferencer(BaseDet3DInferencer): """MMDet3D Monocular 3D object detection inferencer. @@ -75,7 +76,7 @@ def _inputs_to_list(self, inputs: Union[dict, list]) -> list: Preprocess inputs to a list according to its type: - list or tuple: return inputs - - dict: + - dict: the value with key 'img' is - Directory path: return all files in the directory - other cases: return a list containing the string. The string could be a path to file, a url or other types of string according @@ -87,22 +88,7 @@ def _inputs_to_list(self, inputs: Union[dict, list]) -> list: Returns: list: List of input for the :meth:`preprocess`. """ - - if isinstance(inputs, dict) and isinstance(inputs['img'], str): - img = inputs['img'] - backend = get_file_backend(img) - if hasattr(backend, 'isdir') and isdir(img): - # Backends like HttpsBackend do not implement `isdir`, so only - # those backends that implement `isdir` could accept the inputs - # as a directory - filename_list = list_dir_or_file(img, list_dir=False) - img = [join_path(img, filename) for filename in filename_list] - inputs['img'] = img - - if not isinstance(inputs, (list, tuple)): - inputs = [inputs] - - return list(inputs) + return super()._inputs_to_list(inputs, modality_key='img') def _init_pipeline(self, cfg: ConfigType) -> Compose: """Initialize the test pipeline.""" @@ -113,7 +99,7 @@ def _init_pipeline(self, cfg: ConfigType) -> Compose: if load_img_idx == -1: raise ValueError( 'LoadImageFromFileMono3D is not found in the test pipeline') - pipeline_cfg[load_img_idx]['type'] = 'Mono3DInferencerLoader' + pipeline_cfg[load_img_idx]['type'] = 'MonoDet3DInferencerLoader' return Compose(pipeline_cfg) def visualize(self, @@ -167,7 +153,7 @@ def visualize(self, img_name = f'{img_num}.jpg' else: raise ValueError('Unsupported input type: ' - f'{type(single_input)}') + f"{type(single_input['img'])}") out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \ else None diff --git a/mmdet3d/datasets/transforms/__init__.py b/mmdet3d/datasets/transforms/__init__.py index 86e79307ae..72d5a8f42b 100644 --- a/mmdet3d/datasets/transforms/__init__.py +++ b/mmdet3d/datasets/transforms/__init__.py @@ -4,7 +4,7 @@ from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D, LoadMultiViewImageFromFiles, LoadPointsFromDict, LoadPointsFromFile, LoadPointsFromMultiSweeps, - Mono3DInferencerLoader, NormalizePointsColor, + MonoDet3DInferencerLoader, NormalizePointsColor, PointSegClassMapping) from .test_time_aug import MultiScaleFlipAug3D # yapf: disable @@ -29,5 +29,6 @@ 'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize', 'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D', - 'MultiViewWrapper', 'PhotoMetricDistortion3D', 'Mono3DInferencerLoader' + 'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader', + 'LidarDet3DInferencerLoader' ] diff --git a/mmdet3d/datasets/transforms/loading.py b/mmdet3d/datasets/transforms/loading.py index a3bcefc5af..74b387dc6f 100644 --- a/mmdet3d/datasets/transforms/loading.py +++ b/mmdet3d/datasets/transforms/loading.py @@ -709,7 +709,21 @@ class LoadPointsFromDict(LoadPointsFromFile): """Load Points From Dict.""" def transform(self, results: dict) -> dict: + """Convert the type of points from ndarray to corresponding + `point_class`. + + Args: + results (dict): input result. The value of key `points` is a + numpy array. + + Returns: + dict: The processed results. + """ assert 'points' in results + points_class = get_points_type(self.coord_type) + points = results['points'] + results['points'] = points_class( + points, points_dim=points.shape[-1], attribute_dims=None) return results @@ -1001,15 +1015,71 @@ def __repr__(self) -> str: @TRANSFORMS.register_module() -class Mono3DInferencerLoader(BaseTransform): +class LidarDet3DInferencerLoader(BaseTransform): + """Load point cloud in the Inferencer's pipeline. + + Added keys: + - points + - timestamp + - axis_align_matrix + - box_type_3d + - box_mode_3d + """ + + def __init__(self, coord_type='LIDAR', **kwargs) -> None: + super().__init__() + self.from_file = TRANSFORMS.build( + dict(type='LoadPointsFromFile', coord_type=coord_type, **kwargs)) + self.from_ndarray = TRANSFORMS.build( + dict(type='LoadPointsFromDict', coord_type=coord_type, **kwargs)) + self.box_type_3d, self.box_mode_3d = get_box_type(coord_type) + + def transform(self, single_input: dict) -> dict: + """Transform function to add image meta information. + Args: + single_input (dict): Single input. + + Returns: + dict: The dict contains loaded image and meta information. + """ + assert 'points' in single_input, "key 'points' must be in input dict" + if isinstance(single_input['points'], str): + inputs = dict( + lidar_points=dict(lidar_path=single_input['points']), + timestamp=1, + # for ScanNet demo we need axis_align_matrix + axis_align_matrix=np.eye(4), + box_type_3d=self.box_type_3d, + box_mode_3d=self.box_mode_3d) + elif isinstance(single_input['points'], np.ndarray): + inputs = dict( + points=single_input['points'], + timestamp=1, + # for ScanNet demo we need axis_align_matrix + axis_align_matrix=np.eye(4), + box_type_3d=self.box_type_3d, + box_mode_3d=self.box_mode_3d) + else: + raise ValueError('Unsupported input points type: ' + f"{type(single_input['points'])}") + + if 'points' in inputs: + return self.from_ndarray(inputs) + return self.from_file(inputs) + + +@TRANSFORMS.register_module() +class MonoDet3DInferencerLoader(BaseTransform): """Load an image from ``results['images']['CAMX']['img']``. Similar with :obj:`LoadImageFromFileMono3D`, but the image has been loaded as :obj:`np.ndarray` in ``results['images']['CAMX']['img']``. - Args: - to_float32 (bool): Whether to convert the loaded image to a float32 - numpy array. If set to False, the loaded image is an uint8 array. - Defaults to False. + Added keys: + - img + - cam2img + - box_type_3d + - box_mode_3d + """ def __init__(self, **kwargs) -> None: @@ -1029,6 +1099,8 @@ def transform(self, single_input: dict) -> dict: dict: The dict contains loaded image and meta information. """ box_type_3d, box_mode_3d = get_box_type('camera') + assert 'calib' in single_input and 'img' in single_input, \ + "key 'calib' and 'img' must be in input dict" if isinstance(single_input['calib'], str): calib_path = single_input['calib'] with open(calib_path, 'r') as f: @@ -1039,8 +1111,8 @@ def transform(self, single_input: dict) -> dict: elif isinstance(single_input['calib'], np.ndarray): cam2img = single_input['calib'] else: - raise ValueError('Unsupported input type: ' - f'{type(single_input)}') + raise ValueError('Unsupported input calib type: ' + f"{type(single_input['calib'])}") if isinstance(single_input['img'], str): inputs = dict( @@ -1056,8 +1128,8 @@ def transform(self, single_input: dict) -> dict: box_type_3d=box_type_3d, box_mode_3d=box_mode_3d) else: - raise ValueError('Unsupported input type: ' - f'{type(single_input)}') + raise ValueError('Unsupported input image type: ' + f"{type(single_input['img'])}") if 'img' in inputs: return self.from_ndarray(inputs) diff --git a/mmdet3d/registry.py b/mmdet3d/registry.py index 6807567fd8..4e568c53d6 100644 --- a/mmdet3d/registry.py +++ b/mmdet3d/registry.py @@ -10,6 +10,7 @@ from mmengine.registry import DATASETS as MMENGINE_DATASETS from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS from mmengine.registry import LOOPS as MMENGINE_LOOPS from mmengine.registry import METRICS as MMENGINE_METRICS @@ -80,3 +81,6 @@ # manage logprocessor LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS) + +# manage inferencer +INFERENCERS = Registry('inferencer', parent=MMENGINE_INFERENCERS) diff --git a/tests/test_apis/test_inferencers/test_lidar_det3d_inferencer.py b/tests/test_apis/test_inferencers/test_lidar_det3d_inferencer.py new file mode 100644 index 0000000000..5bcb2148b5 --- /dev/null +++ b/tests/test_apis/test_inferencers/test_lidar_det3d_inferencer.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase + +import mmengine +import numpy as np +import torch +from mmengine.utils import is_list_of + +from mmdet3d.apis import LidarDet3DInferencer +from mmdet3d.structures import Det3DDataSample + + +class TestLidarDet3DInferencer(TestCase): + + def setUp(self): + # init from alias + self.inferencer = LidarDet3DInferencer('pointpillars_kitti-3class') + + def test_init(self): + # init from metafile + LidarDet3DInferencer('pointpillars_waymod5-3class') + # init from cfg + LidarDet3DInferencer( + 'configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py', # noqa + weights= # noqa + 'https://download.openmmlab.com/mmdetection3d/v1.0.0_models/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class_20220301_150306-37dc2420.pth' # noqa + ) + + def assert_predictions_equal(self, preds1, preds2): + for pred1, pred2 in zip(preds1, preds2): + if 'bboxes_3d' in pred1: + self.assertTrue( + np.allclose(pred1['bboxes_3d'], pred2['bboxes_3d'], 0.1)) + if 'scores_3d' in pred1: + self.assertTrue( + np.allclose(pred1['scores_3d'], pred2['scores_3d'], 0.1)) + if 'labels_3d' in pred1: + self.assertTrue( + np.allclose(pred1['labels_3d'], pred2['labels_3d'])) + + def test_call(self): + if not torch.cuda.is_available(): + return + # single img + inputs = dict(points='tests/data/kitti/training/velodyne/000000.bin') + res_path = self.inferencer(inputs, return_vis=True) + # ndarray + pts_bytes = mmengine.fileio.get(inputs['points']) + points = np.frombuffer(pts_bytes, dtype=np.float32) + points = points.reshape(-1, 4) + points = points[:, :4] + inputs = dict(points=points) + res_ndarray = self.inferencer(inputs, return_vis=True) + self.assert_predictions_equal(res_path['predictions'], + res_ndarray['predictions']) + self.assertIn('visualization', res_path) + self.assertIn('visualization', res_ndarray) + + # multiple images + inputs = [ + dict(points='tests/data/kitti/training/velodyne/000000.bin'), + dict(points='tests/data/kitti/training/velodyne/000000.bin') + ] + res_path = self.inferencer(inputs, return_vis=True) + # list of ndarray + all_points = [] + for p in inputs: + pts_bytes = mmengine.fileio.get(p['points']) + points = np.frombuffer(pts_bytes, dtype=np.float32) + points = points.reshape(-1, 4) + all_points.append(dict(points=points)) + res_ndarray = self.inferencer(all_points, return_vis=True) + self.assert_predictions_equal(res_path['predictions'], + res_ndarray['predictions']) + self.assertIn('visualization', res_path) + self.assertIn('visualization', res_ndarray) + + # point cloud dir, test different batch sizes + pc_dir = dict(points='tests/data/kitti/training/velodyne/') + res_bs2 = self.inferencer(pc_dir, batch_size=2, return_vis=True) + self.assertIn('visualization', res_bs2) + self.assertIn('predictions', res_bs2) + + def test_visualize(self): + if not torch.cuda.is_available(): + return + inputs = dict(points='tests/data/kitti/training/velodyne/000000.bin'), + # img_out_dir + with tempfile.TemporaryDirectory() as tmp_dir: + self.inferencer(inputs, img_out_dir=tmp_dir) + # TODO: For LiDAR-based detection, the saved image only exists when + # show=True. + # self.assertTrue(osp.exists(osp.join(tmp_dir, '000000.png'))) + + def test_postprocess(self): + if not torch.cuda.is_available(): + return + # return_datasample + inputs = dict(points='tests/data/kitti/training/velodyne/000000.bin') + res = self.inferencer(inputs, return_datasamples=True) + self.assertTrue(is_list_of(res['predictions'], Det3DDataSample)) + + # pred_out_file + with tempfile.TemporaryDirectory() as tmp_dir: + pred_out_file = osp.join(tmp_dir, 'tmp.json') + res = self.inferencer( + inputs, print_result=True, pred_out_file=pred_out_file) + dumped_res = mmengine.load(pred_out_file) + self.assert_predictions_equal(res['predictions'], + dumped_res['predictions']) diff --git a/tests/test_apis/test_inferencers/test_mono3d_det_inferencer.py b/tests/test_apis/test_inferencers/test_mono_det3d_inferencer.py similarity index 100% rename from tests/test_apis/test_inferencers/test_mono3d_det_inferencer.py rename to tests/test_apis/test_inferencers/test_mono_det3d_inferencer.py