diff --git a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py index 352a34247b..a4f4cb6951 100644 --- a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py +++ b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py @@ -255,6 +255,7 @@ def _load_sample_annotation(self, sample_id: int) -> Dict[str, Union[np.ndarray, # Filter out classes that are not in self.class_inclusion_list if self.class_inclusion_list is not None: sample_annotations = self._sub_class_annotation(annotation=sample_annotations) + return sample_annotations def _load_all_annotations(self, n_samples: int) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[str, Any]]]: diff --git a/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py b/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py index 999f48dcfb..987561c815 100644 --- a/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py +++ b/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py @@ -184,6 +184,7 @@ def _load_annotation(self, sample_id: int) -> dict: yolo_format_target, invalid_labels = self._parse_yolo_label_file( label_file_path=label_path, + num_classes=len(self.all_classes_list), ignore_invalid_labels=self.ignore_invalid_labels, show_warnings=self.show_all_warnings, ) @@ -210,13 +211,20 @@ def _load_annotation(self, sample_id: int) -> dict: return annotation @staticmethod - def _parse_yolo_label_file(label_file_path: str, ignore_invalid_labels: bool = True, show_warnings: bool = True) -> Tuple[np.ndarray, List[str]]: + def _parse_yolo_label_file( + label_file_path: str, + ignore_invalid_labels: bool = True, + show_warnings: bool = True, + num_classes: Optional[int] = None, + ) -> Tuple[np.ndarray, List[str]]: """Parse a single label file in yolo format. #TODO: Add support for additional fields (with ConcatenatedTensorFormat) :param label_file_path: Path to the label file in yolo format. :param ignore_invalid_labels: Whether to ignore labels that fail to be parsed. If True ignores and logs a warning, otherwise raise an error. :param show_warnings: Whether to show the warnings or not. + :param num_classes: Number of classes in the dataset. Used to ensure that class ids are within the range [0, num_classes - 1]. + If None, ignore. :return: - labels: np.ndarray of shape (n_labels, 5) in yolo format (LABEL_NORMALIZED_CXCYWH) @@ -229,12 +237,21 @@ def _parse_yolo_label_file(label_file_path: str, ignore_invalid_labels: bool = T for line in filter(lambda x: x != "\n", lines): try: label_id, cx, cw, w, h = line.split() - labels_yolo_format.append([int(label_id), float(cx), float(cw), float(w), float(h)]) + label_id, cx, cw, w, h = int(label_id), float(cx), float(cw), float(w), float(h) + + if (num_classes is not None) and (label_id not in range(num_classes)): + raise ValueError(f"`class_id={label_id}` invalid. It should be between (0 - {num_classes - 1}).") + + labels_yolo_format.append([label_id, cx, cw, w, h]) except Exception as e: + error_msg = ( + f"Line `{line}` of file {label_file_path} will be ignored because not cannot be parsed to (label, cx, cy, w, h) format, " + f"with Exception:\n{e}" + ) if ignore_invalid_labels: invalid_labels.append(line) if show_warnings: - logger.warning(f"Line `{line}` of file {label_file_path} will be ignored because not in LABEL_NORMALIZED_CXCYWH format: {e}") + logger.warning(error_msg) else: - raise e + raise RuntimeError(error_msg) return np.array(labels_yolo_format) if labels_yolo_format else np.zeros((0, 5)), invalid_labels diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index 6ddbb9ef91..b9abce904e 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -27,7 +27,7 @@ ) from tests.end_to_end_tests import TestTrainer from tests.unit_tests.detection_utils_test import TestDetectionUtils -from tests.unit_tests.detection_dataset_test import DetectionDatasetTest +from tests.unit_tests.detection_dataset_test import DetectionDatasetTest, TestParseYoloLabelFile from tests.unit_tests.export_detection_model_test import TestDetectionModelExport from tests.unit_tests.export_onnx_test import TestModelsONNXExport from tests.unit_tests.export_pose_estimation_model_test import TestPoseEstimationModelExport @@ -136,6 +136,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestRepVGGBlock)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LocalCkptHeadReplacementTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionDatasetTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestParseYoloLabelFile)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelsONNXExport)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MaxBatchesLoopBreakTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestTrainingUtils)) diff --git a/tests/unit_tests/detection_dataset_test.py b/tests/unit_tests/detection_dataset_test.py index 1856cfabcb..e1b51e62be 100644 --- a/tests/unit_tests/detection_dataset_test.py +++ b/tests/unit_tests/detection_dataset_test.py @@ -1,13 +1,15 @@ import unittest +from unittest.mock import patch, mock_open from pathlib import Path from typing import Dict +import numpy as np from torch.utils.data import DataLoader from super_gradients import Trainer from super_gradients.training import models, dataloaders from super_gradients.training.dataloaders import coco2017_train_yolo_nas, get_data_loader -from super_gradients.training.datasets import COCODetectionDataset +from super_gradients.training.datasets import COCODetectionDataset, YoloDarknetFormatDetectionDataset from super_gradients.training.datasets.data_formats.default_formats import LABEL_CXCYWH from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException @@ -194,5 +196,31 @@ def test_coco_detection_metrics_with_classwise_ap(self): trainer.train(model=model, training_params=detection_train_params_yolox, train_loader=train_loader, valid_loader=valid_loader) +class TestParseYoloLabelFile(unittest.TestCase): + def setUp(self): + self.num_classes = 3 + self.sample_data_valid = "0 0.5 0.5 0.1 0.1\n1 0.6 0.6 0.2 0.2" + self.sample_data_invalid_format = "0 0.5\n1 0.6 0.6 0.2 0.2" + self.sample_data_invalid_class = "-1 0.5 0.5 0.1 0.1\n3 0.6 0.6 0.2 0.2" + + def test_valid_label(self): + with patch("builtins.open", mock_open(read_data=self.sample_data_valid)): + labels, invalid_labels = YoloDarknetFormatDetectionDataset._parse_yolo_label_file("mock_path", num_classes=3) + np.testing.assert_array_equal(labels, np.array([[0, 0.5, 0.5, 0.1, 0.1], [1, 0.6, 0.6, 0.2, 0.2]])) + self.assertEqual(invalid_labels, []) + + def test_invalid_format(self): + with patch("builtins.open", mock_open(read_data=self.sample_data_invalid_format)): + labels, invalid_labels = YoloDarknetFormatDetectionDataset._parse_yolo_label_file("mock_path", num_classes=3) + np.testing.assert_array_equal(labels, np.array([[1, 0.6, 0.6, 0.2, 0.2]])) + self.assertEqual(invalid_labels, ["0 0.5\n"]) + + def test_invalid_class(self): + with patch("builtins.open", mock_open(read_data=self.sample_data_invalid_class)): + labels, invalid_labels = YoloDarknetFormatDetectionDataset._parse_yolo_label_file("mock_path", num_classes=3) + self.assertEqual(len(labels), 0) + self.assertEqual(invalid_labels, ["-1 0.5 0.5 0.1 0.1\n", "3 0.6 0.6 0.2 0.2"]) + + if __name__ == "__main__": unittest.main()