Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix class_inclusion_list in DetectionDataset #1327

Merged
merged 5 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .format_converter import ConcatenatedTensorFormatConverter
from .output_adapters import DetectionOutputAdapter
from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
from .formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem, LabelTensorSliceItem
from .bbox_formats import (
CXCYWHCoordinateFormat,
NormalizedCXCYWHCoordinateFormat,
Expand All @@ -21,6 +21,7 @@
"NormalizedXYWHCoordinateFormat",
"NormalizedXYXYCoordinateFormat",
"TensorSliceItem",
"LabelTensorSliceItem",
"XYWHCoordinateFormat",
"XYXYCoordinateFormat",
"YXYXCoordinateFormat",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from super_gradients.common.object_names import ConcatenatedTensorFormats
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, LabelTensorSliceItem
from super_gradients.training.datasets.data_formats.bbox_formats import (
XYXYCoordinateFormat,
XYWHCoordinateFormat,
Expand All @@ -12,72 +12,72 @@
XYXY_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
XYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
CXCYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
LABEL_XYXY = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=XYXYCoordinateFormat()),
)
)
LABEL_XYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=XYWHCoordinateFormat()),
)
)
LABEL_CXCYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
)
)
NORMALIZED_XYXY_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
NORMALIZED_XYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
NORMALIZED_CXCYWH_LABEL = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)
LABEL_NORMALIZED_XYXY = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYXYCoordinateFormat()),
)
)
LABEL_NORMALIZED_XYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
)
)
LABEL_NORMALIZED_CXCYWH = ConcatenatedTensorFormat(
layout=(
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedCXCYWHCoordinateFormat()),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def __repr__(self):
return f"name={self.name} length={self.length} format={self.format}"


class LabelTensorSliceItem(TensorSliceItem):
NAME = "labels"

def __init__(self):
super().__init__(name=self.NAME, length=1)


class ConcatenatedTensorFormat(DetectionOutputFormat):
"""
Define the output format that return a single tensor of shape [N,M] (N - number of detections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.training.utils.detection_utils import get_cls_posx_in_target
from super_gradients.training.utils.detection_utils import get_class_index_in_target
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.transforms.transforms import DetectionTransform, DetectionTargetsFormatTransform, DetectionTargetsFormat
from super_gradients.common.exceptions.dataset_exceptions import EmptyDatasetException, DatasetValidationException
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat, LabelTensorSliceItem
from super_gradients.training.utils.utils import ensure_is_tuple_of_two

logger = get_logger(__name__)
Expand Down Expand Up @@ -298,26 +298,26 @@ def _sub_class_annotation(self, annotation: dict) -> Union[dict, None]:
:param annotation: Dict representing the annotation of a specific image
:return: Subclassed annotation if non-empty after subclassing, otherwise None
"""
cls_posx = get_cls_posx_in_target(self.original_target_format)
class_index = _get_class_index_in_target(target_format=self.original_target_format)
for field in self.target_fields:
annotation[field] = self._sub_class_target(targets=annotation[field], cls_posx=cls_posx)
annotation[field] = self._sub_class_target(targets=annotation[field], class_index=class_index)
return annotation

def _sub_class_target(self, targets: np.ndarray, cls_posx: int) -> np.ndarray:
def _sub_class_target(self, targets: np.ndarray, class_index: int) -> np.ndarray:
"""Sublass targets of a specific image.

:param targets: Target array to subclass of shape [n_targets, 5], 5 representing a bbox
:param cls_posx: Position of the class id in a bbox
:param class_index: Position of the class id in a bbox
ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label
:return: Subclassed target
"""
targets_kept = []
for target in targets:
cls_id = int(target[cls_posx])
cls_id = int(target[class_index])
cls_name = self.all_classes_list[cls_id]
if cls_name in self.class_inclusion_list:
# Replace the target cls_id in self.all_classes_list by cls_id in self.class_inclusion_list
target[cls_posx] = self.class_inclusion_list.index(cls_name)
target[class_index] = self.class_inclusion_list.index(cls_name)
targets_kept.append(target)

return np.array(targets_kept) if len(targets_kept) > 0 else np.zeros((0, 5), dtype=np.float32)
Expand Down Expand Up @@ -568,5 +568,16 @@ def get_dataset_preprocessing_params(self):
return params


# TODO
# - Integration Test
def _get_class_index_in_target(target_format: DetectionTargetsFormat) -> int:
"""Get the index of the class in the target format.
:param target_format: format of the target. E.g. XYXY_LABEL, LABEL_NORMALIZED_XYXY, ect...
:return: index of the class in the target format. E.g. XYXY_LABEL -> 4, LABEL_NORMALIZED_XYXY -> 0, ect....
"""
if isinstance(target_format, ConcatenatedTensorFormat):
return target_format.indexes[LabelTensorSliceItem.NAME][0]
elif isinstance(target_format, DetectionTargetsFormat):
return get_class_index_in_target(target_format)
else:
raise NotImplementedError(
f"{target_format} is not supported. Supported formats are: {ConcatenatedTensorFormat.__name__}, {DetectionTargetsFormat.__name__}"
)
2 changes: 1 addition & 1 deletion src/super_gradients/training/utils/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DetectionTargetsFormat(Enum):
NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL"


def get_cls_posx_in_target(target_format: DetectionTargetsFormat) -> int:
def get_class_index_in_target(target_format: DetectionTargetsFormat) -> int:
"""Get the label of a given target
:param target_format: Representation of the target (ex: LABEL_XYXY)
:return: Position of the class id in a bbox
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests/detection_output_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
YXYXCoordinateFormat,
NormalizedCXCYWHCoordinateFormat,
DetectionOutputAdapter,
LabelTensorSliceItem,
)

from super_gradients.training.datasets.data_formats.bbox_formats.normalized_cxcywh import xyxy_to_normalized_cxcywh
Expand All @@ -25,22 +26,22 @@
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=NormalizedXYWHCoordinateFormat()),
TensorSliceItem(length=1, name="scores"),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)

CXCYWH_SCORES_LABELS = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="scores"),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
)
)

CXCYWH_LABELS_SCORES_DISTANCE_ATTR = ConcatenatedTensorFormat(
layout=(
BoundingBoxesTensorSliceItem(name="bboxes", format=CXCYWHCoordinateFormat()),
TensorSliceItem(length=1, name="labels"),
LabelTensorSliceItem(),
TensorSliceItem(length=1, name="scores"),
TensorSliceItem(length=1, name="distance"),
TensorSliceItem(length=4, name="attributes"),
Expand Down
37 changes: 29 additions & 8 deletions tests/unit_tests/detection_sub_classing_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from super_gradients.common.exceptions.dataset_exceptions import EmptyDatasetException, DatasetValidationException

import unittest
import numpy as np
from typing import Union

from super_gradients.training.datasets import DetectionDataset
from super_gradients.training.utils.detection_utils import DetectionTargetsFormat
from super_gradients.common.exceptions.dataset_exceptions import EmptyDatasetException, DatasetValidationException
from super_gradients.training.datasets.data_formats.formats import ConcatenatedTensorFormat
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL


class DummyDetectionDataset(DetectionDataset):
def __init__(self, input_dim, *args, **kwargs):
def __init__(self, input_dim, target_format: Union[DetectionTargetsFormat, ConcatenatedTensorFormat], *args, **kwargs):
"""Dummy Dataset testing subclassing, designed with no annotation that includes class_2."""

self.dummy_targets = [
Expand All @@ -17,7 +21,7 @@ def __init__(self, input_dim, *args, **kwargs):

self.image_size = input_dim
kwargs["all_classes_list"] = ["class_0", "class_1", "class_2"]
kwargs["original_target_format"] = DetectionTargetsFormat.XYXY_LABEL
kwargs["original_target_format"] = target_format
super().__init__(data_dir="", input_dim=input_dim, *args, **kwargs)

def _setup_data_source(self):
Expand Down Expand Up @@ -53,28 +57,45 @@ def setUp(self) -> None:
def test_subclass_keep_empty(self):
"""Check that subclassing only keeps annotations of wanted class"""
for config in self.config_keep_empty_annotation:
test_dataset = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=False, class_inclusion_list=config["class_inclusion_list"])
test_dataset = DummyDetectionDataset(
input_dim=(640, 512), ignore_empty_annotations=False, class_inclusion_list=config["class_inclusion_list"], target_format=XYXY_LABEL
)
n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset)
self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass)

def test_subclass_drop_empty(self):
"""Check that empty annotations are not indexed (i.e. ignored) when ignore_empty_annotations=True"""
for config in self.config_ignore_empty_annotation:
test_dataset = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=config["class_inclusion_list"])
test_dataset = DummyDetectionDataset(
input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=config["class_inclusion_list"], target_format=XYXY_LABEL
)
n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset)
self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass)

# Check last case when class_2, which should raise EmptyDatasetException because not a single image has
# a target in class_inclusion_list
with self.assertRaises(EmptyDatasetException):
DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=["class_2"])
DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, class_inclusion_list=["class_2"], target_format=XYXY_LABEL)

def test_wrong_subclass(self):
"""Check that ValueError is raised when class_inclusion_list includes a class that does not exist."""
with self.assertRaises(DatasetValidationException):
DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["non_existing_class"])
DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["non_existing_class"], target_format=XYXY_LABEL)
with self.assertRaises(DatasetValidationException):
DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["class_0", "non_existing_class"])
DummyDetectionDataset(input_dim=(640, 512), class_inclusion_list=["class_0", "non_existing_class"], target_format=XYXY_LABEL)

def test_legacy_detection_targets_format(self):
"""Check that ValueError is raised when class_inclusion_list includes a class that does not exist."""

for config in self.config_keep_empty_annotation:
test_dataset = DummyDetectionDataset(
input_dim=(640, 512),
ignore_empty_annotations=False,
class_inclusion_list=config["class_inclusion_list"],
target_format=DetectionTargetsFormat.XYXY_LABEL,
)
n_targets_after_subclass = _count_targets_after_subclass_per_index(test_dataset)
self.assertListEqual(config["expected_n_targets_after_subclass"], n_targets_after_subclass)


def _count_targets_after_subclass_per_index(test_dataset: DummyDetectionDataset):
Expand Down