From 57a02189160501bb44649df1e242c55d80f2b975 Mon Sep 17 00:00:00 2001 From: xzq Date: Thu, 23 Feb 2023 15:39:48 +0800 Subject: [PATCH] update --- mmdet3d/models/decode_heads/cylinder3d_head.py | 6 +++--- .../test_decode_heads/test_cylinder3d_head.py | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index 9284ffd87f..dd13fd4dd3 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -144,9 +144,9 @@ def predict( `gt_pts_seg`. We use `point2voxel_map` in this function. Returns: - torch.Tensor: Output point-wise segmentation logits. + List[torch.Tensor]: List of point-wise segmentation logits. """ - seg_logits = self.forward(inputs) + seg_logits = self.forward(inputs).features seg_pred_list = [] coors = batch_inputs_dict['voxels']['voxel_coors'] @@ -157,4 +157,4 @@ def predict( point_seg_predicts = seg_logits_sample[point2voxel_map] seg_pred_list.append(point_seg_predicts) - return seg_logits + return seg_pred_list diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py index 1caa11d802..3bb62c5eef 100644 --- a/tests/test_models/test_decode_heads/test_cylinder3d_head.py +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -59,8 +59,9 @@ def test_cylinder3d_head_loss(self): self.assertGreater(loss_ce, 0, 'ce loss should be positive') self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive') + batch_inputs_dict = dict(voxels=dict(voxel_coors=coors)) datasample.gt_pts_seg.point2voxel_map = torch.randint( - 0, 50, (100, 1)).int().cuda() - point_logits = cylinder3d_head.predict(sparse_voxels, coors, - datasample) - assert point_logits.shape == (100, 20) + 0, 50, (100, )).int().cuda() + point_logits = cylinder3d_head.predict(sparse_voxels, + batch_inputs_dict, [datasample]) + assert point_logits[0].shape == torch.Size([100, 20])