diff --git a/mmdet3d/models/decode_heads/__init__.py b/mmdet3d/models/decode_heads/__init__.py index 2e86c7c8a9..2a1f07a338 100644 --- a/mmdet3d/models/decode_heads/__init__.py +++ b/mmdet3d/models/decode_heads/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .cylinder3d_head import Cylinder3DHead from .dgcnn_head import DGCNNHead from .paconv_head import PAConvHead from .pointnet2_head import PointNet2Head -__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead'] +__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead'] diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py new file mode 100644 index 0000000000..dd13fd4dd3 --- /dev/null +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from mmcv.ops import SparseConvTensor, SparseModule, SubMConv3d + +from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import SampleList +from mmdet3d.utils import OptMultiConfig +from mmdet3d.utils.typing_utils import ConfigType +from .decode_head import Base3DDecodeHead + + +@MODELS.register_module() +class Cylinder3DHead(Base3DDecodeHead): + """Cylinder3D decoder head. + + Decoder head used in `Cylinder3D `_. + Refer to the + `official code `_. + + Args: + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + dropout_ratio (float): Ratio of dropout layer. Defaults to 0. + conv_cfg (dict or :obj:`ConfigDict`): Config of conv layers. + Defaults to dict(type='Conv1d'). + norm_cfg (dict or :obj:`ConfigDict`): Config of norm layers. + Defaults to dict(type='BN1d'). + act_cfg (dict or :obj:`ConfigDict`): Config of activation layers. + Defaults to dict(type='ReLU'). + loss_ce (dict or :obj:`ConfigDict`): Config of CrossEntropy loss. + Defaults to dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0). + loss_lovasz (dict or :obj:`ConfigDict`): Config of Lovasz loss. + Defaults to dict(type='LovaszLoss', loss_weight=1.0). + conv_seg_kernel_size (int): The kernel size used in conv_seg. + Defaults to 3. + ignore_index (int): The label index to be ignored. When using masked + BCE loss, ignore_index should be set to None. Defaults to 0. + init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`], + optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + channels: int, + num_classes: int, + dropout_ratio: float = 0, + conv_cfg: ConfigType = dict(type='Conv1d'), + norm_cfg: ConfigType = dict(type='BN1d'), + act_cfg: ConfigType = dict(type='ReLU'), + loss_ce: ConfigType = dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0), + loss_lovasz: ConfigType = dict( + type='LovaszLoss', loss_weight=1.0), + conv_seg_kernel_size: int = 3, + ignore_index: int = 0, + init_cfg: OptMultiConfig = None) -> None: + super(Cylinder3DHead, self).__init__( + channels=channels, + num_classes=num_classes, + dropout_ratio=dropout_ratio, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + conv_seg_kernel_size=conv_seg_kernel_size, + init_cfg=init_cfg) + + self.loss_lovasz = MODELS.build(loss_lovasz) + self.loss_ce = MODELS.build(loss_ce) + self.ignore_index = ignore_index + + def build_conv_seg(self, channels: int, num_classes: int, + kernel_size: int) -> SparseModule: + return SubMConv3d( + channels, + num_classes, + indice_key='logit', + kernel_size=kernel_size, + stride=1, + padding=1, + bias=True) + + def forward(self, sparse_voxels: SparseConvTensor) -> SparseConvTensor: + """Forward function.""" + sparse_logits = self.cls_seg(sparse_voxels) + return sparse_logits + + def loss_by_feat(self, seg_logit: SparseConvTensor, + batch_data_samples: SampleList) -> dict: + """Compute semantic segmentation loss. + + Args: + seg_logit (SparseConvTensor): Predicted per-voxel + segmentation logits of shape [num_voxels, num_classes] + stored in SparseConvTensor. + batch_data_samples (List[:obj:`Det3DDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_pts_seg`. + + Returns: + Dict[str, Tensor]: A dictionary of loss components. + """ + + gt_semantic_segs = [ + data_sample.gt_pts_seg.voxel_semantic_mask + for data_sample in batch_data_samples + ] + seg_label = torch.cat(gt_semantic_segs) + seg_logit_feat = seg_logit.features + loss = dict() + loss['loss_ce'] = self.loss_ce( + seg_logit_feat, seg_label, ignore_index=self.ignore_index) + seg_logit_feat = seg_logit_feat.permute(1, 0)[None, :, :, + None] # pseudo BCHW + loss['loss_lovasz'] = self.loss_lovasz( + seg_logit_feat, seg_label, ignore_index=self.ignore_index) + + return loss + + def predict( + self, + inputs: SparseConvTensor, + batch_inputs_dict: dict, + batch_data_samples: SampleList, + ) -> torch.Tensor: + """Forward function for testing. + + Args: + inputs (SparseConvTensor): Feature from backbone. + batch_inputs_dict (dict): Input sample dict which includes 'points' + and 'voxels' keys. + + - points (List[Tensor]): Point cloud of each sample. + - voxels (dict): Dict of voxelized voxels and the corresponding + coordinates. + batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data + samples. It usually includes information such as `metainfo` and + `gt_pts_seg`. We use `point2voxel_map` in this function. + + Returns: + List[torch.Tensor]: List of point-wise segmentation logits. + """ + seg_logits = self.forward(inputs).features + + seg_pred_list = [] + coors = batch_inputs_dict['voxels']['voxel_coors'] + for batch_idx in range(len(batch_data_samples)): + seg_logits_sample = seg_logits[coors[:, 0] == batch_idx] + point2voxel_map = batch_data_samples[ + batch_idx].gt_pts_seg.point2voxel_map.long() + point_seg_predicts = seg_logits_sample[point2voxel_map] + seg_pred_list.append(point_seg_predicts) + + return seg_pred_list diff --git a/mmdet3d/models/decode_heads/decode_head.py b/mmdet3d/models/decode_heads/decode_head.py index a9999e1f98..58688d8df5 100644 --- a/mmdet3d/models/decode_heads/decode_head.py +++ b/mmdet3d/models/decode_heads/decode_head.py @@ -51,6 +51,8 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): loss_decode (dict or :obj:`ConfigDict`): Config of decode loss. Defaults to dict(type='mmdet.CrossEntropyLoss', use_sigmoid=False, class_weight=None, loss_weight=1.0). + conv_seg_kernel_size (int): The kernel size used in conv_seg. + Defaults to 1. ignore_index (int): The label index to be ignored. When using masked BCE loss, ignore_index should be set to None. Defaults to 255. init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`], @@ -69,6 +71,7 @@ def __init__(self, use_sigmoid=False, class_weight=None, loss_weight=1.0), + conv_seg_kernel_size: int = 1, ignore_index: int = 255, init_cfg: OptMultiConfig = None) -> None: super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg) @@ -81,7 +84,10 @@ def __init__(self, self.loss_decode = MODELS.build(loss_decode) self.ignore_index = ignore_index - self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1) + self.conv_seg = self.build_conv_seg( + channels=channels, + num_classes=num_classes, + kernel_size=conv_seg_kernel_size) if dropout_ratio > 0: self.dropout = nn.Dropout(dropout_ratio) else: @@ -97,6 +103,11 @@ def forward(self, feats_dict: dict) -> Tensor: """Placeholder of forward function.""" pass + def build_conv_seg(self, channels: int, num_classes: int, + kernel_size: int) -> nn.Module: + """Build Convolutional Segmentation Layers.""" + return nn.Conv1d(channels, num_classes, kernel_size=kernel_size) + def cls_seg(self, feat: Tensor) -> Tensor: """Classify each points.""" if self.dropout is not None: diff --git a/mmdet3d/models/losses/__init__.py b/mmdet3d/models/losses/__init__.py index 84f9cea6ca..6956c7219d 100644 --- a/mmdet3d/models/losses/__init__.py +++ b/mmdet3d/models/losses/__init__.py @@ -3,6 +3,7 @@ from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss from .chamfer_distance import ChamferDistance, chamfer_distance +from .lovasz_loss import LovaszLoss from .multibin_loss import MultiBinLoss from .paconv_regularization_loss import PAConvRegularizationLoss from .rotated_iou_loss import RotatedIoU3DLoss, rotated_iou_3d_loss @@ -12,5 +13,5 @@ 'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance', 'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss', 'PAConvRegularizationLoss', 'UncertainL1Loss', 'UncertainSmoothL1Loss', - 'MultiBinLoss', 'RotatedIoU3DLoss', 'rotated_iou_3d_loss' + 'MultiBinLoss', 'RotatedIoU3DLoss', 'rotated_iou_3d_loss', 'LovaszLoss' ] diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py new file mode 100644 index 0000000000..a9bcc270bd --- /dev/null +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -0,0 +1,356 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Directly borrowed from mmsegmentation. + +Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor +ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim +Berman 2018 ESAT-PSI KU Leuven (MIT License) +""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.models import weight_reduce_loss +from mmengine.utils import is_list_of + +from mmdet3d.registry import MODELS + + +def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor: + """Computes gradient of the Lovasz extension w.r.t sorted errors. + + See Alg. 1 in paper. + `The Lovasz-Softmax loss. `_. + + Args: + gt_sorted (torch.Tensor): Sorted ground truth. + + Return: + torch.Tensor: Gradient of the Lovasz extension. + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def flatten_binary_logits( + logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """Flatten predictions and labels in the batch (binary case). Remove + tensors whose labels equal to 'ignore_index'. + + Args: + probs (torch.Tensor): Predictions to be modified. + labels (torch.Tensor): Labels to be modified. + ignore_index (int, optional): The label index to be ignored. + Defaults to None. + + Return: + tuple(torch.Tensor, torch.Tensor): Modified predictions and labels. + """ + logits = logits.view(-1) + labels = labels.view(-1) + if ignore_index is None: + return logits, labels + valid = (labels != ignore_index) + vlogits = logits[valid] + vlabels = labels[valid] + return vlogits, vlabels + + +def flatten_probs( + probs: torch.Tensor, + labels: torch.Tensor, + ignore_index: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """Flatten predictions and labels in the batch. Remove tensors whose labels + equal to 'ignore_index'. + + Args: + probs (torch.Tensor): Predictions to be modified. + labels (torch.Tensor): Labels to be modified. + ignore_index (int, optional): The label index to be ignored. + Defaults to None. + + Return: + tuple(torch.Tensor, torch.Tensor): Modified predictions and labels. + """ + if probs.dim() != 2: # for input with P*C + if probs.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probs.size() + probs = probs.view(B, 1, H, W) + B, C, H, W = probs.size() + probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, + C) # B*H*W, C=P,C + labels = labels.view(-1) + if ignore_index is None: + return probs, labels + valid = (labels != ignore_index) + vprobs = probs[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobs, vlabels + + +def lovasz_hinge_flat(logits: torch.Tensor, + labels: torch.Tensor) -> torch.Tensor: + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): Logits at each prediction + (between -infty and +infty) with shape [P]. + labels (torch.Tensor): Binary ground truth labels (0 or 1) + with shape [P]. + + Returns: + torch.Tensor: The calculated loss. + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * signs) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def lovasz_hinge(logits: torch.Tensor, + labels: torch.Tensor, + classes: Optional[Union[str, List[int]]] = None, + per_sample: bool = False, + class_weight: Optional[List[float]] = None, + reduction: str = 'mean', + avg_factor: Optional[int] = None, + ignore_index: int = 255) -> torch.Tensor: + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): Logits at each pixel + (between -infty and +infty) with shape [B, H, W]. + labels (torch.Tensor): Binary ground truth masks (0 or 1) + with shape [B, H, W]. + classes (Union[str, list[int]], optional): Placeholder, to be + consistent with other loss. Defaults to None. + per_sample (bool): If per_sample is True, compute the loss per + sample instead of per batch. Defaults to False. + class_weight (list[float], optional): Placeholder, to be consistent + with other loss. Defaults to None. + reduction (str): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_sample is True. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_sample is True. + Defaults to None. + ignore_index (Union[int, None]): The label index to be ignored. + Defaults to 255. + + Returns: + torch.Tensor: The calculated loss. + """ + if per_sample: + loss = [ + lovasz_hinge_flat(*flatten_binary_logits( + logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) + for logit, label in zip(logits, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_hinge_flat( + *flatten_binary_logits(logits, labels, ignore_index)) + return loss + + +def lovasz_softmax_flat( + probs: torch.Tensor, + labels: torch.Tensor, + classes: Union[str, List[int]] = 'present', + class_weight: Optional[List[float]] = None) -> torch.Tensor: + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): Class probabilities at each prediction + (between 0 and 1) with shape [P, C] + labels (torch.Tensor): Ground truth labels (between 0 and C - 1) + with shape [P]. + classes (Union[str, list[int]]): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Defaults to 'present'. + class_weight (list[float], optional): The weight for each class. + Defaults to None. + + Returns: + torch.Tensor: The calculated loss. + """ + if probs.numel() == 0: + # only void pixels, the gradients should be 0 + return probs * 0. + C = probs.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes == 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probs[:, 0] + else: + class_pred = probs[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted)) + if class_weight is not None: + loss *= class_weight[c] + losses.append(loss) + return torch.stack(losses).mean() + + +def lovasz_softmax(probs: torch.Tensor, + labels: torch.Tensor, + classes: Union[str, List[int]] = 'present', + per_sample: bool = False, + class_weight: List[float] = None, + reduction: str = 'mean', + avg_factor: Optional[int] = None, + ignore_index: int = 255) -> torch.Tensor: + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): Class probabilities at each + prediction (between 0 and 1) with shape [B, C, H, W]. + labels (torch.Tensor): Ground truth labels (between 0 and + C - 1) with shape [B, H, W]. + classes (Union[str, list[int]]): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Defaults to 'present'. + per_sample (bool): If per_sample is True, compute the loss per + sample instead of per batch. Defaults to False. + class_weight (list[float], optional): The weight for each class. + Defaults to None. + reduction (str): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_sample is True. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_sample is True. + Defaults to None. + ignore_index (Union[int, None]): The label index to be ignored. + Defaults to 255. + + Returns: + torch.Tensor: The calculated loss. + """ + + if per_sample: + loss = [ + lovasz_softmax_flat( + *flatten_probs( + prob.unsqueeze(0), label.unsqueeze(0), ignore_index), + classes=classes, + class_weight=class_weight) + for prob, label in zip(probs, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_softmax_flat( + *flatten_probs(probs, labels, ignore_index), + classes=classes, + class_weight=class_weight) + return loss + + +@MODELS.register_module() +class LovaszLoss(nn.Module): + """LovaszLoss. + + This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate + for the optimization of the intersection-over-union measure in neural + networks `_. + + Args: + loss_type (str): Binary or multi-class loss. + Defaults to 'multi_class'. Options are "binary" and "multi_class". + classes (Union[str, list[int]]): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Defaults to 'present'. + per_sample (bool): If per_sample is True, compute the loss per + sample instead of per batch. Defaults to False. + reduction (str): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_sample is True. Defaults to 'mean'. + class_weight ([list[float], optional): Weight of each class. + Defaults to None. + loss_weight (float): Weight of the loss. Defaults to 1.0. + """ + + def __init__(self, + loss_type: str = 'multi_class', + classes: Union[str, List[int]] = 'present', + per_sample: bool = False, + reduction: str = 'mean', + class_weight: Optional[List[float]] = None, + loss_weight: float = 1.0): + super().__init__() + assert loss_type in ('binary', 'multi_class'), "loss_type should be \ + 'binary' or 'multi_class'." + + if loss_type == 'binary': + self.cls_criterion = lovasz_hinge + else: + self.cls_criterion = lovasz_softmax + assert classes in ('all', 'present') or is_list_of(classes, int) + if not per_sample: + assert reduction == 'none', "reduction should be 'none' when \ + per_sample is False." + + self.classes = classes + self.per_sample = per_sample + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + + def forward(self, + cls_score: torch.Tensor, + label: torch.Tensor, + avg_factor: int = None, + reduction_override: str = None, + **kwargs) -> torch.Tensor: + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # if multi-class loss, transform logits to probs + if self.cls_criterion == lovasz_softmax: + cls_score = F.softmax(cls_score, dim=1) + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + self.classes, + self.per_sample, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py new file mode 100644 index 0000000000..3bb62c5eef --- /dev/null +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch +from mmcv.ops import SparseConvTensor + +from mmdet3d.models.decode_heads import Cylinder3DHead +from mmdet3d.structures import Det3DDataSample, PointData + + +class TestCylinder3DHead(TestCase): + + def test_cylinder3d_head_loss(self): + """Tests Cylinder3D head loss.""" + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + cylinder3d_head = Cylinder3DHead( + channels=128, + num_classes=20, + loss_ce=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0), + loss_lovasz=dict( + type='LovaszLoss', loss_weight=1.0, reduction='none'), + ).cuda() + + voxel_feats = torch.rand(50, 128).cuda() + coorx = torch.randint(0, 480, (50, 1)).int().cuda() + coory = torch.randint(0, 360, (50, 1)).int().cuda() + coorz = torch.randint(0, 32, (50, 1)).int().cuda() + coorbatch0 = torch.zeros(50, 1).int().cuda() + coors = torch.cat([coorbatch0, coorx, coory, coorz], dim=1) + grid_size = [480, 360, 32] + batch_size = 1 + + sparse_voxels = SparseConvTensor(voxel_feats, coors, grid_size, + batch_size) + # Test forward + seg_logits = cylinder3d_head.forward(sparse_voxels) + + self.assertEqual(seg_logits.features.shape, torch.Size([50, 20])) + + # When truth is non-empty then losses + # should be nonzero for random inputs + voxel_semantic_mask = torch.randint(0, 20, (50, )).long().cuda() + gt_pts_seg = PointData(voxel_semantic_mask=voxel_semantic_mask) + + datasample = Det3DDataSample() + datasample.gt_pts_seg = gt_pts_seg + + losses = cylinder3d_head.loss_by_feat(seg_logits, [datasample]) + + loss_ce = losses['loss_ce'].item() + loss_lovasz = losses['loss_lovasz'].item() + + self.assertGreater(loss_ce, 0, 'ce loss should be positive') + self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive') + + batch_inputs_dict = dict(voxels=dict(voxel_coors=coors)) + datasample.gt_pts_seg.point2voxel_map = torch.randint( + 0, 50, (100, )).int().cuda() + point_logits = cylinder3d_head.predict(sparse_voxels, + batch_inputs_dict, [datasample]) + assert point_logits[0].shape == torch.Size([100, 20])