diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..1d2f0d8a75 --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +unit_tests: + python -m unittest tests/deci_core_unit_test_suite_runner.py + +integration_tests: + python -m unittest tests/deci_core_integration_test_suite_runner.py + +yolo_nas_integration_tests: + python -m unittest tests/integration_tests/yolo_nas_integration_test.py diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml index f21c2329bf..aed33f25dd 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml @@ -31,7 +31,7 @@ train_dataset_params: prob: 0.5 # probability to apply per-sample mixup flip_prob: 0.5 # probability to apply horizontal flip - DetectionPaddedRescale: - input_dim: [640, 640] + input_dim: ${dataset_params.train_dataset_params.input_dim} max_targets: 120 pad_value: 114 - DetectionStandardize: diff --git a/src/super_gradients/training/dataloaders/dataloaders.py b/src/super_gradients/training/dataloaders/dataloaders.py index dc318ce11e..197b07025f 100644 --- a/src/super_gradients/training/dataloaders/dataloaders.py +++ b/src/super_gradients/training/dataloaders/dataloaders.py @@ -3,6 +3,7 @@ import hydra import numpy as np import torch +from omegaconf import OmegaConf, UnsupportedValueType from torch.utils.data import BatchSampler, DataLoader, TensorDataset, RandomSampler import super_gradients @@ -81,14 +82,40 @@ def get_data_loader(config_name: str, dataset_cls: object, train: bool, dataset_ return dataloader -def _process_dataset_params(cfg, dataset_params, train): - default_dataset_params = cfg.train_dataset_params if train else cfg.val_dataset_params - default_dataset_params = hydra.utils.instantiate(default_dataset_params) - for key, val in default_dataset_params.items(): - if key not in dataset_params.keys() or dataset_params[key] is None: - dataset_params[key] = val +def _process_dataset_params(cfg, dataset_params, train: bool): + """ + Merge the default dataset config with the user-provided overrides. + This function handles variable interpolation in the dataset config. + + :param cfg: Default dataset config + :param dataset_params: User-provided overrides + :param train: boolean flag indicating whether we are processing train or val dataset params + :return: New dataset params (merged defaults and overrides, where overrides take precedence) + """ - return dataset_params + try: + # No, we can't simplify the following lines to: + # >>> default_dataset_params = cfg.train_dataset_params if train else cfg.val_dataset_params + # >>> dataset_params = OmegaConf.merge(default_dataset_params, dataset_params) + # >>> return hydra.utils.instantiate(dataset_params) + # For some reason this breaks interpolation :shrug: + + if train: + cfg.train_dataset_params = OmegaConf.merge(cfg.train_dataset_params, dataset_params) + return hydra.utils.instantiate(cfg.train_dataset_params) + else: + cfg.val_dataset_params = OmegaConf.merge(cfg.val_dataset_params, dataset_params) + return hydra.utils.instantiate(cfg.val_dataset_params) + + except UnsupportedValueType: + # This is somewhat ugly fallback for the case when the user provides overrides for the dataset params + # that contains non-primitive types (E.g instantiated transforms). + # In this case interpolation is not possible so we just override the default params with the user-provided ones. + default_dataset_params = hydra.utils.instantiate(cfg.train_dataset_params if train else cfg.val_dataset_params) + for key, val in default_dataset_params.items(): + if key not in dataset_params.keys() or dataset_params[key] is None: + dataset_params[key] = val + return dataset_params def _process_dataloader_params(cfg, dataloader_params, dataset, train): diff --git a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py index a6e9199c86..6f9c44e34a 100644 --- a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py +++ b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py @@ -24,6 +24,7 @@ from super_gradients.common.factories.transforms_factory import TransformsFactory from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat +from super_gradients.training.utils.utils import ensure_is_tuple_of_two logger = get_logger(__name__) @@ -76,7 +77,7 @@ def __init__( max_num_samples: int = None, cache: bool = False, cache_dir: str = None, - input_dim: Optional[Tuple[int, int]] = None, + input_dim: Union[int, Tuple[int, int], None] = None, transforms: List[DetectionTransform] = [], all_classes_list: Optional[List[str]] = [], class_inclusion_list: Optional[List[str]] = None, @@ -89,7 +90,10 @@ def __init__( """Detection dataset. :param data_dir: Where the data is stored - :param input_dim: Image size (when loaded, before transforms). + :param input_dim: Image size (when loaded, before transforms). Can be None, scalar or tuple (rows, cols). + None means that the image will be loaded as is. + Scalar (size) - Image will be resized to (size, size) + Tuple (rows,cols) - Image will be resized to (rows, cols) :param original_target_format: Format of targets stored on disk. raw data format, the output format might differ based on transforms. :param max_num_samples: If not None, set the maximum size of the dataset by only indexing the first n annotations/images. @@ -129,7 +133,7 @@ def __init__( if not isinstance(self.n_available_samples, int) or self.n_available_samples < 1: raise ValueError(f"_setup_data_source() should return the number of available samples but got {self.n_available_samples}") - self.input_dim = input_dim + self.input_dim = ensure_is_tuple_of_two(input_dim) self.original_target_format = original_target_format self.max_num_samples = max_num_samples diff --git a/src/super_gradients/training/transforms/transforms.py b/src/super_gradients/training/transforms/transforms.py index ba973f4153..5261a824e2 100644 --- a/src/super_gradients/training/transforms/transforms.py +++ b/src/super_gradients/training/transforms/transforms.py @@ -28,6 +28,7 @@ _shift_bboxes, _rescale_xyxy_bboxes, ) +from super_gradients.training.utils.utils import ensure_is_tuple_of_two IMAGE_RESAMPLE_MODE = Image.BILINEAR MASK_RESAMPLE_MODE = Image.NEAREST @@ -459,10 +460,10 @@ class DetectionMosaic(DetectionTransform): :param border_value: Value for filling borders after applying transforms. """ - def __init__(self, input_dim: tuple, prob: float = 1.0, enable_mosaic: bool = True, border_value=114): + def __init__(self, input_dim: Union[int, Tuple[int, int]], prob: float = 1.0, enable_mosaic: bool = True, border_value=114): super(DetectionMosaic, self).__init__(additional_samples_count=3) self.prob = prob - self.input_dim = input_dim + self.input_dim = ensure_is_tuple_of_two(input_dim) self.enable_mosaic = enable_mosaic self.border_value = border_value @@ -566,7 +567,7 @@ def __init__( translate: Union[tuple, float] = 0.1, scales: Union[tuple, float] = 0.1, shear: Union[tuple, float] = 10, - target_size: Optional[Tuple[int, int]] = (640, 640), + target_size: Union[int, Tuple[int, int], None] = (640, 640), filter_box_candidates: bool = False, wh_thr: float = 2, ar_thr: float = 20, @@ -578,7 +579,7 @@ def __init__( self.translate = translate self.scale = scales self.shear = shear - self.target_size = target_size + self.target_size = ensure_is_tuple_of_two(target_size) self.enable = True self.filter_box_candidates = filter_box_candidates self.wh_thr = wh_thr @@ -624,9 +625,17 @@ class DetectionMixup(DetectionTransform): :param border_value: Value for filling borders after applying transform. """ - def __init__(self, input_dim: tuple, mixup_scale: tuple, prob: float = 1.0, enable_mixup: bool = True, flip_prob: float = 0.5, border_value: int = 114): + def __init__( + self, + input_dim: Union[int, Tuple[int, int], None], + mixup_scale: tuple, + prob: float = 1.0, + enable_mixup: bool = True, + flip_prob: float = 0.5, + border_value: int = 114, + ): super(DetectionMixup, self).__init__(additional_samples_count=1, non_empty_targets=True) - self.input_dim = input_dim + self.input_dim = ensure_is_tuple_of_two(input_dim) self.mixup_scale = mixup_scale self.prob = prob self.enable_mixup = enable_mixup @@ -736,7 +745,7 @@ class DetectionPadToSize(DetectionTransform): Note: This transformation assume that dimensions of input image is equal or less than `output_size`. """ - def __init__(self, output_size: Tuple[int, int], pad_value: int): + def __init__(self, output_size: Union[int, Tuple[int, int], None], pad_value: int): """ Constructor for DetectionPadToSize transform. @@ -744,7 +753,7 @@ def __init__(self, output_size: Tuple[int, int], pad_value: int): :param pad_value: Padding value for image """ super().__init__() - self.output_size = output_size + self.output_size = ensure_is_tuple_of_two(output_size) self.pad_value = pad_value def __call__(self, sample: dict) -> dict: @@ -775,9 +784,9 @@ class DetectionPaddedRescale(DetectionTransform): :param pad_value: Padding value for image. """ - def __init__(self, input_dim: Tuple, swap: Tuple[int, ...] = (2, 0, 1), max_targets: int = 50, pad_value: int = 114): + def __init__(self, input_dim: Union[int, Tuple[int, int], None], swap: Tuple[int, ...] = (2, 0, 1), max_targets: int = 50, pad_value: int = 114): self.swap = swap - self.input_dim = input_dim + self.input_dim = ensure_is_tuple_of_two(input_dim) self.max_targets = max_targets self.pad_value = pad_value @@ -834,14 +843,14 @@ class DetectionRescale(DetectionTransform): :param output_shape: (rows, cols) """ - def __init__(self, output_shape: Tuple[int, int]): + def __init__(self, output_shape: Union[int, Tuple[int, int]]): super().__init__() - self.output_shape = output_shape + self.output_shape = ensure_is_tuple_of_two(output_shape) def __call__(self, sample: dict) -> dict: image, targets, crowd_targets = sample["image"], sample["target"], sample.get("crowd_target") - sy, sx = (self.output_shape[0] / image.shape[0], self.output_shape[1] / image.shape[1]) + sy, sx = float(self.output_shape[0]) / float(image.shape[0]), float(self.output_shape[1]) / float(image.shape[1]) sample["image"] = _rescale_image(image=image, target_shape=self.output_shape) sample["target"] = _rescale_bboxes(targets, scale_factors=(sy, sx)) @@ -1010,7 +1019,7 @@ class DetectionTargetsFormatTransform(DetectionTransform): @resolve_param("output_format", ConcatenatedTensorFormatFactory()) def __init__( self, - input_dim: Optional[tuple] = None, + input_dim: Union[int, Tuple[int, int], None] = None, input_format: ConcatenatedTensorFormat = XYXY_LABEL, output_format: ConcatenatedTensorFormat = LABEL_CXCYWH, min_bbox_edge_size: float = 1, @@ -1031,6 +1040,7 @@ def __init__( self.input_dim = None if input_dim is not None: + input_dim = ensure_is_tuple_of_two(input_dim) self._setup_input_dim_related_params(input_dim) def _setup_input_dim_related_params(self, input_dim: tuple): diff --git a/src/super_gradients/training/utils/utils.py b/src/super_gradients/training/utils/utils.py index c387c9e21d..abdcce3b21 100755 --- a/src/super_gradients/training/utils/utils.py +++ b/src/super_gradients/training/utils/utils.py @@ -1,29 +1,29 @@ +import collections +import math import os -import tarfile +import random import re -import math +import tarfile import time - import inspect from functools import lru_cache, wraps +from importlib import import_module +from itertools import islice + from pathlib import Path from typing import Mapping, Optional, Tuple, Union, List, Dict, Any, Iterable from zipfile import ZipFile -from jsonschema import validate -from itertools import islice -from PIL import Image, ExifTags +import numpy as np import torch import torch.nn as nn - -# These functions changed from torch 1.2 to torch 1.3 - -import random -import numpy as np -from importlib import import_module +from PIL import Image, ExifTags +from jsonschema import validate from super_gradients.common.abstractions.abstract_logger import get_logger +# These functions changed from torch 1.2 to torch 1.3 + logger = get_logger(__name__) @@ -581,3 +581,19 @@ def generate_batch(iterable: Iterable, batch_size: int) -> Iterable: yield batch else: return + + +def ensure_is_tuple_of_two(inputs: Union[Any, Iterable[Any], None]) -> Union[Tuple[Any, Any], None]: + """ + Checks input and converts it to a tuple of length two. If input is None returns None. + :param inputs: Input argument, either a number or a tuple of two numbers. + :return: Tuple of two numbers if input is not None, otherwise - None. + """ + if inputs is None: + return None + + if isinstance(inputs, collections.Iterable) and not isinstance(inputs, str): + a, b = inputs + return a, b + + return inputs, inputs diff --git a/tests/unit_tests/detection_dataset_test.py b/tests/unit_tests/detection_dataset_test.py index 79c16a0a1d..fda8597f3d 100644 --- a/tests/unit_tests/detection_dataset_test.py +++ b/tests/unit_tests/detection_dataset_test.py @@ -1,8 +1,11 @@ import unittest from pathlib import Path +from super_gradients.training.dataloaders import coco2017_train_yolo_nas from super_gradients.training.datasets import COCODetectionDataset +from super_gradients.training.datasets.data_formats.default_formats import LABEL_CXCYWH from super_gradients.training.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException +from super_gradients.training.transforms import DetectionMosaic, DetectionTargetsFormatTransform, DetectionPaddedRescale class DetectionDatasetTest(unittest.TestCase): @@ -44,6 +47,47 @@ def test_coco_dataset_creation_with_subset_classes(self): with self.assertRaises(ParameterMismatchException): COCODetectionDataset(**train_dataset_params) + def test_coco_detection_dataset_override_image_size(self): + train_dataset_params = { + "data_dir": self.mini_coco_data_dir, + "input_dim": [512, 512], + } + train_dataloader_params = {"num_workers": 0} + dataloader = coco2017_train_yolo_nas(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + batch = next(iter(dataloader)) + print(batch[0].shape) + self.assertEqual(batch[0].shape[2], 512) + self.assertEqual(batch[0].shape[3], 512) + + def test_coco_detection_dataset_override_image_size_single_scalar(self): + train_dataset_params = { + "data_dir": self.mini_coco_data_dir, + "input_dim": 384, + } + train_dataloader_params = {"num_workers": 0} + dataloader = coco2017_train_yolo_nas(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + batch = next(iter(dataloader)) + print(batch[0].shape) + self.assertEqual(batch[0].shape[2], 384) + self.assertEqual(batch[0].shape[3], 384) + + def test_coco_detection_dataset_override_with_objects(self): + train_dataset_params = { + "data_dir": self.mini_coco_data_dir, + "input_dim": 384, + "transforms": [ + DetectionMosaic(input_dim=384), + DetectionPaddedRescale(input_dim=384, max_targets=10), + DetectionTargetsFormatTransform(max_targets=10, output_format=LABEL_CXCYWH), + ], + } + train_dataloader_params = {"num_workers": 0} + dataloader = coco2017_train_yolo_nas(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params) + batch = next(iter(dataloader)) + print(batch[0].shape) + self.assertEqual(batch[0].shape[2], 384) + self.assertEqual(batch[0].shape[3], 384) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/detection_utils_test.py b/tests/unit_tests/detection_utils_test.py index cff3b02e1d..4af450c116 100644 --- a/tests/unit_tests/detection_utils_test.py +++ b/tests/unit_tests/detection_utils_test.py @@ -23,12 +23,13 @@ def setUp(self): @unittest.skipIf(not is_data_available(), "run only when /data is available") def test_visualization(self): - valid_loader = coco2017_val(dataloader_params={"batch_size": 16}) + valid_loader = coco2017_val(dataloader_params={"batch_size": 16, "num_workers": 0}) trainer = Trainer("visualization_test") post_prediction_callback = YoloPostPredictionCallback() # Simulate one iteration of validation subset - batch_i, (imgs, targets) = 0, next(iter(valid_loader)) + batch_i, batch = 0, next(iter(valid_loader)) + imgs, targets = batch[:2] imgs = core_utils.tensor_container_to_device(imgs, self.device) targets = core_utils.tensor_container_to_device(targets, self.device) output = self.model(imgs) @@ -46,7 +47,7 @@ def test_visualization(self): @unittest.skipIf(not is_data_available(), "run only when /data is available") def test_detection_metrics(self): - valid_loader = coco2017_val(dataloader_params={"batch_size": 16}) + valid_loader = coco2017_val(dataloader_params={"batch_size": 16, "num_workers": 0}) metrics = [ DetectionMetrics(num_cls=80, post_prediction_callback=YoloPostPredictionCallback(), normalize_targets=True), @@ -55,14 +56,14 @@ def test_detection_metrics(self): ] ref_values = [ - np.array([0.24662896, 0.4024832, 0.34590888, 0.28435066]), - np.array([0.34606069, 0.56745648, 0.50594932, 0.40323338]), + np.array([0.24701539, 0.40294355, 0.34654024, 0.28485271]), + np.array([0.34666198, 0.56854934, 0.5079478, 0.40414381]), np.array([0.0, 0.0, 0.0, 0.0]), ] for met, ref_val in zip(metrics, ref_values): met.reset() - for i, (imgs, targets) in enumerate(valid_loader): + for i, (imgs, targets, extras) in enumerate(valid_loader): if i > 5: break imgs = core_utils.tensor_container_to_device(imgs, self.device) @@ -71,8 +72,7 @@ def test_detection_metrics(self): met.update(output, targets, device=self.device, inputs=imgs) results = met.compute() values = np.array([x.item() for x in list(results.values())]) - - self.assertTrue(np.allclose(values, ref_val)) + self.assertTrue(np.allclose(values, ref_val, rtol=1e-3, atol=1e-4)) if __name__ == "__main__": diff --git a/tests/unit_tests/preprocessing_unit_test.py b/tests/unit_tests/preprocessing_unit_test.py index eb7d416077..13238f3313 100644 --- a/tests/unit_tests/preprocessing_unit_test.py +++ b/tests/unit_tests/preprocessing_unit_test.py @@ -21,9 +21,9 @@ def test_getting_preprocessing_params(self): "ComposeProcessing": { "processings": [ "ReverseImageChannels", - {"DetectionLongestMaxSizeRescale": {"output_shape": [512, 512]}}, - {"DetectionLongestMaxSizeRescale": {"output_shape": [512, 512]}}, - {"DetectionBottomRightPadding": {"output_shape": [512, 512], "pad_value": 114}}, + {"DetectionLongestMaxSizeRescale": {"output_shape": (512, 512)}}, + {"DetectionLongestMaxSizeRescale": {"output_shape": (512, 512)}}, + {"DetectionBottomRightPadding": {"output_shape": (512, 512), "pad_value": 114}}, {"ImagePermute": {"permutation": (2, 0, 1)}}, ] } diff --git a/tests/unit_tests/yolox_unit_test.py b/tests/unit_tests/yolox_unit_test.py index bcc88617f0..e4575a1527 100644 --- a/tests/unit_tests/yolox_unit_test.py +++ b/tests/unit_tests/yolox_unit_test.py @@ -12,34 +12,34 @@ class TestYOLOX(unittest.TestCase): def setUp(self) -> None: self.arch_params = HpmStruct(num_classes=10) self.yolo_classes = [YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X] + self.devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] def test_yolox_creation(self): """ test_yolox_creation - Tests the creation of the models :return: """ - dummy_input = torch.randn(1, 3, 320, 320) + for device in self.devices: + dummy_input = torch.randn(1, 3, 320, 320).to(device) + with torch.no_grad(): + for yolo_cls in self.yolo_classes: + yolo_model = yolo_cls(self.arch_params).to(device) + # THIS SHOULD RUN THE FORWARD ONCE + yolo_model.eval() + output_standard = yolo_model(dummy_input) + self.assertIsNotNone(output_standard) - with torch.no_grad(): + # THIS SHOULD RUN A TRAINING FORWARD + yolo_model.train() + output_train = yolo_model(dummy_input) - for yolo_cls in self.yolo_classes: - yolo_model = yolo_cls(self.arch_params) - # THIS SHOULD RUN THE FORWARD ONCE - yolo_model.eval() - output_standard = yolo_model(dummy_input) - self.assertIsNotNone(output_standard) + self.assertIsNotNone(output_train) - # THIS SHOULD RUN A TRAINING FORWARD - yolo_model.train() - output_train = yolo_model(dummy_input) - - self.assertIsNotNone(output_train) - - # THIS SHOULD RUN THE FORWARD AUGMENT - yolo_model.eval() - yolo_model.augmented_inference = True - output_augment = yolo_model(dummy_input) - self.assertIsNotNone(output_augment) + # THIS SHOULD RUN THE FORWARD AUGMENT + yolo_model.eval() + yolo_model.augmented_inference = True + output_augment = yolo_model(dummy_input) + self.assertIsNotNone(output_augment) def test_yolox_loss(self): samples = [ @@ -52,21 +52,22 @@ def test_yolox_loss(self): collate = DetectionCollateFN() _, targets = collate(samples) - predictions = [ - torch.randn((5, 1, 256 // 8, 256 // 8, 4 + 1 + 10)), - torch.randn((5, 1, 256 // 16, 256 // 16, 4 + 1 + 10)), - torch.randn((5, 1, 256 // 32, 256 // 32, 4 + 1 + 10)), - ] + for device in self.devices: + predictions = [ + torch.randn((5, 1, 256 // 8, 256 // 8, 4 + 1 + 10)).to(device), + torch.randn((5, 1, 256 // 16, 256 // 16, 4 + 1 + 10)).to(device), + torch.randn((5, 1, 256 // 32, 256 // 32, 4 + 1 + 10)).to(device), + ] - for loss in [ - YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="giou"), - YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="iou"), - YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False), - YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True), - YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False), - ]: - result = loss(predictions, targets) - print(result) + for loss in [ + YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="giou"), + YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="iou"), + YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False), + YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True), + YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False), + ]: + result = loss(predictions, targets.to(device)) + print(result) if __name__ == "__main__":