From b4ebf368e0baf877d2074d176551d6e52cd88231 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Thu, 15 Jun 2023 10:28:18 +0300 Subject: [PATCH 01/11] added unit tests --- .../training/metrics/segmentation_metrics.py | 46 ++++++++++++++++++- ...gnore_indices_segmentation_metrics_test.py | 32 +++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index a4a847dbe2..14568eb1fb 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -2,7 +2,7 @@ import torch import torchmetrics from torchmetrics import Metric -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Union from torchmetrics.utilities.distributed import reduce from abc import ABC, abstractmethod @@ -122,6 +122,18 @@ def intersection_and_union(im_pred, im_lab, num_class): return area_inter, area_union +def _map_ignored_inds(target: torch.Tensor, ignore_index_list, unfiltered_num_classes) -> torch.Tensor: + target_copy = torch.zeros_like(target) + all_unfiltered_classes = list(range(unfiltered_num_classes)) + filtered_classes = [i for i in all_unfiltered_classes if i not in ignore_index_list] + for mapped_idx in range(len(filtered_classes)): + cls_to_map = filtered_classes[mapped_idx] + map_val = mapped_idx + 1 + target_copy[target == cls_to_map] = map_val + + return target_copy + + class AbstractMetricsArgsPrepFn(ABC): """ Abstract preprocess metrics arguments class. @@ -194,7 +206,7 @@ def __init__( self, num_classes: int, dist_sync_on_step: bool = False, - ignore_index: Optional[int] = None, + ignore_index: Optional[Union[int, List[int]]] = None, reduction: str = "elementwise_mean", threshold: float = 0.5, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None, @@ -203,12 +215,27 @@ def __init__( if num_classes <= 1: raise ValueError(f"IoU class only for multi-class usage! For binary usage, please call {BinaryIOU.__name__}") + if isinstance(ignore_index, list): + ignore_index_list = ignore_index + unfiltered_num_classes = num_classes + num_classes = num_classes - len(ignore_index_list) + 1 + ignore_index = 0 + else: + unfiltered_num_classes = num_classes + ignore_index_list = None + super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) + + self.unfiltered_num_classes = unfiltered_num_classes + self.ignore_index_list = ignore_index_list self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) self.greater_is_better = True def update(self, preds, target: torch.Tensor): preds, target = self.metrics_args_prep_fn(preds, target) + if self.ignore_index_list is not None: + target = _map_ignored_inds(target, self.ignore_index_list, self.unfiltered_num_classes) + preds = _map_ignored_inds(preds, self.ignore_index_list, self.unfiltered_num_classes) super().update(preds=preds, target=target) @@ -227,12 +254,27 @@ def __init__( if num_classes <= 1: raise ValueError(f"Dice class only for multi-class usage! For binary usage, please call {BinaryDice.__name__}") + if isinstance(ignore_index, list): + ignore_index_list = ignore_index + unfiltered_num_classes = num_classes + num_classes = num_classes - len(ignore_index_list) + 1 + ignore_index = 0 + if ignore_index not in ignore_index_list: + raise ValueError("ignore_index_mapping must be in ignore_index_list") + else: + ignore_index_list = None + super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) + + self.ignore_index_list = ignore_index_list + self.unfiltered_num_classes = unfiltered_num_classes self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) self.greater_is_better = True def update(self, preds, target: torch.Tensor): preds, target = self.metrics_args_prep_fn(preds, target) + if self.ignore_index_list is not None: + target = _map_ignored_inds(target, self.ignore_index_list, self.ignore_index) super().update(preds=preds, target=target) def compute(self) -> torch.Tensor: diff --git a/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py b/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py new file mode 100644 index 0000000000..812ae34d81 --- /dev/null +++ b/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py @@ -0,0 +1,32 @@ +import unittest + +import torch + +from super_gradients.training.metrics import IoU + + +class MyTestCase(unittest.TestCase): + def test_iou_with_multiple_ignored_classes_and_absent_score(self): + metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + + # preds after onehot -> [4,4,4,4,4,4] + # (1 + 0)/2 : 1.0 for class 4 score and 0 for absent score for class 0 + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5) + + def test_iou_with_multiple_ignored_classes_no_absent_score(self): + metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + pred[0, 0, 3] = 2 + + # preds after onehot -> [4,4,4,0,4,4] + # (1 + 1)/2 : 1.0 for class 4 score and 1 for class 0 + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1) + + +if __name__ == "__main__": + unittest.main() From 20af8b434914aef695c293c8f2815b0099fd61e2 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Thu, 15 Jun 2023 11:55:42 +0300 Subject: [PATCH 02/11] finalized unit tests --- .../training/metrics/segmentation_metrics.py | 48 +++++++++++-------- ...gnore_indices_segmentation_metrics_test.py | 29 ++++++++++- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index 14568eb1fb..d18b5d0891 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -176,7 +176,7 @@ def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten @register_metric(Metrics.PIXEL_ACCURACY) class PixelAccuracy(Metric): - def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): + def __init__(self, ignore_label: Union[int, List[int]] = -100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_label = ignore_label self.greater_is_better = True @@ -186,13 +186,26 @@ def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep def update(self, preds: torch.Tensor, target: torch.Tensor): predict, target = self.metrics_args_prep_fn(preds, target) + labeled_mask = self._handle_multiple_ignored_inds(target) - labeled_mask = target.ne(self.ignore_label) pixel_labeled = torch.sum(labeled_mask) pixel_correct = torch.sum((predict == target) * labeled_mask) self.total_correct += pixel_correct self.total_label += pixel_labeled + def _handle_multiple_ignored_inds(self, target): + if isinstance(self.ignore_label, list): + labeled_mask = None + for ignored_label in self.ignore_label: + if labeled_mask is None: + labeled_mask = target.ne(ignored_label) + else: + labeled_mask = torch.logical_and(labeled_mask, target.ne(ignored_label)) + else: + labeled_mask = target.ne(self.ignore_label) + + return labeled_mask + def compute(self): _total_correct = self.total_correct.cpu().detach().numpy().astype("int64") _total_label = self.total_label.cpu().detach().numpy().astype("int64") @@ -200,6 +213,18 @@ def compute(self): return pix_acc +def _handle_multiple_ignored_inds(ignore_index, num_classes): + if isinstance(ignore_index, list): + ignore_index_list = ignore_index + unfiltered_num_classes = num_classes + num_classes = num_classes - len(ignore_index_list) + 1 + ignore_index = 0 + else: + unfiltered_num_classes = num_classes + ignore_index_list = None + return ignore_index, ignore_index_list, num_classes, unfiltered_num_classes + + @register_metric(Metrics.IOU) class IoU(torchmetrics.JaccardIndex): def __init__( @@ -215,14 +240,7 @@ def __init__( if num_classes <= 1: raise ValueError(f"IoU class only for multi-class usage! For binary usage, please call {BinaryIOU.__name__}") - if isinstance(ignore_index, list): - ignore_index_list = ignore_index - unfiltered_num_classes = num_classes - num_classes = num_classes - len(ignore_index_list) + 1 - ignore_index = 0 - else: - unfiltered_num_classes = num_classes - ignore_index_list = None + ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes) super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) @@ -254,15 +272,7 @@ def __init__( if num_classes <= 1: raise ValueError(f"Dice class only for multi-class usage! For binary usage, please call {BinaryDice.__name__}") - if isinstance(ignore_index, list): - ignore_index_list = ignore_index - unfiltered_num_classes = num_classes - num_classes = num_classes - len(ignore_index_list) + 1 - ignore_index = 0 - if ignore_index not in ignore_index_list: - raise ValueError("ignore_index_mapping must be in ignore_index_list") - else: - ignore_index_list = None + ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes) super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) diff --git a/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py b/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py index 812ae34d81..f6e04db354 100644 --- a/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py +++ b/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py @@ -2,10 +2,10 @@ import torch -from super_gradients.training.metrics import IoU +from super_gradients.training.metrics import IoU, PixelAccuracy, Dice -class MyTestCase(unittest.TestCase): +class TestSegmentationMetricsMultipleIgnored(unittest.TestCase): def test_iou_with_multiple_ignored_classes_and_absent_score(self): metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2]) target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) @@ -27,6 +27,31 @@ def test_iou_with_multiple_ignored_classes_no_absent_score(self): # (1 + 1)/2 : 1.0 for class 4 score and 1 for class 0 self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1) + def test_dice_with_multiple_ignored_classes_and_absent_score(self): + metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5) + + def test_dice_with_multiple_ignored_classes_no_absent_score(self): + metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + pred[0, 0, 3] = 2 + + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5) + + def test_pixelaccuracy_with_multiple_ignored_classes(self): + metric_multi_ignored = PixelAccuracy(ignore_label=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1.0) + if __name__ == "__main__": unittest.main() From 115b804c55b1cfcf40ad592335e6778c6e6dd910 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Thu, 15 Jun 2023 19:54:23 +0300 Subject: [PATCH 03/11] updated docs and test suite --- .../training/metrics/segmentation_metrics.py | 61 +++++++++++++++++++ tests/deci_core_unit_test_suite_runner.py | 2 + 2 files changed, 63 insertions(+) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index d18b5d0891..8a57341725 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -176,6 +176,25 @@ def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten @register_metric(Metrics.PIXEL_ACCURACY) class PixelAccuracy(Metric): + """ + Pixel Accuracy + + Args: + ignore_label: Optional[Union[int, List[int]]], specifying a target class(es) to ignore. + If given, this class index does not contribute to the returned score, regardless of reduction method. + Has no effect if given an int that is not in the range [0, num_classes-1]. + By default, no index is ignored, and all classes are used. + IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error. + reduction: a method to reduce metric score over labels: + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics. + By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True) + """ + def __init__(self, ignore_label: Union[int, List[int]] = -100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_label = ignore_label @@ -227,6 +246,27 @@ def _handle_multiple_ignored_inds(ignore_index, num_classes): @register_metric(Metrics.IOU) class IoU(torchmetrics.JaccardIndex): + """ + IoU Metric + + Args: + num_classes: Number of classes in the dataset. + ignore_index: Optional[Union[int, List[int]]], specifying a target class(es) to ignore. + If given, this class index does not contribute to the returned score, regardless of reduction method. + Has no effect if given an int that is not in the range [0, num_classes-1]. + By default, no index is ignored, and all classes are used. + IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error. + threshold: Threshold value for binary or multi-label probabilities. + reduction: a method to reduce metric score over labels: + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics. + By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True) + """ + def __init__( self, num_classes: int, @@ -259,6 +299,27 @@ def update(self, preds, target: torch.Tensor): @register_metric(Metrics.DICE) class Dice(torchmetrics.JaccardIndex): + """ + Dice Coefficient Metric + + Args: + num_classes: Number of classes in the dataset. + ignore_index: Optional[Union[int, List[int]]], specifying a target class(es) to ignore. + If given, this class index does not contribute to the returned score, regardless of reduction method. + Has no effect if given an int that is not in the range [0, num_classes-1]. + By default, no index is ignored, and all classes are used. + IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error. + threshold: Threshold value for binary or multi-label probabilities. + reduction: a method to reduce metric score over labels: + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics. + By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True) + """ + def __init__( self, num_classes: int, diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index d7beaa7658..7d7289ac92 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -29,6 +29,7 @@ from tests.unit_tests.load_checkpoint_test import LoadCheckpointTest from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest from tests.unit_tests.max_batches_loop_break_test import MaxBatchesLoopBreakTest +from tests.unit_tests.multiple_ignore_indices_segmentation_metrics_test import TestSegmentationMetricsMultipleIgnored from tests.unit_tests.phase_delegates_test import ContextMethodsTest from tests.unit_tests.pose_estimation_dataset_test import TestPoseEstimationDataset from tests.unit_tests.preprocessing_unit_test import PreprocessingUnitTest @@ -141,6 +142,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLONAS)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DeprecationsUnitTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMinSamplesSingleNode)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationMetricsMultipleIgnored)) def _add_modules_to_end_to_end_tests_suite(self): """ From c681c262f3e962b1faa3317e7fca7dd7d7bdbb37 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Sun, 18 Jun 2023 15:19:05 +0300 Subject: [PATCH 04/11] updated docs --- .../training/metrics/segmentation_metrics.py | 41 +++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index 8a57341725..afdb56cbb8 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -122,7 +122,24 @@ def intersection_and_union(im_pred, im_lab, num_class): return area_inter, area_union -def _map_ignored_inds(target: torch.Tensor, ignore_index_list, unfiltered_num_classes) -> torch.Tensor: +def _map_ignored_inds(target: torch.Tensor, ignore_index_list: List[int], unfiltered_num_classes: int) -> torch.Tensor: + """ + Creaetes a copy of target, mapping indices in range(unfiltered_num_classes) to range(unfiltered_num_classes-len( + ignore_index_list)+1). Indices in ignore_index_list are being mapped to 0, which can later on be used as + "ignore_index". + + Example: + >>>_map_ignored_inds(torch.tensor([0,1,2,3,4,5,6]), ignore_index_list=[3,5,1], unfiltered_num_classes=7) + >>> tensor([1, 0, 2, 0, 3, 0, 4]) + + + + :param target: torch.Tensor, tensor to perform the mapping on. + :param ignore_index_list: List[int], list of indices to map to 0 in the output tensor. + :param unfiltered_num_classes: int, Total number of possible class indices in target. + + :return: mapped tensor as described above. + """ target_copy = torch.zeros_like(target) all_unfiltered_classes = list(range(unfiltered_num_classes)) filtered_classes = [i for i in all_unfiltered_classes if i not in ignore_index_list] @@ -232,7 +249,24 @@ def compute(self): return pix_acc -def _handle_multiple_ignored_inds(ignore_index, num_classes): +def _handle_multiple_ignored_inds(ignore_index: Union[int, List[int]], num_classes: int): + """ + Helper method for variable assignment, prior to the + + super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) + + call in segmentation metrics inheriting from torchmetrics.JaccardIndex. + When ignore_index is list, the num_classes being passed to the torchmetrics.JaccardIndex c'tor is set to be the one after + mapping of the ignored indices in ignore_index_list to 0. Hence, we set: + ignore_index=0, + And since we map all of the ignored indices to 0, it is if we removed them and introduces a new index: + num_classes = num_classes - len(ignore_index_list) +1 + Unfiltered num_classes is used in .update() for mapping of the original indice values. + Sets ignore_index to 0 + :param ignore_index: list or single int representing the class ind(ices) to ignore. + :param num_classes: int, num_classes (original, before mapping) being passed to segmentation metric classesׄ + :return:ignore_index, ignore_index_list, num_classes, unfiltered_num_classesignore_index, ignore_index_list, num_classes, unfiltered_num_classes + """ if isinstance(ignore_index, list): ignore_index_list = ignore_index unfiltered_num_classes = num_classes @@ -279,7 +313,8 @@ def __init__( if num_classes <= 1: raise ValueError(f"IoU class only for multi-class usage! For binary usage, please call {BinaryIOU.__name__}") - + if isinstance(ignore_index, list) and reduction == "none": + raise ValueError("passing multiple ignore indices ") ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes) super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) From e55fc895b77641a9ffc8e7f646800a4f45f73ead Mon Sep 17 00:00:00 2001 From: shayaharon Date: Sun, 18 Jun 2023 15:25:04 +0300 Subject: [PATCH 05/11] updated update() --- src/super_gradients/training/metrics/segmentation_metrics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index afdb56cbb8..aa1b0643e4 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -380,7 +380,8 @@ def __init__( def update(self, preds, target: torch.Tensor): preds, target = self.metrics_args_prep_fn(preds, target) if self.ignore_index_list is not None: - target = _map_ignored_inds(target, self.ignore_index_list, self.ignore_index) + target = _map_ignored_inds(target, self.ignore_index_list, self.unfiltered_num_classes) + preds = _map_ignored_inds(preds, self.ignore_index_list, self.unfiltered_num_classes) super().update(preds=preds, target=target) def compute(self) -> torch.Tensor: From 3c1fcf2f059c5b396f043ab282d37857cbd8295c Mon Sep 17 00:00:00 2001 From: shayaharon Date: Sun, 18 Jun 2023 15:27:22 +0300 Subject: [PATCH 06/11] updated update() test --- .../multiple_ignore_indices_segmentation_metrics_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py b/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py index f6e04db354..34f8619f77 100644 --- a/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py +++ b/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py @@ -42,7 +42,7 @@ def test_dice_with_multiple_ignored_classes_no_absent_score(self): pred[:, 4] = 1 pred[0, 0, 3] = 2 - self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5) + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1.0) def test_pixelaccuracy_with_multiple_ignored_classes(self): metric_multi_ignored = PixelAccuracy(ignore_label=[3, 1, 2]) From a9351123c4a0f04b854f552cf04291dacec9500f Mon Sep 17 00:00:00 2001 From: shayaharon Date: Sun, 18 Jun 2023 16:12:18 +0300 Subject: [PATCH 07/11] renamed var in _handle_multiple_ignored_inds --- .../training/metrics/segmentation_metrics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index aa1b0643e4..ef9e32da69 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -231,16 +231,16 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): def _handle_multiple_ignored_inds(self, target): if isinstance(self.ignore_label, list): - labeled_mask = None + evaluated_classes_mask = None for ignored_label in self.ignore_label: - if labeled_mask is None: - labeled_mask = target.ne(ignored_label) + if evaluated_classes_mask is None: + evaluated_classes_mask = target.ne(ignored_label) else: - labeled_mask = torch.logical_and(labeled_mask, target.ne(ignored_label)) + evaluated_classes_mask = torch.logical_and(evaluated_classes_mask, target.ne(ignored_label)) else: - labeled_mask = target.ne(self.ignore_label) + evaluated_classes_mask = target.ne(self.ignore_label) - return labeled_mask + return evaluated_classes_mask def compute(self): _total_correct = self.total_correct.cpu().detach().numpy().astype("int64") From c44f09983c76feafb379b13a5cfa4a1d9e068a6e Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 19 Jun 2023 09:55:46 +0300 Subject: [PATCH 08/11] faster index mapping for pixel accuracy --- .../training/metrics/segmentation_metrics.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index ef9e32da69..0573b92e21 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -231,12 +231,9 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): def _handle_multiple_ignored_inds(self, target): if isinstance(self.ignore_label, list): - evaluated_classes_mask = None + labeled_mask = torch.ones_like(target) for ignored_label in self.ignore_label: - if evaluated_classes_mask is None: - evaluated_classes_mask = target.ne(ignored_label) - else: - evaluated_classes_mask = torch.logical_and(evaluated_classes_mask, target.ne(ignored_label)) + labeled_mask.masked_fill(target.eq(ignored_label), 0) else: evaluated_classes_mask = target.ne(self.ignore_label) From dba36944db7885c042cbcf90b14a57c63b979e3a Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 19 Jun 2023 09:56:32 +0300 Subject: [PATCH 09/11] faster index mapping for pixel accuracy fix --- src/super_gradients/training/metrics/segmentation_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index 0573b92e21..9d2731f151 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -231,9 +231,9 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): def _handle_multiple_ignored_inds(self, target): if isinstance(self.ignore_label, list): - labeled_mask = torch.ones_like(target) + evaluated_classes_mask = torch.ones_like(target) for ignored_label in self.ignore_label: - labeled_mask.masked_fill(target.eq(ignored_label), 0) + evaluated_classes_mask.masked_fill(target.eq(ignored_label), 0) else: evaluated_classes_mask = target.ne(self.ignore_label) From 6d8a3201c582dc715d2c6087da33499a80b3e699 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 19 Jun 2023 10:07:46 +0300 Subject: [PATCH 10/11] fixed non inplace op in pixel accuracy --- src/super_gradients/training/metrics/segmentation_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index 9d2731f151..e1d3001d93 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -233,7 +233,7 @@ def _handle_multiple_ignored_inds(self, target): if isinstance(self.ignore_label, list): evaluated_classes_mask = torch.ones_like(target) for ignored_label in self.ignore_label: - evaluated_classes_mask.masked_fill(target.eq(ignored_label), 0) + evaluated_classes_mask = evaluated_classes_mask.masked_fill(target.eq(ignored_label), 0) else: evaluated_classes_mask = target.ne(self.ignore_label) From 59620c8d909660766ccf813bc930e3d4b5dd9a44 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 19 Jun 2023 11:03:56 +0300 Subject: [PATCH 11/11] type checking for ignore index expanded to iterable --- .../training/metrics/segmentation_metrics.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index e1d3001d93..d934bc1d27 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -1,3 +1,5 @@ +import typing + import numpy as np import torch import torchmetrics @@ -230,7 +232,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.total_label += pixel_labeled def _handle_multiple_ignored_inds(self, target): - if isinstance(self.ignore_label, list): + if isinstance(self.ignore_label, typing.Iterable): evaluated_classes_mask = torch.ones_like(target) for ignored_label in self.ignore_label: evaluated_classes_mask = evaluated_classes_mask.masked_fill(target.eq(ignored_label), 0) @@ -264,7 +266,7 @@ def _handle_multiple_ignored_inds(ignore_index: Union[int, List[int]], num_class :param num_classes: int, num_classes (original, before mapping) being passed to segmentation metric classesׄ :return:ignore_index, ignore_index_list, num_classes, unfiltered_num_classesignore_index, ignore_index_list, num_classes, unfiltered_num_classes """ - if isinstance(ignore_index, list): + if isinstance(ignore_index, typing.Iterable): ignore_index_list = ignore_index unfiltered_num_classes = num_classes num_classes = num_classes - len(ignore_index_list) + 1 @@ -310,7 +312,7 @@ def __init__( if num_classes <= 1: raise ValueError(f"IoU class only for multi-class usage! For binary usage, please call {BinaryIOU.__name__}") - if isinstance(ignore_index, list) and reduction == "none": + if isinstance(ignore_index, typing.Iterable) and reduction == "none": raise ValueError("passing multiple ignore indices ") ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes)