From 526b8f2e1c17d8efa13c4bd00ca83ad4832d7769 Mon Sep 17 00:00:00 2001 From: xzq Date: Fri, 17 Feb 2023 14:33:10 +0800 Subject: [PATCH 01/18] add cylinder decode head --- mmdet3d/models/decode_heads/__init__.py | 3 +- .../models/decode_heads/cylinder3d_head.py | 121 ++++++++++++++++++ mmdet3d/models/decode_heads/decode_head.py | 13 +- .../test_decode_heads/test_cylinder3d_head.py | 59 +++++++++ 4 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 mmdet3d/models/decode_heads/cylinder3d_head.py create mode 100644 tests/test_models/test_decode_heads/test_cylinder3d_head.py diff --git a/mmdet3d/models/decode_heads/__init__.py b/mmdet3d/models/decode_heads/__init__.py index 2e86c7c8a9..2a1f07a338 100644 --- a/mmdet3d/models/decode_heads/__init__.py +++ b/mmdet3d/models/decode_heads/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .cylinder3d_head import Cylinder3DHead from .dgcnn_head import DGCNNHead from .paconv_head import PAConvHead from .pointnet2_head import PointNet2Head -__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead'] +__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead'] diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py new file mode 100644 index 0000000000..630c33927c --- /dev/null +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import spconv +import torch + +from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import SampleList +from mmdet3d.utils.typing_utils import ConfigType +from .decode_head import Base3DDecodeHead + + +@MODELS.register_module() +class Cylinder3DHead(Base3DDecodeHead): + """Cylinder3D decoder head. + + Decoder head used in `Cylinder3D `_. + Refer to the + `official code `_. + + Args: + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + dropout_ratio (float): Ratio of dropout layer. Defaults to 0.5. + conv_cfg (dict or :obj:`ConfigDict`): Config of conv layers. + Defaults to dict(type='Conv1d'). + norm_cfg (dict or :obj:`ConfigDict`): Config of norm layers. + Defaults to dict(type='BN1d'). + act_cfg (dict or :obj:`ConfigDict`): Config of activation layers. + Defaults to dict(type='ReLU'). + loss_ce (dict or :obj:`ConfigDict`): Config of CrossEntropy loss. + Defaults to dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0). + loss_lovasz (dict or :obj:`ConfigDict`): Config of Lovasz loss. + Defaults to dict(type='mmseg.LovaszLoss', loss_weight=1.0). + conv_seg_kernel_size (int): The kernel size used in conv_seg. + Defaults to 1. + ignore_index (int): The label index to be ignored. When using masked + BCE loss, ignore_index should be set to None. Defaults to 255. + init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`], + optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + channels: int, + num_classes: int, + dropout_ratio: float = 0, + conv_cfg: ConfigType = dict(type='Conv1d'), + norm_cfg: ConfigType = dict(type='BN1d'), + act_cfg: ConfigType = dict(type='ReLU'), + loss_ce: ConfigType = dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0), + loss_lovasz: ConfigType = dict( + type='mmseg.LovaszLoss', loss_weight=1.0), + conv_seg_kernel_size: int = 3, + ignore_index: int = 0, + init_cfg=None) -> None: + super(Cylinder3DHead, self).__init__( + channels=channels, + num_classes=num_classes, + dropout_ratio=dropout_ratio, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + conv_seg_kernel_size=conv_seg_kernel_size, + init_cfg=init_cfg) + + self.loss_lovasz = MODELS.build(loss_lovasz) + self.loss_ce = MODELS.build(loss_ce) + self.ignore_index = ignore_index + + def build_conv_seg(self, channels: int, num_classes: int, + kernel_size: int) -> spconv.SubMConv3d: + return spconv.SubMConv3d( + channels, + num_classes, + indice_key='logit', + kernel_size=kernel_size, + stride=1, + padding=1, + bias=True) + + def forward( + self, + sparse_voxels: spconv.SparseConvTensor) -> spconv.SparseConvTensor: + """Forward function.""" + sparse_logits = self.cls_seg(sparse_voxels) + return sparse_logits + + def loss_by_feat(self, seg_logit: spconv.SparseConvTensor, + batch_data_samples: SampleList) -> dict: + """Compute semantic segmentation loss. + + Args: + seg_logit (spconv.SparseConvTensor): Predicted per-voxel + segmentation logits of shape [num_voxels, num_classes] + stored in SparseConvTensor. + batch_data_samples (List[:obj:`Det3DDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_pts_seg`. + """ + + gt_semantic_segs = [ + data_sample.gt_pts_seg.voxel_semantic_mask + for data_sample in batch_data_samples + ] + seg_label = torch.cat(gt_semantic_segs) + seg_logit_feat = seg_logit.features + loss = dict() + loss['loss_ce'] = self.loss_ce( + seg_logit_feat, seg_label, ignore_index=self.ignore_index) + seg_logit_feat = seg_logit_feat.permute(1, 0)[None, :, :, + None] # pseudo BCHW + loss['loss_lovasz'] = self.loss_lovasz( + seg_logit_feat, seg_label, ignore_index=self.ignore_index) + + return loss diff --git a/mmdet3d/models/decode_heads/decode_head.py b/mmdet3d/models/decode_heads/decode_head.py index a9999e1f98..5c8bbd672f 100644 --- a/mmdet3d/models/decode_heads/decode_head.py +++ b/mmdet3d/models/decode_heads/decode_head.py @@ -51,6 +51,8 @@ class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): loss_decode (dict or :obj:`ConfigDict`): Config of decode loss. Defaults to dict(type='mmdet.CrossEntropyLoss', use_sigmoid=False, class_weight=None, loss_weight=1.0). + conv_seg_kernel_size (int): The kernel size used in conv_seg. + Defaults to 1. ignore_index (int): The label index to be ignored. When using masked BCE loss, ignore_index should be set to None. Defaults to 255. init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`], @@ -69,6 +71,7 @@ def __init__(self, use_sigmoid=False, class_weight=None, loss_weight=1.0), + conv_seg_kernel_size: int = 1, ignore_index: int = 255, init_cfg: OptMultiConfig = None) -> None: super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg) @@ -81,7 +84,10 @@ def __init__(self, self.loss_decode = MODELS.build(loss_decode) self.ignore_index = ignore_index - self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1) + self.conv_seg = self.build_conv_seg( + channels=channels, + num_classes=num_classes, + kernel_size=conv_seg_kernel_size) if dropout_ratio > 0: self.dropout = nn.Dropout(dropout_ratio) else: @@ -97,6 +103,11 @@ def forward(self, feats_dict: dict) -> Tensor: """Placeholder of forward function.""" pass + def build_conv_seg(self, channels: int, num_classes: int, + kernel_size: int) -> nn.Conv1d: + """Build Convolutional Segmentation Layers.""" + return nn.Conv1d(channels, num_classes, kernel_size=kernel_size) + def cls_seg(self, feat: Tensor) -> Tensor: """Classify each points.""" if self.dropout is not None: diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py new file mode 100644 index 0000000000..b8107213ac --- /dev/null +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import spconv +import torch + +from mmdet3d.models.decode_heads import Cylinder3DHead +from mmdet3d.structures import Det3DDataSample, PointData + + +class TestCylinder3DHead(TestCase): + + def test_cylinder3d_head_loss(self): + """Tests DGCNN head loss.""" + + cylinder3d_head = Cylinder3DHead( + channels=128, + num_classes=20, + loss_ce=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0), + loss_lovasz=dict( + type='mmseg.LovaszLoss', loss_weight=1.0, reduction='none'), + ).cuda() + + # DGCNN head expects dict format features + voxel_feats = torch.rand(50, 128).cuda() + coorx = torch.randint(0, 480, (50, 1)).int().cuda() + coory = torch.randint(0, 360, (50, 1)).int().cuda() + coorz = torch.randint(0, 32, (50, 1)).int().cuda() + coorbatch0 = torch.zeros(50, 1).int().cuda() + coors = torch.cat([coorbatch0, coorx, coory, coorz], dim=1) + grid_size = [480, 360, 32] + batch_size = 1 + + sparse_voxels = spconv.SparseConvTensor(voxel_feats, coors, grid_size, + batch_size) + # Test forward + seg_logits = cylinder3d_head.forward(sparse_voxels) + + self.assertEqual(seg_logits.features.shape, torch.Size([50, 20])) + + # When truth is non-empty then losses + # should be nonzero for random inputs + voxel_semantic_mask = torch.randint(0, 20, (50, )).long().cuda() + gt_pts_seg = PointData(voxel_semantic_mask=voxel_semantic_mask) + + datasample = Det3DDataSample() + datasample.gt_pts_seg = gt_pts_seg + + losses = cylinder3d_head.loss_by_feat(seg_logits, [datasample]) + + loss_ce = losses['loss_ce'].item() + loss_lovasz = losses['loss_lovasz'].item() + + self.assertGreater(loss_ce, 0, 'ce loss should be positive') + self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive') From 9506d155de84cfee424ddb7f71fd24b94d569bc0 Mon Sep 17 00:00:00 2001 From: xzq Date: Fri, 17 Feb 2023 19:13:12 +0800 Subject: [PATCH 02/18] update --- mmdet3d/models/decode_heads/cylinder3d_head.py | 12 +++++------- .../test_decode_heads/test_cylinder3d_head.py | 6 +++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index 630c33927c..a035311052 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -import spconv import torch +from mmcv.ops import SparseConvTensor, SubMConv3d from mmdet3d.registry import MODELS from mmdet3d.structures.det3d_data_sample import SampleList @@ -74,8 +74,8 @@ def __init__(self, self.ignore_index = ignore_index def build_conv_seg(self, channels: int, num_classes: int, - kernel_size: int) -> spconv.SubMConv3d: - return spconv.SubMConv3d( + kernel_size: int) -> SubMConv3d: + return SubMConv3d( channels, num_classes, indice_key='logit', @@ -84,14 +84,12 @@ def build_conv_seg(self, channels: int, num_classes: int, padding=1, bias=True) - def forward( - self, - sparse_voxels: spconv.SparseConvTensor) -> spconv.SparseConvTensor: + def forward(self, sparse_voxels: SparseConvTensor) -> SparseConvTensor: """Forward function.""" sparse_logits = self.cls_seg(sparse_voxels) return sparse_logits - def loss_by_feat(self, seg_logit: spconv.SparseConvTensor, + def loss_by_feat(self, seg_logit: SparseConvTensor, batch_data_samples: SampleList) -> dict: """Compute semantic segmentation loss. diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py index b8107213ac..23993fcfff 100644 --- a/tests/test_models/test_decode_heads/test_cylinder3d_head.py +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase -import spconv import torch +from mmcv.ops import SparseConvTensor from mmdet3d.models.decode_heads import Cylinder3DHead from mmdet3d.structures import Det3DDataSample, PointData @@ -35,8 +35,8 @@ def test_cylinder3d_head_loss(self): grid_size = [480, 360, 32] batch_size = 1 - sparse_voxels = spconv.SparseConvTensor(voxel_feats, coors, grid_size, - batch_size) + sparse_voxels = SparseConvTensor(voxel_feats, coors, grid_size, + batch_size) # Test forward seg_logits = cylinder3d_head.forward(sparse_voxels) From 8fbac9b777d85d8d86acf0cb502f64c9b6d6b904 Mon Sep 17 00:00:00 2001 From: xzq Date: Sun, 19 Feb 2023 16:10:23 +0800 Subject: [PATCH 03/18] update --- mmdet3d/models/losses/__init__.py | 3 ++- tests/test_models/test_decode_heads/test_cylinder3d_head.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mmdet3d/models/losses/__init__.py b/mmdet3d/models/losses/__init__.py index 84f9cea6ca..6956c7219d 100644 --- a/mmdet3d/models/losses/__init__.py +++ b/mmdet3d/models/losses/__init__.py @@ -3,6 +3,7 @@ from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss from .chamfer_distance import ChamferDistance, chamfer_distance +from .lovasz_loss import LovaszLoss from .multibin_loss import MultiBinLoss from .paconv_regularization_loss import PAConvRegularizationLoss from .rotated_iou_loss import RotatedIoU3DLoss, rotated_iou_3d_loss @@ -12,5 +13,5 @@ 'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance', 'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss', 'PAConvRegularizationLoss', 'UncertainL1Loss', 'UncertainSmoothL1Loss', - 'MultiBinLoss', 'RotatedIoU3DLoss', 'rotated_iou_3d_loss' + 'MultiBinLoss', 'RotatedIoU3DLoss', 'rotated_iou_3d_loss', 'LovaszLoss' ] diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py index 23993fcfff..ab339a5889 100644 --- a/tests/test_models/test_decode_heads/test_cylinder3d_head.py +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -22,7 +22,7 @@ def test_cylinder3d_head_loss(self): class_weight=None, loss_weight=1.0), loss_lovasz=dict( - type='mmseg.LovaszLoss', loss_weight=1.0, reduction='none'), + type='LovaszLoss', loss_weight=1.0, reduction='none'), ).cuda() # DGCNN head expects dict format features From bb65b62a261257a9263799327573b5897a6f782c Mon Sep 17 00:00:00 2001 From: xzq Date: Sun, 19 Feb 2023 17:03:48 +0800 Subject: [PATCH 04/18] add lovasz loss --- mmdet3d/models/losses/lovasz_loss.py | 323 +++++++++++++++++++++ mmdet3d/models/losses/lovasz_loss_utils.py | 127 ++++++++ 2 files changed, 450 insertions(+) create mode 100644 mmdet3d/models/losses/lovasz_loss.py create mode 100644 mmdet3d/models/losses/lovasz_loss_utils.py diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py new file mode 100644 index 0000000000..3fb9fe9797 --- /dev/null +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -0,0 +1,323 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor +ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim +Berman 2018 ESAT-PSI KU Leuven (MIT License)""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.utils import is_list_of + +from mmdet3d.registry import MODELS +from .lovasz_loss_utils import get_class_weight, weight_reduce_loss + + +def lovasz_grad(gt_sorted): + """Computes gradient of the Lovasz extension w.r.t sorted errors. + + See Alg. 1 in paper. + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def flatten_binary_logits(logits, labels, ignore_index=None): + """Flattens predictions in the batch (binary case) Remove labels equal to + 'ignore_index'.""" + logits = logits.view(-1) + labels = labels.view(-1) + if ignore_index is None: + return logits, labels + valid = (labels != ignore_index) + vlogits = logits[valid] + vlabels = labels[valid] + return vlogits, vlabels + + +def flatten_probs(probs, labels, ignore_index=None): + """Flattens predictions in the batch.""" + if probs.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probs.size() + probs = probs.view(B, 1, H, W) + B, C, H, W = probs.size() + probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C + labels = labels.view(-1) + if ignore_index is None: + return probs, labels + valid = (labels != ignore_index) + vprobs = probs[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobs, vlabels + + +def lovasz_hinge_flat(logits, labels): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [P], logits at each prediction + (between -infty and +infty). + labels (torch.Tensor): [P], binary ground truth labels (0 or 1). + + Returns: + torch.Tensor: The calculated loss. + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * signs) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def lovasz_hinge(logits, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [B, H, W], logits at each pixel + (between -infty and +infty). + labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). + classes (str | list[int], optional): Placeholder, to be consistent with + other loss. Default: None. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): Placeholder, to be consistent + with other loss. Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + if per_image: + loss = [ + lovasz_hinge_flat(*flatten_binary_logits( + logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) + for logit, label in zip(logits, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_hinge_flat( + *flatten_binary_logits(logits, labels, ignore_index)) + return loss + + +def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [P, C], class probabilities at each prediction + (between 0 and 1). + labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + class_weight (list[float], optional): The weight for each class. + Default: None. + + Returns: + torch.Tensor: The calculated loss. + """ + if probs.numel() == 0: + # only void pixels, the gradients should be 0 + return probs * 0. + C = probs.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes == 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probs[:, 0] + else: + class_pred = probs[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted)) + if class_weight is not None: + loss *= class_weight[c] + losses.append(loss) + return torch.stack(losses).mean() + + +def lovasz_softmax(probs, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [B, C, H, W], class probabilities at each + prediction (between 0 and 1). + labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and + C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + + if per_image: + loss = [ + lovasz_softmax_flat( + *flatten_probs( + prob.unsqueeze(0), label.unsqueeze(0), ignore_index), + classes=classes, + class_weight=class_weight) + for prob, label in zip(probs, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_softmax_flat( + *flatten_probs(probs, labels, ignore_index), + classes=classes, + class_weight=class_weight) + return loss + + +@MODELS.register_module() +class LovaszLoss(nn.Module): + """LovaszLoss. + + This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate + for the optimization of the intersection-over-union measure in neural + networks `_. + + Args: + loss_type (str, optional): Binary or multi-class loss. + Default: 'multi_class'. Options are "binary" and "multi_class". + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_lovasz'. + """ + + def __init__(self, + loss_type='multi_class', + classes='present', + per_image=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_lovasz'): + super().__init__() + assert loss_type in ('binary', 'multi_class'), "loss_type should be \ + 'binary' or 'multi_class'." + + if loss_type == 'binary': + self.cls_criterion = lovasz_hinge + else: + self.cls_criterion = lovasz_softmax + assert classes in ('all', 'present') or is_list_of(classes, int) + if not per_image: + assert reduction == 'none', "reduction should be 'none' when \ + per_image is False." + + self.classes = classes + self.per_image = per_image + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self._loss_name = loss_name + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # if multi-class loss, transform logits to probs + if self.cls_criterion == lovasz_softmax: + cls_score = F.softmax(cls_score, dim=1) + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + self.classes, + self.per_image, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmdet3d/models/losses/lovasz_loss_utils.py b/mmdet3d/models/losses/lovasz_loss_utils.py new file mode 100644 index 0000000000..250eccb9f4 --- /dev/null +++ b/mmdet3d/models/losses/lovasz_loss_utils.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Directly borrowed from mmsegmentation.""" +import functools + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.fileio import load + + +def get_class_weight(class_weight): + """Get class weight for loss function. + + Args: + class_weight (list[float] | str | None): If class_weight is a str, + take it as a file name and read from it. + """ + if isinstance(class_weight, str): + # take it as a file path + if class_weight.endswith('.npy'): + class_weight = np.load(class_weight) + else: + # pkl, json or yaml + class_weight = load(class_weight) + + return class_weight + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper From 288c8b47eca21cdbe3ed38f9643c6ef70962a65f Mon Sep 17 00:00:00 2001 From: xzq Date: Mon, 20 Feb 2023 10:52:41 +0800 Subject: [PATCH 05/18] update --- mmdet3d/models/losses/lovasz_loss.py | 1 + tests/test_models/test_decode_heads/test_cylinder3d_head.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py index 3fb9fe9797..d5034d9f10 100644 --- a/mmdet3d/models/losses/lovasz_loss.py +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -317,6 +317,7 @@ def loss_name(self): by simple sum operation. In addition, if you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. + Returns: str: The name of this loss item. """ diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py index ab339a5889..9a35fd664a 100644 --- a/tests/test_models/test_decode_heads/test_cylinder3d_head.py +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase +import pytest import torch from mmcv.ops import SparseConvTensor @@ -11,8 +12,9 @@ class TestCylinder3DHead(TestCase): def test_cylinder3d_head_loss(self): - """Tests DGCNN head loss.""" - + """Tests Cylinder3D head loss.""" + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') cylinder3d_head = Cylinder3DHead( channels=128, num_classes=20, From f7499b174f6f64aca7d4b10e3143ed9aeac576f5 Mon Sep 17 00:00:00 2001 From: xzq Date: Mon, 20 Feb 2023 10:53:04 +0800 Subject: [PATCH 06/18] update --- mmdet3d/models/losses/lovasz_loss.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py index d5034d9f10..8385d71301 100644 --- a/mmdet3d/models/losses/lovasz_loss.py +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -1,7 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor +"""Directly borrowed from mmsegmentation. + +Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim -Berman 2018 ESAT-PSI KU Leuven (MIT License)""" +Berman 2018 ESAT-PSI KU Leuven (MIT License) +""" import torch import torch.nn as nn From e24688f429fbd7fdc5c1827dfb87bbc94465a072 Mon Sep 17 00:00:00 2001 From: xzq Date: Mon, 20 Feb 2023 20:13:15 +0800 Subject: [PATCH 07/18] update --- .../models/decode_heads/cylinder3d_head.py | 15 ++- mmdet3d/models/losses/lovasz_loss.py | 116 ++++++++++++------ mmdet3d/models/losses/lovasz_loss_utils.py | 23 ++-- .../test_decode_heads/test_cylinder3d_head.py | 1 - 4 files changed, 101 insertions(+), 54 deletions(-) diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index a035311052..0a1c16ba3a 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -19,7 +19,7 @@ class Cylinder3DHead(Base3DDecodeHead): Args: channels (int): Channels after modules, before conv_seg. num_classes (int): Number of classes. - dropout_ratio (float): Ratio of dropout layer. Defaults to 0.5. + dropout_ratio (float): Ratio of dropout layer. Defaults to 0. conv_cfg (dict or :obj:`ConfigDict`): Config of conv layers. Defaults to dict(type='Conv1d'). norm_cfg (dict or :obj:`ConfigDict`): Config of norm layers. @@ -33,11 +33,11 @@ class Cylinder3DHead(Base3DDecodeHead): class_weight=None, loss_weight=1.0). loss_lovasz (dict or :obj:`ConfigDict`): Config of Lovasz loss. - Defaults to dict(type='mmseg.LovaszLoss', loss_weight=1.0). + Defaults to dict(type='LovaszLoss', loss_weight=1.0). conv_seg_kernel_size (int): The kernel size used in conv_seg. - Defaults to 1. + Defaults to 3. ignore_index (int): The label index to be ignored. When using masked - BCE loss, ignore_index should be set to None. Defaults to 255. + BCE loss, ignore_index should be set to None. Defaults to 0. init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`], optional): Initialization config dict. Defaults to None. """ @@ -55,7 +55,7 @@ def __init__(self, class_weight=None, loss_weight=1.0), loss_lovasz: ConfigType = dict( - type='mmseg.LovaszLoss', loss_weight=1.0), + type='LovaszLoss', loss_weight=1.0), conv_seg_kernel_size: int = 3, ignore_index: int = 0, init_cfg=None) -> None: @@ -74,7 +74,7 @@ def __init__(self, self.ignore_index = ignore_index def build_conv_seg(self, channels: int, num_classes: int, - kernel_size: int) -> SubMConv3d: + kernel_size: int) -> SparseConvTensor: return SubMConv3d( channels, num_classes, @@ -100,6 +100,9 @@ def loss_by_feat(self, seg_logit: SparseConvTensor, batch_data_samples (List[:obj:`Det3DDataSample`]): The seg data samples. It usually includes information such as `metainfo` and `gt_pts_seg`. + + Returns: + Dict[str, Tensor]: A dictionary of loss components. """ gt_semantic_segs = [ diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py index 8385d71301..dd060e5aef 100644 --- a/mmdet3d/models/losses/lovasz_loss.py +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -6,6 +6,8 @@ Berman 2018 ESAT-PSI KU Leuven (MIT License) """ +from typing import List + import torch import torch.nn as nn import torch.nn.functional as F @@ -15,10 +17,16 @@ from .lovasz_loss_utils import get_class_weight, weight_reduce_loss -def lovasz_grad(gt_sorted): +def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor: """Computes gradient of the Lovasz extension w.r.t sorted errors. See Alg. 1 in paper. + + Args: + gt_sorted (torch.Tensor): Sorted ground truth. + + Return: + torch.Tensor: Gradient of the Lovasz extension. """ p = len(gt_sorted) gts = gt_sorted.sum() @@ -30,9 +38,22 @@ def lovasz_grad(gt_sorted): return jaccard -def flatten_binary_logits(logits, labels, ignore_index=None): - """Flattens predictions in the batch (binary case) Remove labels equal to - 'ignore_index'.""" +def flatten_binary_logits( + logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = None) -> tuple(torch.Tensor, torch.Tensor): + """Flattens predictions and labels in the batch (binary case). Remove + tensors whose labels equal to 'ignore_index'. + + Args: + probs (torch.Tensor): Predictions to be modified. + labels (torch.Tensor): Labels to be modified. + ignore_index (int | None): The label index to be ignored. + Defaults to None. + + Return: + tuple(torch.Tensor, torch.Tensor): Modified predictions and labels. + """ logits = logits.view(-1) labels = labels.view(-1) if ignore_index is None: @@ -43,8 +64,22 @@ def flatten_binary_logits(logits, labels, ignore_index=None): return vlogits, vlabels -def flatten_probs(probs, labels, ignore_index=None): - """Flattens predictions in the batch.""" +def flatten_probs( + probs: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = None) -> tuple(torch.Tensor, torch.Tensor): + """Flattens predictions and labels in the batch. Remove tensors whose + labels equal to 'ignore_index'. + + Args: + probs (torch.Tensor): Predictions to be modified. + labels (torch.Tensor): Labels to be modified. + ignore_index (int | None): The label index to be ignored. + Defaults to None. + + Return: + tuple(torch.Tensor, torch.Tensor): Modified predictions and labels. + """ if probs.dim() == 3: # assumes output of a sigmoid layer B, H, W = probs.size() @@ -60,7 +95,8 @@ def flatten_probs(probs, labels, ignore_index=None): return vprobs, vlabels -def lovasz_hinge_flat(logits, labels): +def lovasz_hinge_flat(logits: torch.Tensor, + labels: torch.Tensor) -> torch.Tensor: """Binary Lovasz hinge loss. Args: @@ -84,14 +120,14 @@ def lovasz_hinge_flat(logits, labels): return loss -def lovasz_hinge(logits, - labels, - classes='present', - per_image=False, - class_weight=None, - reduction='mean', - avg_factor=None, - ignore_index=255): +def lovasz_hinge(logits: torch.Tensor, + labels: torch.Tensor, + classes: str or List[int] = 'present', + per_image: bool = False, + class_weight: List[float] = None, + reduction: str = 'mean', + avg_factor: int = None, + ignore_index: int = 255) -> torch.Tensor: """Binary Lovasz hinge loss. Args: @@ -129,14 +165,17 @@ def lovasz_hinge(logits, return loss -def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): +def lovasz_softmax_flat(probs: torch.Tensor, + labels: torch.Tensor, + classes: str or List[int] = 'present', + class_weight: List[float] = None) -> torch.Tensor: """Multi-class Lovasz-Softmax loss. Args: probs (torch.Tensor): [P, C], class probabilities at each prediction (between 0 and 1). labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). - classes (str | list[int], optional): Classes chosen to calculate loss. + classes (str | list[int]): Classes chosen to calculate loss. 'all' for all classes, 'present' for classes present in labels, or a list of classes to average. Default: 'present'. class_weight (list[float], optional): The weight for each class. @@ -172,14 +211,14 @@ def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): return torch.stack(losses).mean() -def lovasz_softmax(probs, - labels, - classes='present', - per_image=False, - class_weight=None, - reduction='mean', - avg_factor=None, - ignore_index=255): +def lovasz_softmax(probs: torch.Tensor, + labels: torch.Tensor, + classes: str or List[int] = 'present', + per_image: bool = False, + class_weight: List[float] = None, + reduction: str = 'mean', + avg_factor: int = None, + ignore_index: int = 255) -> torch.Tensor: """Multi-class Lovasz-Softmax loss. Args: @@ -253,13 +292,13 @@ class LovaszLoss(nn.Module): """ def __init__(self, - loss_type='multi_class', - classes='present', - per_image=False, - reduction='mean', - class_weight=None, - loss_weight=1.0, - loss_name='loss_lovasz'): + loss_type: str = 'multi_class', + classes: str or List[int] = 'present', + per_image: bool = False, + reduction: str = 'mean', + class_weight: List[float] or str = None, + loss_weight: float = 1.0, + loss_name: str = 'loss_lovasz'): super().__init__() assert loss_type in ('binary', 'multi_class'), "loss_type should be \ 'binary' or 'multi_class'." @@ -281,12 +320,11 @@ def __init__(self, self._loss_name = loss_name def forward(self, - cls_score, - label, - weight=None, - avg_factor=None, - reduction_override=None, - **kwargs): + cls_score: torch.Tensor, + label: torch.Tensor, + avg_factor: int = None, + reduction_override: str = None, + **kwargs) -> torch.Tensor: """Forward function.""" assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( @@ -312,7 +350,7 @@ def forward(self, return loss_cls @property - def loss_name(self): + def loss_name(self) -> str: """Loss Name. This function must be implemented and will return the name of this diff --git a/mmdet3d/models/losses/lovasz_loss_utils.py b/mmdet3d/models/losses/lovasz_loss_utils.py index 250eccb9f4..1cd89244a9 100644 --- a/mmdet3d/models/losses/lovasz_loss_utils.py +++ b/mmdet3d/models/losses/lovasz_loss_utils.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. """Directly borrowed from mmsegmentation.""" import functools +from typing import List import numpy as np import torch @@ -8,12 +9,15 @@ from mmengine.fileio import load -def get_class_weight(class_weight): +def get_class_weight(class_weight: List[float] or str) -> List[float]: """Get class weight for loss function. Args: class_weight (list[float] | str | None): If class_weight is a str, take it as a file name and read from it. + + Return: + list[float]: Loaded class_weight. """ if isinstance(class_weight, str): # take it as a file path @@ -26,15 +30,15 @@ def get_class_weight(class_weight): return class_weight -def reduce_loss(loss, reduction): +def reduce_loss(loss: torch.Tensor, reduction: str) -> torch.Tensor: """Reduce loss as specified. Args: - loss (Tensor): Elementwise loss tensor. + loss (torch.Tensor): Elementwise loss tensor. reduction (str): Options are "none", "mean" and "sum". Return: - Tensor: Reduced loss tensor. + torch.Tensor: Reduced loss tensor. """ reduction_enum = F._Reduction.get_enum(reduction) # none: 0, elementwise_mean:1, sum: 2 @@ -46,17 +50,20 @@ def reduce_loss(loss, reduction): return loss.sum() -def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): +def weight_reduce_loss(loss: torch.Tensor, + weight: torch.Tensor = None, + reduction: str = 'mean', + avg_factor: float = None) -> torch.Tensor: """Apply element-wise weight and reduce loss. Args: - loss (Tensor): Element-wise loss. - weight (Tensor): Element-wise weights. + loss (torch.Tensor): Element-wise loss. + weight (torch.Tensor): Element-wise weights. reduction (str): Same as built-in losses of PyTorch. avg_factor (float): Average factor when computing the mean of losses. Returns: - Tensor: Processed loss values. + torch.Tensor: Processed loss values. """ # if weight is specified, apply element-wise weight if weight is not None: diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py index 9a35fd664a..2bad4dc9ef 100644 --- a/tests/test_models/test_decode_heads/test_cylinder3d_head.py +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -27,7 +27,6 @@ def test_cylinder3d_head_loss(self): type='LovaszLoss', loss_weight=1.0, reduction='none'), ).cuda() - # DGCNN head expects dict format features voxel_feats = torch.rand(50, 128).cuda() coorx = torch.randint(0, 480, (50, 1)).int().cuda() coory = torch.randint(0, 360, (50, 1)).int().cuda() From 73bffd1daa50cf10bf96aa2c95fb3727650d7649 Mon Sep 17 00:00:00 2001 From: xzq Date: Mon, 20 Feb 2023 20:14:06 +0800 Subject: [PATCH 08/18] update --- mmdet3d/models/losses/lovasz_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py index dd060e5aef..2ad00e9a96 100644 --- a/mmdet3d/models/losses/lovasz_loss.py +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -6,7 +6,7 @@ Berman 2018 ESAT-PSI KU Leuven (MIT License) """ -from typing import List +from typing import List, Tuple import torch import torch.nn as nn @@ -41,7 +41,7 @@ def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor: def flatten_binary_logits( logits: torch.Tensor, labels: torch.Tensor, - ignore_index: int = None) -> tuple(torch.Tensor, torch.Tensor): + ignore_index: int = None) -> Tuple(torch.Tensor, torch.Tensor): """Flattens predictions and labels in the batch (binary case). Remove tensors whose labels equal to 'ignore_index'. @@ -67,7 +67,7 @@ def flatten_binary_logits( def flatten_probs( probs: torch.Tensor, labels: torch.Tensor, - ignore_index: int = None) -> tuple(torch.Tensor, torch.Tensor): + ignore_index: int = None) -> Tuple(torch.Tensor, torch.Tensor): """Flattens predictions and labels in the batch. Remove tensors whose labels equal to 'ignore_index'. From 0d2df754c576613129714562144bfafed41b5ee4 Mon Sep 17 00:00:00 2001 From: xzq Date: Mon, 20 Feb 2023 20:15:24 +0800 Subject: [PATCH 09/18] update --- mmdet3d/models/losses/lovasz_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py index 2ad00e9a96..ef797723c6 100644 --- a/mmdet3d/models/losses/lovasz_loss.py +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -41,7 +41,7 @@ def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor: def flatten_binary_logits( logits: torch.Tensor, labels: torch.Tensor, - ignore_index: int = None) -> Tuple(torch.Tensor, torch.Tensor): + ignore_index: int = None) -> Tuple[torch.Tensor, torch.Tensor]: """Flattens predictions and labels in the batch (binary case). Remove tensors whose labels equal to 'ignore_index'. @@ -67,7 +67,7 @@ def flatten_binary_logits( def flatten_probs( probs: torch.Tensor, labels: torch.Tensor, - ignore_index: int = None) -> Tuple(torch.Tensor, torch.Tensor): + ignore_index: int = None) -> Tuple[torch.Tensor, torch.Tensor]: """Flattens predictions and labels in the batch. Remove tensors whose labels equal to 'ignore_index'. From 961efd3eaab4d1ed878144526f5394118eb628cc Mon Sep 17 00:00:00 2001 From: xzq Date: Tue, 21 Feb 2023 16:49:54 +0800 Subject: [PATCH 10/18] update --- .../models/decode_heads/cylinder3d_head.py | 5 +- mmdet3d/models/losses/lovasz_loss.py | 103 +++++++++--------- mmdet3d/models/losses/lovasz_loss_utils.py | 8 +- 3 files changed, 62 insertions(+), 54 deletions(-) diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index 0a1c16ba3a..80afcd4061 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch from mmcv.ops import SparseConvTensor, SubMConv3d from mmdet3d.registry import MODELS from mmdet3d.structures.det3d_data_sample import SampleList +from mmdet3d.utils import OptConfigType from mmdet3d.utils.typing_utils import ConfigType from .decode_head import Base3DDecodeHead @@ -58,7 +61,7 @@ def __init__(self, type='LovaszLoss', loss_weight=1.0), conv_seg_kernel_size: int = 3, ignore_index: int = 0, - init_cfg=None) -> None: + init_cfg: Optional[dict or OptConfigType] = None) -> None: super(Cylinder3DHead, self).__init__( channels=channels, num_classes=num_classes, diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py index ef797723c6..1c854b2406 100644 --- a/mmdet3d/models/losses/lovasz_loss.py +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -6,7 +6,7 @@ Berman 2018 ESAT-PSI KU Leuven (MIT License) """ -from typing import List, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -41,14 +41,15 @@ def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor: def flatten_binary_logits( logits: torch.Tensor, labels: torch.Tensor, - ignore_index: int = None) -> Tuple[torch.Tensor, torch.Tensor]: + ignore_index: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor]: """Flattens predictions and labels in the batch (binary case). Remove tensors whose labels equal to 'ignore_index'. Args: probs (torch.Tensor): Predictions to be modified. labels (torch.Tensor): Labels to be modified. - ignore_index (int | None): The label index to be ignored. + ignore_index (int, optional): The label index to be ignored. Defaults to None. Return: @@ -67,14 +68,15 @@ def flatten_binary_logits( def flatten_probs( probs: torch.Tensor, labels: torch.Tensor, - ignore_index: int = None) -> Tuple[torch.Tensor, torch.Tensor]: + ignore_index: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor]: """Flattens predictions and labels in the batch. Remove tensors whose labels equal to 'ignore_index'. Args: probs (torch.Tensor): Predictions to be modified. labels (torch.Tensor): Labels to be modified. - ignore_index (int | None): The label index to be ignored. + ignore_index (int, optional): The label index to be ignored. Defaults to None. Return: @@ -122,11 +124,11 @@ def lovasz_hinge_flat(logits: torch.Tensor, def lovasz_hinge(logits: torch.Tensor, labels: torch.Tensor, - classes: str or List[int] = 'present', + classes: Optional[Union[str, List[int]]] = None, per_image: bool = False, - class_weight: List[float] = None, + class_weight: Optional[List[float]] = None, reduction: str = 'mean', - avg_factor: int = None, + avg_factor: Optional[int] = None, ignore_index: int = 255) -> torch.Tensor: """Binary Lovasz hinge loss. @@ -134,19 +136,20 @@ def lovasz_hinge(logits: torch.Tensor, logits (torch.Tensor): [B, H, W], logits at each pixel (between -infty and +infty). labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). - classes (str | list[int], optional): Placeholder, to be consistent with - other loss. Default: None. - per_image (bool, optional): If per_image is True, compute the loss per - image instead of per batch. Default: False. + classes (Union[str, list[int]], optional): Placeholder, to be + consistent with other loss. Defaults to None. + per_image (bool): If per_image is True, compute the loss per + image instead of per batch. Defaults to False. class_weight (list[float], optional): Placeholder, to be consistent - with other loss. Default: None. - reduction (str, optional): The method used to reduce the loss. Options + with other loss. Defaults to None. + reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when - per_image is True. Default: 'mean'. + per_image is True. Defaults to 'mean'. avg_factor (int, optional): Average factor that is used to average the loss. This parameter only works when per_image is True. - Default: None. - ignore_index (int | None): The label index to be ignored. Default: 255. + Defaults to None. + ignore_index (int, optional): The label index to be ignored. + Defaults to 255. Returns: torch.Tensor: The calculated loss. @@ -165,21 +168,22 @@ def lovasz_hinge(logits: torch.Tensor, return loss -def lovasz_softmax_flat(probs: torch.Tensor, - labels: torch.Tensor, - classes: str or List[int] = 'present', - class_weight: List[float] = None) -> torch.Tensor: +def lovasz_softmax_flat( + probs: torch.Tensor, + labels: torch.Tensor, + classes: Union[str, List[int]] = 'present', + class_weight: Optional[List[float]] = None) -> torch.Tensor: """Multi-class Lovasz-Softmax loss. Args: probs (torch.Tensor): [P, C], class probabilities at each prediction (between 0 and 1). labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). - classes (str | list[int]): Classes chosen to calculate loss. + classes (Union[str, list[int]]): Classes chosen to calculate loss. 'all' for all classes, 'present' for classes present in labels, or - a list of classes to average. Default: 'present'. + a list of classes to average. Defaults to 'present'. class_weight (list[float], optional): The weight for each class. - Default: None. + Defaults to None. Returns: torch.Tensor: The calculated loss. @@ -213,11 +217,11 @@ def lovasz_softmax_flat(probs: torch.Tensor, def lovasz_softmax(probs: torch.Tensor, labels: torch.Tensor, - classes: str or List[int] = 'present', + classes: Optional[Union[str, List[int]]] = 'present', per_image: bool = False, class_weight: List[float] = None, reduction: str = 'mean', - avg_factor: int = None, + avg_factor: Optional[int] = None, ignore_index: int = 255) -> torch.Tensor: """Multi-class Lovasz-Softmax loss. @@ -226,20 +230,21 @@ def lovasz_softmax(probs: torch.Tensor, prediction (between 0 and 1). labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and C - 1). - classes (str | list[int], optional): Classes chosen to calculate loss. + classes (Union[str, list[int]]): Classes chosen to calculate loss. 'all' for all classes, 'present' for classes present in labels, or - a list of classes to average. Default: 'present'. - per_image (bool, optional): If per_image is True, compute the loss per - image instead of per batch. Default: False. + a list of classes to average. Defaults to 'present'. + per_image (bool): If per_image is True, compute the loss per + image instead of per batch. Defaults to False. class_weight (list[float], optional): The weight for each class. - Default: None. - reduction (str, optional): The method used to reduce the loss. Options + Defaults to None. + reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when - per_image is True. Default: 'mean'. + per_image is True. Defaults to 'mean'. avg_factor (int, optional): Average factor that is used to average the loss. This parameter only works when per_image is True. - Default: None. - ignore_index (int | None): The label index to be ignored. Default: 255. + Defaults to None. + ignore_index (Union[int, None]): The label index to be ignored. + Defaults to 255. Returns: torch.Tensor: The calculated loss. @@ -273,30 +278,30 @@ class LovaszLoss(nn.Module): networks `_. Args: - loss_type (str, optional): Binary or multi-class loss. - Default: 'multi_class'. Options are "binary" and "multi_class". - classes (str | list[int], optional): Classes chosen to calculate loss. + loss_type (str): Binary or multi-class loss. + Defaults to 'multi_class'. Options are "binary" and "multi_class". + classes (Union[str, list[int]]): Classes chosen to calculate loss. 'all' for all classes, 'present' for classes present in labels, or - a list of classes to average. Default: 'present'. - per_image (bool, optional): If per_image is True, compute the loss per - image instead of per batch. Default: False. - reduction (str, optional): The method used to reduce the loss. Options + a list of classes to average. Defaults to 'present'. + per_image (bool): If per_image is True, compute the loss per + image instead of per batch. Defaults to False. + reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when - per_image is True. Default: 'mean'. - class_weight (list[float] | str, optional): Weight of each class. If in - str format, read them from a file. Defaults to None. - loss_weight (float, optional): Weight of the loss. Defaults to 1.0. - loss_name (str, optional): Name of the loss item. If you want this loss + per_image is True. Defaults to 'mean'. + class_weight (Union[list[float], str], optional): Weight of each class. + If in str format, read them from a file. Defaults to None. + loss_weight (float): Weight of the loss. Defaults to 1.0. + loss_name (str): Name of the loss item. If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. Defaults to 'loss_lovasz'. """ def __init__(self, loss_type: str = 'multi_class', - classes: str or List[int] = 'present', + classes: Union[str, List[int]] = 'present', per_image: bool = False, reduction: str = 'mean', - class_weight: List[float] or str = None, + class_weight: Optional[Union[List[float], str]] = None, loss_weight: float = 1.0, loss_name: str = 'loss_lovasz'): super().__init__() diff --git a/mmdet3d/models/losses/lovasz_loss_utils.py b/mmdet3d/models/losses/lovasz_loss_utils.py index 1cd89244a9..1c88d63079 100644 --- a/mmdet3d/models/losses/lovasz_loss_utils.py +++ b/mmdet3d/models/losses/lovasz_loss_utils.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. """Directly borrowed from mmsegmentation.""" import functools -from typing import List +from typing import List, Union import numpy as np import torch @@ -9,12 +9,12 @@ from mmengine.fileio import load -def get_class_weight(class_weight: List[float] or str) -> List[float]: +def get_class_weight(class_weight: Union[List[float], str]) -> List[float]: """Get class weight for loss function. Args: - class_weight (list[float] | str | None): If class_weight is a str, - take it as a file name and read from it. + class_weight (Union[list[float], str], optional): If class_weight + is a str, take it as a file name and read from it. Return: list[float]: Loaded class_weight. From 8931d4e4bcfe2103ceb0273b1373543ea1649de0 Mon Sep 17 00:00:00 2001 From: xzq Date: Tue, 21 Feb 2023 17:10:11 +0800 Subject: [PATCH 11/18] update --- mmdet3d/models/losses/lovasz_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py index 1c854b2406..cd01f47fd2 100644 --- a/mmdet3d/models/losses/lovasz_loss.py +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -217,7 +217,7 @@ def lovasz_softmax_flat( def lovasz_softmax(probs: torch.Tensor, labels: torch.Tensor, - classes: Optional[Union[str, List[int]]] = 'present', + classes: Union[str, List[int]] = 'present', per_image: bool = False, class_weight: List[float] = None, reduction: str = 'mean', From 5788d78c590cb5164860d6e46ce5eb41e77dc362 Mon Sep 17 00:00:00 2001 From: xzq Date: Wed, 22 Feb 2023 19:18:24 +0800 Subject: [PATCH 12/18] update --- mmdet3d/models/losses/lovasz_loss.py | 120 +++++++++++++++++---------- 1 file changed, 74 insertions(+), 46 deletions(-) diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py index cd01f47fd2..7b4b1adbbf 100644 --- a/mmdet3d/models/losses/lovasz_loss.py +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -8,19 +8,43 @@ from typing import List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from mmdet.models import weight_reduce_loss +from mmengine.fileio import load from mmengine.utils import is_list_of from mmdet3d.registry import MODELS -from .lovasz_loss_utils import get_class_weight, weight_reduce_loss + + +def get_class_weight(class_weight: Union[List[float], str]) -> List[float]: + """Get class weight for loss function. + + Args: + class_weight (Union[list[float], str], optional): If class_weight + is a str, take it as a file name and read from it. + + Return: + list[float]: Loaded class_weight. + """ + if isinstance(class_weight, str): + # take it as a file path + if class_weight.endswith('.npy'): + class_weight = np.load(class_weight) + else: + # pkl, json or yaml + class_weight = load(class_weight) + + return class_weight def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor: """Computes gradient of the Lovasz extension w.r.t sorted errors. See Alg. 1 in paper. + `The Lovasz-Softmax loss. `_. Args: gt_sorted (torch.Tensor): Sorted ground truth. @@ -43,7 +67,7 @@ def flatten_binary_logits( labels: torch.Tensor, ignore_index: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - """Flattens predictions and labels in the batch (binary case). Remove + """Flatten predictions and labels in the batch (binary case). Remove tensors whose labels equal to 'ignore_index'. Args: @@ -70,8 +94,8 @@ def flatten_probs( labels: torch.Tensor, ignore_index: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - """Flattens predictions and labels in the batch. Remove tensors whose - labels equal to 'ignore_index'. + """Flatten predictions and labels in the batch. Remove tensors whose labels + equal to 'ignore_index'. Args: probs (torch.Tensor): Predictions to be modified. @@ -82,13 +106,15 @@ def flatten_probs( Return: tuple(torch.Tensor, torch.Tensor): Modified predictions and labels. """ - if probs.dim() == 3: - # assumes output of a sigmoid layer - B, H, W = probs.size() - probs = probs.view(B, 1, H, W) - B, C, H, W = probs.size() - probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C - labels = labels.view(-1) + if probs.dim() != 2: # for input with P*C + if probs.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probs.size() + probs = probs.view(B, 1, H, W) + B, C, H, W = probs.size() + probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, + C) # B*H*W, C=P,C + labels = labels.view(-1) if ignore_index is None: return probs, labels valid = (labels != ignore_index) @@ -102,9 +128,10 @@ def lovasz_hinge_flat(logits: torch.Tensor, """Binary Lovasz hinge loss. Args: - logits (torch.Tensor): [P], logits at each prediction - (between -infty and +infty). - labels (torch.Tensor): [P], binary ground truth labels (0 or 1). + logits (torch.Tensor): Logits at each prediction + (between -infty and +infty) with shape [P]. + labels (torch.Tensor): Binary ground truth labels (0 or 1) + with shape [P]. Returns: torch.Tensor: The calculated loss. @@ -125,7 +152,7 @@ def lovasz_hinge_flat(logits: torch.Tensor, def lovasz_hinge(logits: torch.Tensor, labels: torch.Tensor, classes: Optional[Union[str, List[int]]] = None, - per_image: bool = False, + per_sample: bool = False, class_weight: Optional[List[float]] = None, reduction: str = 'mean', avg_factor: Optional[int] = None, @@ -133,28 +160,29 @@ def lovasz_hinge(logits: torch.Tensor, """Binary Lovasz hinge loss. Args: - logits (torch.Tensor): [B, H, W], logits at each pixel - (between -infty and +infty). - labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). + logits (torch.Tensor): Logits at each pixel + (between -infty and +infty) with shape [B, H, W]. + labels (torch.Tensor): Binary ground truth masks (0 or 1) + with shape [B, H, W]. classes (Union[str, list[int]], optional): Placeholder, to be consistent with other loss. Defaults to None. - per_image (bool): If per_image is True, compute the loss per - image instead of per batch. Defaults to False. + per_sample (bool): If per_sample is True, compute the loss per + sample instead of per batch. Defaults to False. class_weight (list[float], optional): Placeholder, to be consistent with other loss. Defaults to None. reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when - per_image is True. Defaults to 'mean'. + per_sample is True. Defaults to 'mean'. avg_factor (int, optional): Average factor that is used to average - the loss. This parameter only works when per_image is True. + the loss. This parameter only works when per_sample is True. Defaults to None. - ignore_index (int, optional): The label index to be ignored. + ignore_index (Union[int, None]): The label index to be ignored. Defaults to 255. Returns: torch.Tensor: The calculated loss. """ - if per_image: + if per_sample: loss = [ lovasz_hinge_flat(*flatten_binary_logits( logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) @@ -176,9 +204,10 @@ def lovasz_softmax_flat( """Multi-class Lovasz-Softmax loss. Args: - probs (torch.Tensor): [P, C], class probabilities at each prediction - (between 0 and 1). - labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). + probs (torch.Tensor): Class probabilities at each prediction + (between 0 and 1) with shape [P, C] + labels (torch.Tensor): Ground truth labels (between 0 and C - 1) + with shape [P]. classes (Union[str, list[int]]): Classes chosen to calculate loss. 'all' for all classes, 'present' for classes present in labels, or a list of classes to average. Defaults to 'present'. @@ -218,7 +247,7 @@ def lovasz_softmax_flat( def lovasz_softmax(probs: torch.Tensor, labels: torch.Tensor, classes: Union[str, List[int]] = 'present', - per_image: bool = False, + per_sample: bool = False, class_weight: List[float] = None, reduction: str = 'mean', avg_factor: Optional[int] = None, @@ -226,22 +255,22 @@ def lovasz_softmax(probs: torch.Tensor, """Multi-class Lovasz-Softmax loss. Args: - probs (torch.Tensor): [B, C, H, W], class probabilities at each - prediction (between 0 and 1). - labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and - C - 1). + probs (torch.Tensor): Class probabilities at each + prediction (between 0 and 1) with shape [B, C, H, W]. + labels (torch.Tensor): Ground truth labels (between 0 and + C - 1) with shape [B, H, W]. classes (Union[str, list[int]]): Classes chosen to calculate loss. 'all' for all classes, 'present' for classes present in labels, or a list of classes to average. Defaults to 'present'. - per_image (bool): If per_image is True, compute the loss per - image instead of per batch. Defaults to False. + per_sample (bool): If per_sample is True, compute the loss per + sample instead of per batch. Defaults to False. class_weight (list[float], optional): The weight for each class. Defaults to None. reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when - per_image is True. Defaults to 'mean'. + per_sample is True. Defaults to 'mean'. avg_factor (int, optional): Average factor that is used to average - the loss. This parameter only works when per_image is True. + the loss. This parameter only works when per_sample is True. Defaults to None. ignore_index (Union[int, None]): The label index to be ignored. Defaults to 255. @@ -250,7 +279,7 @@ def lovasz_softmax(probs: torch.Tensor, torch.Tensor: The calculated loss. """ - if per_image: + if per_sample: loss = [ lovasz_softmax_flat( *flatten_probs( @@ -283,11 +312,11 @@ class LovaszLoss(nn.Module): classes (Union[str, list[int]]): Classes chosen to calculate loss. 'all' for all classes, 'present' for classes present in labels, or a list of classes to average. Defaults to 'present'. - per_image (bool): If per_image is True, compute the loss per - image instead of per batch. Defaults to False. + per_sample (bool): If per_sample is True, compute the loss per + sample instead of per batch. Defaults to False. reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when - per_image is True. Defaults to 'mean'. + per_sample is True. Defaults to 'mean'. class_weight (Union[list[float], str], optional): Weight of each class. If in str format, read them from a file. Defaults to None. loss_weight (float): Weight of the loss. Defaults to 1.0. @@ -299,7 +328,7 @@ class LovaszLoss(nn.Module): def __init__(self, loss_type: str = 'multi_class', classes: Union[str, List[int]] = 'present', - per_image: bool = False, + per_sample: bool = False, reduction: str = 'mean', class_weight: Optional[Union[List[float], str]] = None, loss_weight: float = 1.0, @@ -313,12 +342,12 @@ def __init__(self, else: self.cls_criterion = lovasz_softmax assert classes in ('all', 'present') or is_list_of(classes, int) - if not per_image: + if not per_sample: assert reduction == 'none', "reduction should be 'none' when \ - per_image is False." + per_sample is False." self.classes = classes - self.per_image = per_image + self.per_sample = per_sample self.reduction = reduction self.loss_weight = loss_weight self.class_weight = get_class_weight(class_weight) @@ -347,14 +376,13 @@ def forward(self, cls_score, label, self.classes, - self.per_image, + self.per_sample, class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss_cls - @property def loss_name(self) -> str: """Loss Name. From 034960cb280c51424aa6fdb2227542f1f3e8c91c Mon Sep 17 00:00:00 2001 From: xzq Date: Wed, 22 Feb 2023 19:19:11 +0800 Subject: [PATCH 13/18] update --- mmdet3d/models/losses/lovasz_loss_utils.py | 134 --------------------- 1 file changed, 134 deletions(-) delete mode 100644 mmdet3d/models/losses/lovasz_loss_utils.py diff --git a/mmdet3d/models/losses/lovasz_loss_utils.py b/mmdet3d/models/losses/lovasz_loss_utils.py deleted file mode 100644 index 1c88d63079..0000000000 --- a/mmdet3d/models/losses/lovasz_loss_utils.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""Directly borrowed from mmsegmentation.""" -import functools -from typing import List, Union - -import numpy as np -import torch -import torch.nn.functional as F -from mmengine.fileio import load - - -def get_class_weight(class_weight: Union[List[float], str]) -> List[float]: - """Get class weight for loss function. - - Args: - class_weight (Union[list[float], str], optional): If class_weight - is a str, take it as a file name and read from it. - - Return: - list[float]: Loaded class_weight. - """ - if isinstance(class_weight, str): - # take it as a file path - if class_weight.endswith('.npy'): - class_weight = np.load(class_weight) - else: - # pkl, json or yaml - class_weight = load(class_weight) - - return class_weight - - -def reduce_loss(loss: torch.Tensor, reduction: str) -> torch.Tensor: - """Reduce loss as specified. - - Args: - loss (torch.Tensor): Elementwise loss tensor. - reduction (str): Options are "none", "mean" and "sum". - - Return: - torch.Tensor: Reduced loss tensor. - """ - reduction_enum = F._Reduction.get_enum(reduction) - # none: 0, elementwise_mean:1, sum: 2 - if reduction_enum == 0: - return loss - elif reduction_enum == 1: - return loss.mean() - elif reduction_enum == 2: - return loss.sum() - - -def weight_reduce_loss(loss: torch.Tensor, - weight: torch.Tensor = None, - reduction: str = 'mean', - avg_factor: float = None) -> torch.Tensor: - """Apply element-wise weight and reduce loss. - - Args: - loss (torch.Tensor): Element-wise loss. - weight (torch.Tensor): Element-wise weights. - reduction (str): Same as built-in losses of PyTorch. - avg_factor (float): Average factor when computing the mean of losses. - - Returns: - torch.Tensor: Processed loss values. - """ - # if weight is specified, apply element-wise weight - if weight is not None: - assert weight.dim() == loss.dim() - if weight.dim() > 1: - assert weight.size(1) == 1 or weight.size(1) == loss.size(1) - loss = loss * weight - - # if avg_factor is not specified, just reduce the loss - if avg_factor is None: - loss = reduce_loss(loss, reduction) - else: - # if reduction is mean, then average the loss by avg_factor - if reduction == 'mean': - # Avoid causing ZeroDivisionError when avg_factor is 0.0, - # i.e., all labels of an image belong to ignore index. - eps = torch.finfo(torch.float32).eps - loss = loss.sum() / (avg_factor + eps) - # if reduction is 'none', then do nothing, otherwise raise an error - elif reduction != 'none': - raise ValueError('avg_factor can not be used with reduction="sum"') - return loss - - -def weighted_loss(loss_func): - """Create a weighted version of a given loss function. - - To use this decorator, the loss function must have the signature like - `loss_func(pred, target, **kwargs)`. The function only needs to compute - element-wise loss without any reduction. This decorator will add weight - and reduction arguments to the function. The decorated function will have - the signature like `loss_func(pred, target, weight=None, reduction='mean', - avg_factor=None, **kwargs)`. - - :Example: - - >>> import torch - >>> @weighted_loss - >>> def l1_loss(pred, target): - >>> return (pred - target).abs() - - >>> pred = torch.Tensor([0, 2, 3]) - >>> target = torch.Tensor([1, 1, 1]) - >>> weight = torch.Tensor([1, 0, 1]) - - >>> l1_loss(pred, target) - tensor(1.3333) - >>> l1_loss(pred, target, weight) - tensor(1.) - >>> l1_loss(pred, target, reduction='none') - tensor([1., 1., 2.]) - >>> l1_loss(pred, target, weight, avg_factor=2) - tensor(1.5000) - """ - - @functools.wraps(loss_func) - def wrapper(pred, - target, - weight=None, - reduction='mean', - avg_factor=None, - **kwargs): - # get element-wise loss - loss = loss_func(pred, target, **kwargs) - loss = weight_reduce_loss(loss, weight, reduction, avg_factor) - return loss - - return wrapper From 545d8782d74c55f3019c23286541f1bdfc868f0f Mon Sep 17 00:00:00 2001 From: xzq Date: Thu, 23 Feb 2023 14:15:29 +0800 Subject: [PATCH 14/18] cylinder3d_head --- .../models/decode_heads/cylinder3d_head.py | 4 +- mmdet3d/models/decode_heads/decode_head.py | 2 +- mmdet3d/models/losses/lovasz_loss.py | 52 ++----------------- 3 files changed, 8 insertions(+), 50 deletions(-) diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index 80afcd4061..44ace4d1fe 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -2,7 +2,7 @@ from typing import Optional import torch -from mmcv.ops import SparseConvTensor, SubMConv3d +from mmcv.ops import SparseConvTensor, SparseModule, SubMConv3d from mmdet3d.registry import MODELS from mmdet3d.structures.det3d_data_sample import SampleList @@ -77,7 +77,7 @@ def __init__(self, self.ignore_index = ignore_index def build_conv_seg(self, channels: int, num_classes: int, - kernel_size: int) -> SparseConvTensor: + kernel_size: int) -> SparseModule: return SubMConv3d( channels, num_classes, diff --git a/mmdet3d/models/decode_heads/decode_head.py b/mmdet3d/models/decode_heads/decode_head.py index 5c8bbd672f..58688d8df5 100644 --- a/mmdet3d/models/decode_heads/decode_head.py +++ b/mmdet3d/models/decode_heads/decode_head.py @@ -104,7 +104,7 @@ def forward(self, feats_dict: dict) -> Tensor: pass def build_conv_seg(self, channels: int, num_classes: int, - kernel_size: int) -> nn.Conv1d: + kernel_size: int) -> nn.Module: """Build Convolutional Segmentation Layers.""" return nn.Conv1d(channels, num_classes, kernel_size=kernel_size) diff --git a/mmdet3d/models/losses/lovasz_loss.py b/mmdet3d/models/losses/lovasz_loss.py index 7b4b1adbbf..a9bcc270bd 100644 --- a/mmdet3d/models/losses/lovasz_loss.py +++ b/mmdet3d/models/losses/lovasz_loss.py @@ -8,38 +8,15 @@ from typing import List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmdet.models import weight_reduce_loss -from mmengine.fileio import load from mmengine.utils import is_list_of from mmdet3d.registry import MODELS -def get_class_weight(class_weight: Union[List[float], str]) -> List[float]: - """Get class weight for loss function. - - Args: - class_weight (Union[list[float], str], optional): If class_weight - is a str, take it as a file name and read from it. - - Return: - list[float]: Loaded class_weight. - """ - if isinstance(class_weight, str): - # take it as a file path - if class_weight.endswith('.npy'): - class_weight = np.load(class_weight) - else: - # pkl, json or yaml - class_weight = load(class_weight) - - return class_weight - - def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor: """Computes gradient of the Lovasz extension w.r.t sorted errors. @@ -317,12 +294,9 @@ class LovaszLoss(nn.Module): reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when per_sample is True. Defaults to 'mean'. - class_weight (Union[list[float], str], optional): Weight of each class. - If in str format, read them from a file. Defaults to None. + class_weight ([list[float], optional): Weight of each class. + Defaults to None. loss_weight (float): Weight of the loss. Defaults to 1.0. - loss_name (str): Name of the loss item. If you want this loss - item to be included into the backward graph, `loss_` must be the - prefix of the name. Defaults to 'loss_lovasz'. """ def __init__(self, @@ -330,9 +304,8 @@ def __init__(self, classes: Union[str, List[int]] = 'present', per_sample: bool = False, reduction: str = 'mean', - class_weight: Optional[Union[List[float], str]] = None, - loss_weight: float = 1.0, - loss_name: str = 'loss_lovasz'): + class_weight: Optional[List[float]] = None, + loss_weight: float = 1.0): super().__init__() assert loss_type in ('binary', 'multi_class'), "loss_type should be \ 'binary' or 'multi_class'." @@ -350,8 +323,7 @@ def __init__(self, self.per_sample = per_sample self.reduction = reduction self.loss_weight = loss_weight - self.class_weight = get_class_weight(class_weight) - self._loss_name = loss_name + self.class_weight = class_weight def forward(self, cls_score: torch.Tensor, @@ -382,17 +354,3 @@ def forward(self, avg_factor=avg_factor, **kwargs) return loss_cls - - def loss_name(self) -> str: - """Loss Name. - - This function must be implemented and will return the name of this - loss function. This name will be used to combine different loss items - by simple sum operation. In addition, if you want this loss item to be - included into the backward graph, `loss_` must be the prefix of the - name. - - Returns: - str: The name of this loss item. - """ - return self._loss_name From 39ecf13b7aa8300daeffb50c6e61f9b760a94507 Mon Sep 17 00:00:00 2001 From: xzq Date: Thu, 23 Feb 2023 14:21:00 +0800 Subject: [PATCH 15/18] update --- mmdet3d/models/decode_heads/cylinder3d_head.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index 44ace4d1fe..a6085b71f6 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -1,12 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional import torch from mmcv.ops import SparseConvTensor, SparseModule, SubMConv3d from mmdet3d.registry import MODELS from mmdet3d.structures.det3d_data_sample import SampleList -from mmdet3d.utils import OptConfigType +from mmdet3d.utils import OptMultiConfig from mmdet3d.utils.typing_utils import ConfigType from .decode_head import Base3DDecodeHead @@ -61,7 +60,7 @@ def __init__(self, type='LovaszLoss', loss_weight=1.0), conv_seg_kernel_size: int = 3, ignore_index: int = 0, - init_cfg: Optional[dict or OptConfigType] = None) -> None: + init_cfg: OptMultiConfig = None) -> None: super(Cylinder3DHead, self).__init__( channels=channels, num_classes=num_classes, @@ -97,7 +96,7 @@ def loss_by_feat(self, seg_logit: SparseConvTensor, """Compute semantic segmentation loss. Args: - seg_logit (spconv.SparseConvTensor): Predicted per-voxel + seg_logit (SparseConvTensor): Predicted per-voxel segmentation logits of shape [num_voxels, num_classes] stored in SparseConvTensor. batch_data_samples (List[:obj:`Det3DDataSample`]): The seg From 9c0e75a96b20d55c41645939b68290514b7e3e76 Mon Sep 17 00:00:00 2001 From: xzq Date: Thu, 23 Feb 2023 15:08:53 +0800 Subject: [PATCH 16/18] update --- .../models/decode_heads/cylinder3d_head.py | 34 +++++++++++++++++++ .../test_decode_heads/test_cylinder3d_head.py | 2 ++ 2 files changed, 36 insertions(+) diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index a6085b71f6..f57b05cf02 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -122,3 +122,37 @@ def loss_by_feat(self, seg_logit: SparseConvTensor, seg_logit_feat, seg_label, ignore_index=self.ignore_index) return loss + + def predict( + self, + inputs: SparseConvTensor, + batch_inputs_dict: dict, + batch_data_samples: SampleList, + ) -> torch.Tensor: + """Forward function for testing. + + Args: + inputs (SparseConvTensor): Feature from backbone. + batch_inputs_dict (dict): Input sample dict which includes 'points' + and 'voxels' keys. + - points (List[Tensor]): Point cloud of each sample. + - voxels (List[Tensor]): Image tensor has shape (B, C, H, W). + batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data + samples. It usually includes information such as `metainfo` and + `gt_pts_seg`. We use `point2voxel_map` in this function. + + Returns: + torch.Tensor: Output point-wise segmentation logits. + """ + seg_logits = self.forward(inputs) + + seg_pred_list = [] + coors = batch_inputs_dict['voxels']['voxel_coors'] + for batch_idx in range(len(batch_data_samples)): + seg_logits_sample = seg_logits[coors[:, 0] == batch_idx] + point2voxel_map = batch_data_samples[ + batch_idx].gt_pts_seg.point2voxel_map.long() + point_seg_predicts = seg_logits_sample[point2voxel_map] + seg_pred_list.append(point_seg_predicts) + + return seg_logits diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py index 2bad4dc9ef..ec50c01758 100644 --- a/tests/test_models/test_decode_heads/test_cylinder3d_head.py +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -58,3 +58,5 @@ def test_cylinder3d_head_loss(self): self.assertGreater(loss_ce, 0, 'ce loss should be positive') self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive') + + cylinder3d_head.predict() From 3f71f15518dbdff54947f9716325289571cd5589 Mon Sep 17 00:00:00 2001 From: xzq Date: Thu, 23 Feb 2023 15:19:43 +0800 Subject: [PATCH 17/18] update --- .../data_preprocessors/data_preprocessor.py | 70 +++++++++++++++++-- .../models/decode_heads/cylinder3d_head.py | 4 +- .../test_decode_heads/test_cylinder3d_head.py | 1 + 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/mmdet3d/models/data_preprocessors/data_preprocessor.py b/mmdet3d/models/data_preprocessors/data_preprocessor.py index 28fa434c64..b1e8a0b4cf 100644 --- a/mmdet3d/models/data_preprocessors/data_preprocessor.py +++ b/mmdet3d/models/data_preprocessors/data_preprocessor.py @@ -5,15 +5,16 @@ import numpy as np import torch -from mmcv.ops import Voxelization from mmdet.models import DetDataPreprocessor from mmengine.model import stack_batch from mmengine.utils import is_list_of from torch.nn import functional as F from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.utils import OptConfigType from .utils import multiview_img_stack_batch +from .voxelize import VoxelizationByGridShape, dynamic_scatter_3d @MODELS.register_module() @@ -103,7 +104,7 @@ def __init__(self, self.voxel = voxel self.voxel_type = voxel_type if voxel: - self.voxel_layer = Voxelization(**voxel_layer) + self.voxel_layer = VoxelizationByGridShape(**voxel_layer) def forward(self, data: Union[dict, List[dict]], @@ -157,7 +158,7 @@ def simple_process(self, data: dict, training: bool = False) -> dict: batch_inputs['points'] = inputs['points'] if self.voxel: - voxel_dict = self.voxelize(inputs['points']) + voxel_dict = self.voxelize(inputs['points'], data_samples) batch_inputs['voxels'] = voxel_dict if 'imgs' in inputs: @@ -329,11 +330,14 @@ def _get_pad_shape(self, data: dict) -> List[tuple]: return batch_pad_shape @torch.no_grad() - def voxelize(self, points: List[torch.Tensor]) -> Dict[str, torch.Tensor]: + def voxelize(self, points: List[torch.Tensor], + data_samples: SampleList) -> Dict[str, torch.Tensor]: """Apply voxelization to point cloud. Args: points (List[Tensor]): Point cloud in one data batch. + data_samples: (list[:obj:`Det3DDataSample`]): The annotation data + of every samples. Add voxel-wise annotation for segmentation. Returns: Dict[str, Tensor]: Voxelization information. @@ -378,6 +382,39 @@ def voxelize(self, points: List[torch.Tensor]) -> Dict[str, torch.Tensor]: coors.append(res_coors) voxels = torch.cat(points, dim=0) coors = torch.cat(coors, dim=0) + elif self.voxel_type == 'cylindrical': + voxels, coors = [], [] + for i, (res, data_sample) in enumerate(zip(points, data_samples)): + rho = torch.sqrt(res[:, 0]**2 + res[:, 1]**2) + phi = torch.atan2(res[:, 1], res[:, 0]) + polar_res = torch.stack((rho, phi, res[:, 2]), dim=-1) + min_bound = polar_res.new_tensor( + self.voxel_layer.point_cloud_range[:3]) + max_bound = polar_res.new_tensor( + self.voxel_layer.point_cloud_range[3:]) + try: # only support PyTorch >= 1.9.0 + polar_res_clamp = torch.clamp(polar_res, min_bound, + max_bound) + except TypeError: + polar_res_clamp = polar_res.clone() + for coor_idx in range(3): + polar_res_clamp[:, coor_idx][ + polar_res[:, coor_idx] > + max_bound[coor_idx]] = max_bound[coor_idx] + polar_res_clamp[:, coor_idx][ + polar_res[:, coor_idx] < + min_bound[coor_idx]] = min_bound[coor_idx] + res_coors = torch.floor( + (polar_res_clamp - min_bound) / polar_res_clamp.new_tensor( + self.voxel_layer.voxel_size)).int() + self.get_voxel_seg(res_coors, data_sample, not self.training) + res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i) + res_voxels = torch.cat((polar_res, res[:, :2], res[:, 3:]), + dim=-1) + voxels.append(res_voxels) + coors.append(res_coors) + voxels = torch.cat(voxels, dim=0) + coors = torch.cat(coors, dim=0) else: raise ValueError(f'Invalid voxelization type {self.voxel_type}') @@ -385,3 +422,28 @@ def voxelize(self, points: List[torch.Tensor]) -> Dict[str, torch.Tensor]: voxel_dict['coors'] = coors return voxel_dict + + def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList, + test_mode: bool): + """Get voxel-wise segmentation label and point2voxel map. + + Args: + res_coors (Tensor): The voxel coordinates of points, Nx3. + data_sample: (:obj:`Det3DDataSample`): The annotation data of + every samples. Add voxel-wise annotation forsegmentation. + test_mode (bool): Whether in test mode or not. + """ + + if not test_mode: + pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask + voxel_semantic_mask, _, point2voxel_map = dynamic_scatter_3d( + F.one_hot(pts_semantic_mask.long()).float(), res_coors, 'mean', + True) + voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1) + data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask + data_sample.gt_pts_seg.point2voxel_map = point2voxel_map + else: + pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float() + _, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor, + res_coors, 'mean', True) + data_sample.gt_pts_seg.point2voxel_map = point2voxel_map diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index f57b05cf02..9284ffd87f 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -135,8 +135,10 @@ def predict( inputs (SparseConvTensor): Feature from backbone. batch_inputs_dict (dict): Input sample dict which includes 'points' and 'voxels' keys. + - points (List[Tensor]): Point cloud of each sample. - - voxels (List[Tensor]): Image tensor has shape (B, C, H, W). + - voxels (dict): Dict of voxelized voxels and the corresponding + coordinates. batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data samples. It usually includes information such as `metainfo` and `gt_pts_seg`. We use `point2voxel_map` in this function. diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py index ec50c01758..461a52ed93 100644 --- a/tests/test_models/test_decode_heads/test_cylinder3d_head.py +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -59,4 +59,5 @@ def test_cylinder3d_head_loss(self): self.assertGreater(loss_ce, 0, 'ce loss should be positive') self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive') + # datasample.gt_pts_seg.point2voxel_map = cylinder3d_head.predict() From 57a02189160501bb44649df1e242c55d80f2b975 Mon Sep 17 00:00:00 2001 From: xzq Date: Thu, 23 Feb 2023 15:39:48 +0800 Subject: [PATCH 18/18] update --- mmdet3d/models/decode_heads/cylinder3d_head.py | 6 +++--- .../test_decode_heads/test_cylinder3d_head.py | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index 9284ffd87f..dd13fd4dd3 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -144,9 +144,9 @@ def predict( `gt_pts_seg`. We use `point2voxel_map` in this function. Returns: - torch.Tensor: Output point-wise segmentation logits. + List[torch.Tensor]: List of point-wise segmentation logits. """ - seg_logits = self.forward(inputs) + seg_logits = self.forward(inputs).features seg_pred_list = [] coors = batch_inputs_dict['voxels']['voxel_coors'] @@ -157,4 +157,4 @@ def predict( point_seg_predicts = seg_logits_sample[point2voxel_map] seg_pred_list.append(point_seg_predicts) - return seg_logits + return seg_pred_list diff --git a/tests/test_models/test_decode_heads/test_cylinder3d_head.py b/tests/test_models/test_decode_heads/test_cylinder3d_head.py index 1caa11d802..3bb62c5eef 100644 --- a/tests/test_models/test_decode_heads/test_cylinder3d_head.py +++ b/tests/test_models/test_decode_heads/test_cylinder3d_head.py @@ -59,8 +59,9 @@ def test_cylinder3d_head_loss(self): self.assertGreater(loss_ce, 0, 'ce loss should be positive') self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive') + batch_inputs_dict = dict(voxels=dict(voxel_coors=coors)) datasample.gt_pts_seg.point2voxel_map = torch.randint( - 0, 50, (100, 1)).int().cuda() - point_logits = cylinder3d_head.predict(sparse_voxels, coors, - datasample) - assert point_logits.shape == (100, 20) + 0, 50, (100, )).int().cuda() + point_logits = cylinder3d_head.predict(sparse_voxels, + batch_inputs_dict, [datasample]) + assert point_logits[0].shape == torch.Size([100, 20])