Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xizaoqu committed Feb 23, 2023
1 parent 952fa9e commit 57a0218
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
6 changes: 3 additions & 3 deletions mmdet3d/models/decode_heads/cylinder3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def predict(
`gt_pts_seg`. We use `point2voxel_map` in this function.
Returns:
torch.Tensor: Output point-wise segmentation logits.
List[torch.Tensor]: List of point-wise segmentation logits.
"""
seg_logits = self.forward(inputs)
seg_logits = self.forward(inputs).features

seg_pred_list = []
coors = batch_inputs_dict['voxels']['voxel_coors']
Expand All @@ -157,4 +157,4 @@ def predict(
point_seg_predicts = seg_logits_sample[point2voxel_map]
seg_pred_list.append(point_seg_predicts)

return seg_logits
return seg_pred_list
9 changes: 5 additions & 4 deletions tests/test_models/test_decode_heads/test_cylinder3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def test_cylinder3d_head_loss(self):
self.assertGreater(loss_ce, 0, 'ce loss should be positive')
self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive')

batch_inputs_dict = dict(voxels=dict(voxel_coors=coors))
datasample.gt_pts_seg.point2voxel_map = torch.randint(
0, 50, (100, 1)).int().cuda()
point_logits = cylinder3d_head.predict(sparse_voxels, coors,
datasample)
assert point_logits.shape == (100, 20)
0, 50, (100, )).int().cuda()
point_logits = cylinder3d_head.predict(sparse_voxels,
batch_inputs_dict, [datasample])
assert point_logits[0].shape == torch.Size([100, 20])

0 comments on commit 57a0218

Please sign in to comment.