diff --git a/Makefile b/Makefile index 1d2f0d8a75..6b950b02f0 100644 --- a/Makefile +++ b/Makefile @@ -6,3 +6,11 @@ integration_tests: yolo_nas_integration_tests: python -m unittest tests/integration_tests/yolo_nas_integration_test.py + +recipe_accuracy_tests: + python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test + python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test epochs=1 batch_size=4 val_batch_size=8 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4 + python3.8 src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 + python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_yolox experiment_name=shortened_coco2017_yolox_n_map_test epochs=10 architecture=yolox_n training_hyperparams.loss=yolox_fast_loss training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 + python3.8 src/super_gradients/train_from_recipe.py --config-name=cityscapes_regseg48 experiment_name=shortened_cityscapes_regseg48_iou_test epochs=10 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4 + coverage run --source=super_gradients -m unittest tests/deci_core_recipe_test_suite_runner.py diff --git a/src/super_gradients/common/object_names.py b/src/super_gradients/common/object_names.py index 45921f5776..2ce232a3dd 100644 --- a/src/super_gradients/common/object_names.py +++ b/src/super_gradients/common/object_names.py @@ -110,6 +110,7 @@ class Transforms: # Keypoints KeypointsRandomAffineTransform = "KeypointsRandomAffineTransform" KeypointsImageNormalize = "KeypointsImageNormalize" + KeypointsImageStandardize = "KeypointsImageStandardize" KeypointsImageToTensor = "KeypointsImageToTensor" KeypointTransform = "KeypointTransform" KeypointsPadIfNeeded = "KeypointsPadIfNeeded" @@ -413,8 +414,10 @@ class Processings: DetectionCenterPadding = "DetectionCenterPadding" DetectionLongestMaxSizeRescale = "DetectionLongestMaxSizeRescale" DetectionBottomRightPadding = "DetectionBottomRightPadding" - ImagePermute = "ImagePermute" DetectionRescale = "DetectionRescale" + KeypointsLongestMaxSizeRescale = "KeypointsLongestMaxSizeRescale" + KeypointsBottomRightPadding = "KeypointsBottomRightPadding" + ImagePermute = "ImagePermute" ReverseImageChannels = "ReverseImageChannels" NormalizeImage = "NormalizeImage" ComposeProcessing = "ComposeProcessing" diff --git a/src/super_gradients/common/plugins/wandb/log_predictions.py b/src/super_gradients/common/plugins/wandb/log_predictions.py index 04fa007aa4..6043bbc0d3 100644 --- a/src/super_gradients/common/plugins/wandb/log_predictions.py +++ b/src/super_gradients/common/plugins/wandb/log_predictions.py @@ -3,7 +3,7 @@ except (ModuleNotFoundError, ImportError, NameError): pass # no action or logging - this is normal in most cases -from super_gradients.training.models.prediction_results import ImageDetectionPrediction, ImagesDetectionPrediction +from super_gradients.training.utils.predict import ImageDetectionPrediction, ImagesDetectionPrediction def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool): diff --git a/src/super_gradients/recipes/arch_params/pose_dekr_coco_rescoring_arch_params.yaml b/src/super_gradients/recipes/arch_params/pose_dekr_coco_rescoring_arch_params.yaml index 53d85a76ef..418b9c3908 100644 --- a/src/super_gradients/recipes/arch_params/pose_dekr_coco_rescoring_arch_params.yaml +++ b/src/super_gradients/recipes/arch_params/pose_dekr_coco_rescoring_arch_params.yaml @@ -1,7 +1,7 @@ num_classes: 17 hidden_channels: 256 num_layers: 2 -joint_links: +edge_links: - [ 0, 1 ] - [ 0, 2 ] - [ 1, 2 ] diff --git a/src/super_gradients/recipes/dataset_params/coco_pose_estimation_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_pose_estimation_dataset_params.yaml index 131fe08146..cc15e4b0e5 100644 --- a/src/super_gradients/recipes/dataset_params/coco_pose_estimation_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_pose_estimation_dataset_params.yaml @@ -3,10 +3,9 @@ num_joints: 17 # OKs sigma values take from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py#L523 oks_sigmas: [0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 1.007, 1.007, 0.087, 0.087, 0.089, 0.089] -flip_indexes_heatmap: [ 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 17] -flip_indexes_offset: [ 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15,] +flip_indexes: [ 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15,] -joint_links: +edge_links: - [0, 1] - [0, 2] - [1, 2] @@ -27,6 +26,47 @@ joint_links: - [13, 15] - [14, 16] +edge_colors: + - [214, 39, 40] # Nose -> LeftEye + - [148, 103, 189] # Nose -> RightEye + - [44, 160, 44] # LeftEye -> RightEye + - [140, 86, 75] # LeftEye -> LeftEar + - [227, 119, 194] # RightEye -> RightEar + - [127, 127, 127] # LeftEar -> LeftShoulder + - [188, 189, 34] # RightEar -> RightShoulder + - [127, 127, 127] # Shoulders + - [188, 189, 34] # LeftShoulder -> LeftElbow + - [140, 86, 75] # LeftTorso + - [23, 190, 207] # RightShoulder -> RightElbow + - [227, 119, 194] # RightTorso + - [31, 119, 180] # LeftElbow -> LeftArm + - [255, 127, 14] # RightElbow -> RightArm + - [148, 103, 189] # Waist + - [255, 127, 14] # Left Hip -> Left Knee + - [214, 39, 40] # Right Hip -> Right Knee + - [31, 119, 180] # Left Knee -> Left Ankle + - [44, 160, 44] # Right Knee -> Right Ankle + + +keypoint_colors: + - [148, 103, 189] + - [31, 119, 180] + - [148, 103, 189] + - [31, 119, 180] + - [148, 103, 189] + - [31, 119, 180] + - [148, 103, 189] + - [31, 119, 180] + - [148, 103, 189] + - [31, 119, 180] + - [148, 103, 189] + - [31, 119, 180] + - [148, 103, 189] + - [31, 119, 180] + - [148, 103, 189] + - [31, 119, 180] + - [148, 103, 189] + train_dataset_params: data_dir: /data/coco # root path to coco data @@ -36,6 +76,10 @@ train_dataset_params: include_empty_samples: False min_instance_area: 64 + edge_links: ${dataset_params.edge_links} + edge_colors: ${dataset_params.edge_colors} + keypoint_colors: ${dataset_params.keypoint_colors} + transforms: - KeypointsLongestMaxSize: max_height: 640 @@ -44,12 +88,12 @@ train_dataset_params: - KeypointsPadIfNeeded: min_height: 640 min_width: 640 - image_pad_value: [ 127, 127, 127 ] + image_pad_value: 127 mask_pad_value: 1 - KeypointsRandomHorizontalFlip: # Note these indexes are COCO-specific. If you're using a different dataset, you'll need to change these accordingly. - flip_index: [ 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15 ] + flip_index: ${dataset_params.flip_indexes} prob: 0.5 - KeypointsRandomAffineTransform: @@ -57,16 +101,18 @@ train_dataset_params: min_scale: 0.5 max_scale: 2 max_translate: 0.2 - image_pad_value: [ 127, 127, 127 ] + image_pad_value: 127 mask_pad_value: 1 prob: 0.75 - - KeypointsImageToTensor + - KeypointsImageStandardize: + max_value: 255 - KeypointsImageNormalize: mean: [ 0.485, 0.456, 0.406 ] std: [ 0.229, 0.224, 0.225 ] + - KeypointsImageToTensor val_dataset_params: data_dir: /data/coco/ @@ -75,6 +121,11 @@ val_dataset_params: json_file: annotations/person_keypoints_val2017.json include_empty_samples: True min_instance_area: 128 + + edge_links: ${dataset_params.edge_links} + edge_colors: ${dataset_params.edge_colors} + keypoint_colors: ${dataset_params.keypoint_colors} + transforms: - KeypointsLongestMaxSize: max_height: 640 @@ -83,15 +134,17 @@ val_dataset_params: - KeypointsPadIfNeeded: min_height: 640 min_width: 640 - image_pad_value: [ 127, 127, 127 ] + image_pad_value: 127 mask_pad_value: 1 - - KeypointsImageToTensor + - KeypointsImageStandardize: + max_value: 255 - KeypointsImageNormalize: - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] + mean: [ 0.485, 0.456, 0.406 ] + std: [ 0.229, 0.224, 0.225 ] + - KeypointsImageToTensor train_dataloader_params: shuffle: True diff --git a/src/super_gradients/recipes/dataset_params/coco_pose_estimation_rescoring_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_pose_estimation_rescoring_dataset_params.yaml index 9ed5000b88..1d9d721f34 100644 --- a/src/super_gradients/recipes/dataset_params/coco_pose_estimation_rescoring_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_pose_estimation_rescoring_dataset_params.yaml @@ -3,7 +3,7 @@ num_joints: 17 # OKs sigma values take from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py#L523 oks_sigmas: [0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 1.007, 1.007, 0.087, 0.087, 0.089, 0.089] -joint_links: +edge_links: - [0, 1] - [0, 2] - [1, 2] diff --git a/src/super_gradients/scripts/generate_rescoring_training_data/__main__.py b/src/super_gradients/scripts/generate_rescoring_training_data/__main__.py index 81e2345ef0..8957e7c964 100644 --- a/src/super_gradients/scripts/generate_rescoring_training_data/__main__.py +++ b/src/super_gradients/scripts/generate_rescoring_training_data/__main__.py @@ -115,7 +115,7 @@ def main(cfg: DictConfig) -> None: ) # model = DEKRWrapper(model, apply_sigmoid=True).cuda().eval() - model = DEKRHorisontalFlipWrapper(model, cfg.dataset_params.flip_indexes_heatmap, cfg.dataset_params.flip_indexes_offset, apply_sigmoid=True).cuda().eval() + model = DEKRHorisontalFlipWrapper(model, cfg.dataset_params.flip_indexes, apply_sigmoid=True).cuda().eval() post_prediction_callback = cfg.post_prediction_callback diff --git a/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py b/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py index d677955291..062f0b57d4 100644 --- a/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py +++ b/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py @@ -1,14 +1,16 @@ import abc -from typing import Tuple, List, Mapping, Any, Dict +from typing import Tuple, List, Mapping, Any, Dict, Union import numpy as np import torch from torch.utils.data.dataloader import default_collate, Dataset from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.object_names import Processings from super_gradients.common.registry.registry import register_collate_function from super_gradients.training.datasets.pose_estimation_datasets.target_generators import KeypointsTargetsGenerator from super_gradients.training.transforms.keypoint_transforms import KeypointsCompose, KeypointTransform +from super_gradients.training.utils.visualization.utils import generate_color_mapping logger = get_logger(__name__) @@ -24,6 +26,10 @@ def __init__( target_generator: KeypointsTargetsGenerator, transforms: List[KeypointTransform], min_instance_area: float, + num_joints: int, + edge_links: Union[List[Tuple[int, int]], np.ndarray], + edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None], + keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None], ): """ @@ -31,11 +37,19 @@ def __init__( See DEKRTargetsGenerator for an example. :param transforms: Transforms to be applied to the image & keypoints :param min_instance_area: Minimum area of an instance to be included in the dataset + :param num_joints: Number of joints to be predicted + :param edge_links: Edge links between joints + :param edge_colors: Color of the edge links. If None, the color will be generated randomly. + :param keypoint_colors: Color of the keypoints. If None, the color will be generated randomly. """ super().__init__() self.target_generator = target_generator self.transforms = KeypointsCompose(transforms) self.min_instance_area = min_instance_area + self.num_joints = num_joints + self.edge_links = edge_links + self.edge_colors = edge_colors or generate_color_mapping(len(edge_links)) + self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints) @abc.abstractmethod def __len__(self) -> int: @@ -95,6 +109,21 @@ def filter_joints(self, joints: np.ndarray, image: np.ndarray) -> np.ndarray: return joints + def get_dataset_preprocessing_params(self): + """ + + :return: + """ + pipeline = self.transforms.get_equivalent_preprocessing() + params = dict( + conf=0.25, + image_processor={Processings.ComposeProcessing: {"processings": pipeline}}, + edge_links=self.edge_links, + edge_colors=self.edge_colors, + keypoint_colors=self.keypoint_colors, + ) + return params + @register_collate_function() class KeypointsCollate: diff --git a/src/super_gradients/training/datasets/pose_estimation_datasets/coco_keypoints.py b/src/super_gradients/training/datasets/pose_estimation_datasets/coco_keypoints.py index 253a3c45e6..77ee826dc8 100644 --- a/src/super_gradients/training/datasets/pose_estimation_datasets/coco_keypoints.py +++ b/src/super_gradients/training/datasets/pose_estimation_datasets/coco_keypoints.py @@ -1,5 +1,5 @@ import os -from typing import Tuple, List, Mapping, Any +from typing import Tuple, List, Mapping, Any, Union import cv2 import numpy as np @@ -8,7 +8,7 @@ from torch import Tensor from super_gradients.common.abstractions.abstract_logger import get_logger -from super_gradients.common.object_names import Datasets +from super_gradients.common.object_names import Datasets, Processings from super_gradients.common.registry.registry import register_dataset from super_gradients.common.decorators.factory_decorator import resolve_param from super_gradients.common.factories.target_generator_factory import TargetGeneratorsFactory @@ -37,6 +37,9 @@ def __init__( target_generator, transforms: List[KeypointTransform], min_instance_area: float, + edge_links: Union[List[Tuple[int, int]], np.ndarray], + edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None], + keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None], ): """ @@ -49,20 +52,32 @@ def __init__( See DEKRTargetsGenerator for an example. :param transforms: Transforms to be applied to the image & keypoints :param min_instance_area: Minimum area of an instance to be included in the dataset + :param edge_links: Edge links between joints + :param edge_colors: Color of the edge links. If None, the color will be generated randomly. + :param keypoint_colors: Color of the keypoints. If None, the color will be generated randomly. """ - super().__init__(transforms=transforms, target_generator=target_generator, min_instance_area=min_instance_area) - self.root = data_dir - self.images_dir = os.path.join(data_dir, images_dir) - self.json_file = os.path.join(data_dir, json_file) - coco = COCO(self.json_file) + json_file = os.path.join(data_dir, json_file) + coco = COCO(json_file) if len(coco.dataset["categories"]) != 1: raise ValueError("Dataset must contain exactly one category") - + joints = coco.dataset["categories"][0]["keypoints"] + num_joints = len(joints) + + super().__init__( + transforms=transforms, + target_generator=target_generator, + min_instance_area=min_instance_area, + num_joints=num_joints, + edge_links=edge_links, + edge_colors=edge_colors, + keypoint_colors=keypoint_colors, + ) + self.root = data_dir + self.images_dir = os.path.join(data_dir, images_dir) self.coco = coco self.ids = list(self.coco.imgs.keys()) - self.joints = coco.dataset["categories"][0]["keypoints"] - self.num_joints = len(self.joints) + self.joints = joints if not include_empty_samples: subset = [img_id for img_id in self.ids if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0] @@ -190,3 +205,21 @@ def get_mask(self, anno, img_info) -> np.ndarray: m += mask return (m < 0.5).astype(np.float32) + + def get_dataset_preprocessing_params(self): + """ + + :return: + """ + # Since we are using cv2.imread to read images, our model in fact is trained on BGR images. + # In our pipelines the convention that input images are RGB, so we need to reverse the channels to get BGR + # to match with the expected input of the model. + pipeline = [Processings.ReverseImageChannels] + self.transforms.get_equivalent_preprocessing() + params = dict( + conf=0.25, + image_processor={Processings.ComposeProcessing: {"processings": pipeline}}, + edge_links=self.edge_links, + edge_colors=self.edge_colors, + keypoint_colors=self.keypoint_colors, + ) + return params diff --git a/src/super_gradients/training/models/detection_models/customizable_detector.py b/src/super_gradients/training/models/detection_models/customizable_detector.py index 9de78abbfa..bcbc07f384 100644 --- a/src/super_gradients/training/models/detection_models/customizable_detector.py +++ b/src/super_gradients/training/models/detection_models/customizable_detector.py @@ -19,7 +19,7 @@ from super_gradients.training.utils.utils import HpmStruct, arch_params_deprecated from super_gradients.training.models.sg_module import SgModule import super_gradients.common.factories.detection_modules_factory as det_factory -from super_gradients.training.models.prediction_results import ImagesDetectionPrediction +from super_gradients.training.utils.predict import ImagesDetectionPrediction from super_gradients.training.pipelines.pipelines import DetectionPipeline from super_gradients.training.processing.processing import Processing from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback diff --git a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py index c4438a8123..a64a974ed8 100644 --- a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py +++ b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_e.py @@ -16,7 +16,7 @@ from super_gradients.training.utils import HpmStruct from super_gradients.training.models.arch_params_factory import get_arch_params from super_gradients.training.models.detection_models.pp_yolo_e.post_prediction_callback import PPYoloEPostPredictionCallback, DetectionPostPredictionCallback -from super_gradients.training.models.prediction_results import ImagesDetectionPrediction +from super_gradients.training.utils.predict import ImagesDetectionPrediction from super_gradients.training.pipelines.pipelines import DetectionPipeline from super_gradients.training.processing.processing import Processing from super_gradients.training.utils.media.image import ImageSource diff --git a/src/super_gradients/training/models/detection_models/yolo_base.py b/src/super_gradients/training/models/detection_models/yolo_base.py index e42fbdd2ad..69d2ba0165 100755 --- a/src/super_gradients/training/models/detection_models/yolo_base.py +++ b/src/super_gradients/training/models/detection_models/yolo_base.py @@ -14,7 +14,7 @@ from super_gradients.training.utils import torch_version_is_greater_or_equal from super_gradients.training.utils.detection_utils import non_max_suppression, matrix_non_max_suppression, NMS_Type, DetectionPostPredictionCallback, Anchors from super_gradients.training.utils.utils import HpmStruct, check_img_size_divisibility, get_param -from super_gradients.training.models.prediction_results import ImagesDetectionPrediction +from super_gradients.training.utils.predict import ImagesDetectionPrediction from super_gradients.training.pipelines.pipelines import DetectionPipeline from super_gradients.training.processing.processing import Processing from super_gradients.training.utils.media.image import ImageSource diff --git a/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py b/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py index da7c0d7e59..edcb4b9f5c 100644 --- a/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py +++ b/src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py @@ -11,22 +11,32 @@ from __future__ import print_function import copy -from typing import Mapping, Any, Tuple +from functools import lru_cache +from typing import Mapping, Any, Tuple, Optional, List, Union +import numpy as np import torch import torch.nn.functional as F import torchvision from torch import nn +from super_gradients.common.decorators.factory_decorator import resolve_param +from super_gradients.common.factories.processing_factory import ProcessingFactory from super_gradients.common.registry.registry import register_model from super_gradients.common.object_names import Models from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.training.utils.predict import ImagesPoseEstimationPrediction from super_gradients.training.models.sg_module import SgModule from super_gradients.training.models.arch_params_factory import get_arch_params __all__ = ["DEKRPoseEstimationModel", "DEKRW32NODC"] -from super_gradients.training.utils import HpmStruct +from super_gradients.training.pipelines.pipelines import PoseEstimationPipeline + +from super_gradients.training.processing.processing import Processing + +from super_gradients.training.utils import HpmStruct, DEKRPoseEstimationDecodeCallback +from super_gradients.training.utils.media.image import ImageSource logger = get_logger(__name__) @@ -520,8 +530,98 @@ def init_weights(self): if hasattr(m, "bias"): nn.init.constant_(m.translation_conv.bias, 0) + @staticmethod + def get_post_prediction_callback(conf: float = 0.05): + return DEKRPoseEstimationDecodeCallback( + min_confidence=conf, + keypoint_threshold=0.05, + nms_threshold=0.05, + apply_sigmoid=True, + max_num_people=30, + nms_num_threshold=8, + output_stride=4, + ) + + @resolve_param("image_processor", ProcessingFactory()) + def set_dataset_processing_params( + self, + edge_links: Union[np.ndarray, List[Tuple[int, int]]], + edge_colors: Union[np.ndarray, List[Tuple[int, int, int]]], + keypoint_colors: Union[np.ndarray, List[Tuple[int, int, int]]], + image_processor: Optional[Processing] = None, + conf: Optional[float] = None, + ) -> None: + """Set the processing parameters for the dataset. + + :param image_processor: (Optional) Image processing objects to reproduce the dataset preprocessing used for training. + :param conf: (Optional) Below the confidence threshold, prediction are discarded + """ + self._edge_links = edge_links or self._edge_links + self._edge_colors = edge_colors or self._edge_colors + self._keypoint_colors = keypoint_colors or self._keypoint_colors + self._image_processor = image_processor or self._image_processor + self._default_nms_conf = conf or self._default_nms_conf + + @lru_cache(maxsize=1) + def _get_pipeline(self, conf: Optional[float] = None, fuse_model: bool = True) -> PoseEstimationPipeline: + """Instantiate the prediction pipeline of this model. + + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage. + """ + if None in (self._edge_links, self._image_processor, self._default_nms_conf): + raise RuntimeError( + "You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first." + ) + + conf = conf or self._default_nms_conf + + if len(self._keypoint_colors) != self.num_joints: + raise RuntimeError( + "The number of colors for the keypoints ({}) does not match the number of joints ({})".format(len(self._keypoint_colors), self.num_joints) + ) + if len(self._edge_colors) != len(self._edge_links): + raise RuntimeError( + "The number of colors for the joints ({}) does not match the number of joint links ({})".format(len(self._edge_colors), len(self._edge_links)) + ) + + pipeline = PoseEstimationPipeline( + model=self, + image_processor=self._image_processor, + edge_links=self._edge_links, + edge_colors=self._edge_colors, + keypoint_colors=self._keypoint_colors, + post_prediction_callback=self.get_post_prediction_callback(conf=conf), + fuse_model=fuse_model, + ) + return pipeline + + def predict(self, images: ImageSource, conf: Optional[float] = None, fuse_model: bool = True) -> ImagesPoseEstimationPrediction: + """Predict an image or a list of images. + + :param images: Images to predict. + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage. + """ + pipeline = self._get_pipeline(conf=conf, fuse_model=fuse_model) + return pipeline(images) # type: ignore + + def predict_webcam(self, conf: Optional[float] = None, fuse_model: bool = True): + """Predict using webcam. -POSE_DEKR_W32_NO_DC_ARCH_PARAMS = get_arch_params("pose_dekr_w32_no_dc_arch_params") + :param conf: (Optional) Below the confidence threshold, prediction are discarded. + If None, the default value associated to the training is used. + :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage. + """ + pipeline = self._get_pipeline(conf=conf, fuse_model=fuse_model) + pipeline.predict_webcam() + + def train(self, mode: bool = True): + self._get_pipeline.cache_clear() + torch.cuda.empty_cache() + return super().train(mode) @register_model(Models.DEKR_W32_NO_DC) @@ -531,6 +631,8 @@ class DEKRW32NODC(DEKRPoseEstimationModel): """ def __init__(self, arch_params): + POSE_DEKR_W32_NO_DC_ARCH_PARAMS = get_arch_params("pose_dekr_w32_no_dc_arch_params") + merged_arch_params = HpmStruct(**copy.deepcopy(POSE_DEKR_W32_NO_DC_ARCH_PARAMS)) merged_arch_params.override(**arch_params.to_dict()) super().__init__(merged_arch_params) @@ -552,11 +654,12 @@ def forward(self, inputs): class DEKRHorisontalFlipWrapper(nn.Module): - def __init__(self, model: DEKRPoseEstimationModel, flip_indexes_heatmap, flip_indexes_offset, apply_sigmoid=False): + def __init__(self, model: DEKRPoseEstimationModel, flip_indexes, apply_sigmoid=False): super().__init__() self.model = model - self.flip_indexes_heatmap = torch.tensor(flip_indexes_heatmap).long() - self.flip_indexes_offset = torch.tensor(flip_indexes_offset).long() + # In DEKR the heatmap has one more channel for the center point of the pose, which is the last channel and it is not flipped + self.flip_indexes_heatmap = torch.tensor(list(flip_indexes) + [len(flip_indexes)]).long() + self.flip_indexes_offset = torch.tensor(flip_indexes).long() self.apply_sigmoid = apply_sigmoid def forward(self, inputs): diff --git a/src/super_gradients/training/models/pose_estimation_models/rescoring_net.py b/src/super_gradients/training/models/pose_estimation_models/rescoring_net.py index d122dd9c59..87d0c9fad7 100644 --- a/src/super_gradients/training/models/pose_estimation_models/rescoring_net.py +++ b/src/super_gradients/training/models/pose_estimation_models/rescoring_net.py @@ -24,9 +24,9 @@ class PoseRescoringNet(SgModule): The output is a single scalar value. """ - def __init__(self, num_classes: int, hidden_channels: int, num_layers: int, joint_links: List[Tuple[int, int]]): + def __init__(self, num_classes: int, hidden_channels: int, num_layers: int, edge_links: List[Tuple[int, int]]): super(PoseRescoringNet, self).__init__() - in_channels = len(joint_links) * 2 + len(joint_links) + num_classes # [joint_relate, joint_length, visibility] + in_channels = len(edge_links) * 2 + len(edge_links) + num_classes # [joint_relate, joint_length, visibility] layers = [] for _ in range(num_layers): layers.append(nn.Linear(in_channels, hidden_channels, bias=True)) @@ -34,7 +34,7 @@ def __init__(self, num_classes: int, hidden_channels: int, num_layers: int, join in_channels = hidden_channels self.layers = nn.Sequential(*layers) self.final = nn.Linear(hidden_channels, 1, bias=True) - self.joint_links = torch.tensor(joint_links).long() + self.edge_links = torch.tensor(edge_links).long() def forward(self, poses: Tensor) -> Tuple[Tensor, Tensor]: """ @@ -43,7 +43,7 @@ def forward(self, poses: Tensor) -> Tuple[Tensor, Tensor]: :return: Tuple of input poses and corresponding scores """ - x = self.get_feature(poses, self.joint_links) + x = self.get_feature(poses, self.edge_links) x = self.layers(x) y_pred = self.final(x) return poses, y_pred @@ -55,19 +55,19 @@ def init_weights(self): nn.init.constant_(m.bias, 0) @classmethod - def get_feature(cls, poses: Tensor, joint_links: Tensor) -> Tensor: + def get_feature(cls, poses: Tensor, edge_links: Tensor) -> Tensor: """ Compute the feature vector input to the rescoring network. :param poses: [N, J, 3] Predicted poses - :param joint_links: [L,2] List of joint indices + :param edge_links: [L,2] List of joint indices :return: [N, L*2+L+J] Feature vector """ joint_xy = poses[..., :2] visibility = poses[..., 2] - joint_1 = joint_links[:, 0] - joint_2 = joint_links[:, 1] + joint_1 = edge_links[:, 0] + joint_2 = edge_links[:, 1] # To get the Delta x Delta y joint_relate = joint_xy[..., joint_1, :] - joint_xy[..., joint_2, :] # [N, L, 2] @@ -99,5 +99,5 @@ def __init__(self, arch_params): num_classes=merged_arch_params.num_classes, hidden_channels=merged_arch_params.hidden_channels, num_layers=merged_arch_params.num_layers, - joint_links=merged_arch_params.joint_links, + edge_links=merged_arch_params.edge_links, ) diff --git a/src/super_gradients/training/models/predictions.py b/src/super_gradients/training/models/predictions.py index a657bf2b56..d994d60ade 100644 --- a/src/super_gradients/training/models/predictions.py +++ b/src/super_gradients/training/models/predictions.py @@ -1,55 +1,11 @@ -from typing import Tuple -from abc import ABC -from dataclasses import dataclass +from super_gradients.training.utils.predict import Prediction, DetectionPrediction +import warnings -import numpy as np +warnings.warn( + "Importing from super_gradients.training.models.predictions is deprecated. " + "Please update your code to import from super_gradients.training.utils.predict instead.", + DeprecationWarning, +) -from super_gradients.common.factories.bbox_format_factory import BBoxFormatFactory -from super_gradients.training.datasets.data_formats.bbox_formats import convert_bboxes - -@dataclass -class Prediction(ABC): - pass - - -@dataclass -class DetectionPrediction(Prediction): - """Represents a detection prediction, with bboxes represented in xyxy format.""" - - bboxes_xyxy: np.ndarray - confidence: np.ndarray - labels: np.ndarray - - def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]): - """ - :param bboxes: BBoxes in the format specified by bbox_format - :param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...) - :param confidence: Confidence scores for each bounding box - :param labels: Labels for each bounding box. - :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format - """ - self._validate_input(bboxes, confidence, labels) - - factory = BBoxFormatFactory() - bboxes_xyxy = convert_bboxes( - bboxes=bboxes, - image_shape=image_shape, - source_format=factory.get(bbox_format), - target_format=factory.get("xyxy"), - inplace=False, - ) - - self.bboxes_xyxy = bboxes_xyxy - self.confidence = confidence - self.labels = labels - - def _validate_input(self, bboxes: np.ndarray, confidence: np.ndarray, labels: np.ndarray) -> None: - n_bboxes, n_confidences, n_labels = bboxes.shape[0], confidence.shape[0], labels.shape[0] - if n_bboxes != n_confidences != n_labels: - raise ValueError( - f"The number of bounding boxes ({n_bboxes}) does not match the number of confidence scores ({n_confidences}) and labels ({n_labels})." - ) - - def __len__(self): - return len(self.bboxes_xyxy) +__all__ = ["Prediction", "DetectionPrediction"] diff --git a/src/super_gradients/training/models/results.py b/src/super_gradients/training/models/results.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/super_gradients/training/pipelines/pipelines.py b/src/super_gradients/training/pipelines/pipelines.py index 1e716adea9..49e3114672 100644 --- a/src/super_gradients/training/pipelines/pipelines.py +++ b/src/super_gradients/training/pipelines/pipelines.py @@ -6,21 +6,27 @@ import numpy as np import torch -from super_gradients.training.utils.utils import generate_batch -from super_gradients.training.utils.media.video import load_video, includes_video_extension -from super_gradients.training.utils.media.image import ImageSource, check_image_typing -from super_gradients.training.utils.media.stream import WebcamStreaming -from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback -from super_gradients.training.models.sg_module import SgModule -from super_gradients.training.models.prediction_results import ( + +from super_gradients.training.utils.predict import ( + ImagePoseEstimationPrediction, + ImagesPoseEstimationPrediction, + VideoPoseEstimationPrediction, ImagesDetectionPrediction, VideoDetectionPrediction, ImagePrediction, ImageDetectionPrediction, ImagesPredictions, VideoPredictions, + Prediction, + DetectionPrediction, + PoseEstimationPrediction, ) -from super_gradients.training.models.predictions import Prediction, DetectionPrediction +from super_gradients.training.utils.utils import generate_batch +from super_gradients.training.utils.media.video import load_video, includes_video_extension +from super_gradients.training.utils.media.image import ImageSource, check_image_typing +from super_gradients.training.utils.media.stream import WebcamStreaming +from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback +from super_gradients.training.models.sg_module import SgModule from super_gradients.training.processing.processing import Processing, ComposeProcessing from super_gradients.common.abstractions.abstract_logger import get_logger @@ -298,3 +304,76 @@ def _combine_image_prediction_to_video( ) -> VideoDetectionPrediction: images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")] return VideoDetectionPrediction(_images_prediction_lst=images_predictions, fps=fps) + + +class PoseEstimationPipeline(Pipeline): + """Pipeline specifically designed for pose estimation tasks. + The pipeline includes loading images, preprocessing, prediction, and postprocessing. + + :param model: The object detection model (instance of SgModule) used for making predictions. + :param post_prediction_callback: Callback function to process raw predictions from the model. + :param image_processor: Single image processor or a list of image processors for preprocessing and postprocessing the images. + :param device: The device on which the model will be run. If None, will run on current model device. Use "cuda" for GPU support. + :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage. + """ + + def __init__( + self, + model: SgModule, + edge_links: Union[np.ndarray, List[Tuple[int, int]]], + edge_colors: Union[np.ndarray, List[Tuple[int, int, int]]], + keypoint_colors: Union[np.ndarray, List[Tuple[int, int, int]]], + post_prediction_callback, + device: Optional[str] = None, + image_processor: Optional[Processing] = None, + fuse_model: bool = True, + ): + super().__init__(model=model, device=device, image_processor=image_processor, class_names=None, fuse_model=fuse_model) + self.post_prediction_callback = post_prediction_callback + self.edge_links = np.asarray(edge_links, dtype=int) + self.edge_colors = np.asarray(edge_colors, dtype=int) + self.keypoint_colors = np.asarray(keypoint_colors, dtype=int) + + def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[PoseEstimationPrediction]: + """Decode the model output, by applying post prediction callback. This includes NMS. + + :param model_output: Direct output of the model, without any post-processing. + :param model_input: Model input (i.e. images after preprocessing). + :return: Predicted Bboxes. + """ + all_poses, all_scores = self.post_prediction_callback(model_output) + + predictions = [] + for poses, scores, image in zip(all_poses, all_scores, model_input): + predictions.append( + PoseEstimationPrediction( + poses=poses, + scores=scores, + image_shape=image.shape, + edge_links=self.edge_links, + edge_colors=self.edge_colors, + keypoint_colors=self.keypoint_colors, + ) + ) + + return predictions + + def _instantiate_image_prediction(self, image: np.ndarray, prediction: PoseEstimationPrediction) -> ImagePrediction: + return ImagePoseEstimationPrediction(image=image, prediction=prediction, class_names=self.class_names) + + def _combine_image_prediction_to_images( + self, images_predictions: Iterable[PoseEstimationPrediction], n_images: Optional[int] = None + ) -> ImagesPoseEstimationPrediction: + if n_images is not None and n_images == 1: + # Do not show tqdm progress bar if there is only one image + images_predictions = [next(iter(images_predictions))] + else: + images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")] + + return ImagesPoseEstimationPrediction(_images_prediction_lst=images_predictions) + + def _combine_image_prediction_to_video( + self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None + ) -> VideoPoseEstimationPrediction: + images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")] + return VideoPoseEstimationPrediction(_images_prediction_lst=images_predictions, fps=fps) diff --git a/src/super_gradients/training/processing/processing.py b/src/super_gradients/training/processing/processing.py index fef05b8667..da29c57502 100644 --- a/src/super_gradients/training/processing/processing.py +++ b/src/super_gradients/training/processing/processing.py @@ -6,7 +6,7 @@ from super_gradients.common.registry.registry import register_processing from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST -from super_gradients.training.models.predictions import Prediction, DetectionPrediction +from super_gradients.training.utils.predict import Prediction, DetectionPrediction, PoseEstimationPrediction from super_gradients.training.transforms.utils import ( _rescale_image, _rescale_bboxes, @@ -15,6 +15,8 @@ _pad_image, _shift_bboxes, PaddingCoordinates, + _rescale_keypoints, + _shift_keypoints, ) from super_gradients.common.object_names import Processings @@ -196,6 +198,37 @@ def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinate pass +class _KeypointsPadding(Processing, ABC): + """Base class for keypoints padding methods. One should implement the `_get_padding_params` method to work with a custom padding method. + + Note: This transformation assume that dimensions of input image is equal or less than `output_shape`. + + :param output_shape: Output image shape (H, W) + :param pad_value: Padding value for image + """ + + def __init__(self, output_shape: Tuple[int, int], pad_value: int): + self.output_shape = output_shape + self.pad_value = pad_value + + def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPadToSizeMetadata]: + padding_coordinates = self._get_padding_params(input_shape=image.shape) + processed_image = _pad_image(image=image, padding_coordinates=padding_coordinates, pad_value=self.pad_value) + return processed_image, DetectionPadToSizeMetadata(padding_coordinates=padding_coordinates) + + def postprocess_predictions(self, predictions: PoseEstimationPrediction, metadata: DetectionPadToSizeMetadata) -> PoseEstimationPrediction: + predictions.poses = _shift_keypoints( + targets=predictions.poses, + shift_h=-metadata.padding_coordinates.top, + shift_w=-metadata.padding_coordinates.left, + ) + return predictions + + @abstractmethod + def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates: + pass + + @register_processing(Processings.DetectionCenterPadding) class DetectionCenterPadding(_DetectionPadding): def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates: @@ -208,6 +241,12 @@ def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinate return _get_bottom_right_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape) +@register_processing(Processings.KeypointsBottomRightPadding) +class KeypointsBottomRightPadding(_KeypointsPadding): + def _get_padding_params(self, input_shape: Tuple[int, int]) -> PaddingCoordinates: + return _get_bottom_right_padding_coordinates(input_shape=input_shape, output_shape=self.output_shape) + + class _Rescale(Processing, ABC): """Resize image to given image dimensions WITHOUT preserving aspect ratio. @@ -259,6 +298,13 @@ def postprocess_predictions(self, predictions: DetectionPrediction, metadata: Re return predictions +@register_processing(Processings.KeypointsLongestMaxSizeRescale) +class KeypointsLongestMaxSizeRescale(_LongestMaxSizeRescale): + def postprocess_predictions(self, predictions: PoseEstimationPrediction, metadata: RescaleMetadata) -> PoseEstimationPrediction: + predictions.poses = _rescale_keypoints(targets=predictions.poses, scale_factors=(1 / metadata.scale_factor_h, 1 / metadata.scale_factor_w)) + return predictions + + def default_yolox_coco_processing_params() -> dict: """Processing parameters commonly used for training YoloX on COCO dataset. TODO: remove once we load it from the checkpoint @@ -328,6 +374,87 @@ def default_yolo_nas_coco_processing_params() -> dict: return params +def default_dekr_coco_processing_params() -> dict: + """Processing parameters commonly used for training DEKR on COCO dataset.""" + + image_processor = ComposeProcessing( + [ + ReverseImageChannels(), + KeypointsLongestMaxSizeRescale(output_shape=(640, 640)), + KeypointsBottomRightPadding(output_shape=(640, 640), pad_value=127), + StandardizeImage(max_value=255.0), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ImagePermute(permutation=(2, 0, 1)), + ] + ) + + edge_links = [ + [0, 1], + [0, 2], + [1, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + [5, 6], + [5, 7], + [5, 11], + [6, 8], + [6, 12], + [7, 9], + [8, 10], + [11, 12], + [11, 13], + [12, 14], + [13, 15], + [14, 16], + ] + + edge_colors = [ + (214, 39, 40), # Nose -> LeftEye + (148, 103, 189), # Nose -> RightEye + (44, 160, 44), # LeftEye -> RightEye + (140, 86, 75), # LeftEye -> LeftEar + (227, 119, 194), # RightEye -> RightEar + (127, 127, 127), # LeftEar -> LeftShoulder + (188, 189, 34), # RightEar -> RightShoulder + (127, 127, 127), # Shoulders + (188, 189, 34), # LeftShoulder -> LeftElbow + (140, 86, 75), # LeftTorso + (23, 190, 207), # RightShoulder -> RightElbow + (227, 119, 194), # RightTorso + (31, 119, 180), # LeftElbow -> LeftArm + (255, 127, 14), # RightElbow -> RightArm + (148, 103, 189), # Waist + (255, 127, 14), # Left Hip -> Left Knee + (214, 39, 40), # Right Hip -> Right Knee + (31, 119, 180), # Left Knee -> Left Ankle + (44, 160, 44), # Right Knee -> Right Ankle + ] + + keypoint_colors = [ + (148, 103, 189), + (31, 119, 180), + (148, 103, 189), + (31, 119, 180), + (148, 103, 189), + (31, 119, 180), + (148, 103, 189), + (31, 119, 180), + (148, 103, 189), + (31, 119, 180), + (148, 103, 189), + (31, 119, 180), + (148, 103, 189), + (31, 119, 180), + (148, 103, 189), + (31, 119, 180), + (148, 103, 189), + ] + params = dict(image_processor=image_processor, conf=0.05, edge_links=edge_links, edge_colors=edge_colors, keypoint_colors=keypoint_colors) + return params + + def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -> dict: """Get the processing parameters for a pretrained model. TODO: remove once we load it from the checkpoint @@ -339,4 +466,8 @@ def get_pretrained_processing_params(model_name: str, pretrained_weights: str) - return default_ppyoloe_coco_processing_params() elif "yolo_nas" in model_name: return default_yolo_nas_coco_processing_params() + + if pretrained_weights == "coco_pose" and model_name in ("dekr_w32_no_dc", "dekr_custom"): + return default_dekr_coco_processing_params() + return dict() diff --git a/src/super_gradients/training/transforms/keypoint_transforms.py b/src/super_gradients/training/transforms/keypoint_transforms.py index 3935c766cb..b805d6315f 100644 --- a/src/super_gradients/training/transforms/keypoint_transforms.py +++ b/src/super_gradients/training/transforms/keypoint_transforms.py @@ -1,17 +1,18 @@ import random from abc import abstractmethod -from typing import Tuple, List, Iterable, Union, Optional +from typing import Tuple, List, Iterable, Union, Optional, Dict import cv2 import numpy as np +import torch from torch import Tensor -from torchvision.transforms import functional as F -from super_gradients.common.object_names import Transforms +from super_gradients.common.object_names import Transforms, Processings from super_gradients.common.registry.registry import register_transform __all__ = [ "KeypointsImageNormalize", + "KeypointsImageStandardize", "KeypointsImageToTensor", "KeypointsPadIfNeeded", "KeypointsLongestMaxSize", @@ -28,7 +29,7 @@ @register_transform(Transforms.KeypointTransform) class KeypointTransform(object): """ - Base class for all transforms for keypoints augmnetation. + Base class for all transforms for keypoints augmentation. All transforms subclassing it should implement __call__ method which takes image, mask and keypoints as input and returns transformed image, mask and keypoints. """ @@ -49,6 +50,9 @@ def __call__( """ raise NotImplementedError + def get_equivalent_preprocessing(self) -> List: + raise NotImplementedError + class KeypointsCompose(KeypointTransform): def __init__(self, transforms: List[KeypointTransform]): @@ -61,33 +65,66 @@ def __call__( image, mask, joints, areas, bboxes = t(image, mask, joints, areas, bboxes) return image, mask, joints, areas, bboxes + def get_equivalent_preprocessing(self) -> List: + preprocessing = [] + for t in self.transforms: + preprocessing += t.get_equivalent_preprocessing() + return preprocessing + @register_transform(Transforms.KeypointsImageToTensor) class KeypointsImageToTensor(KeypointTransform): """ Convert image from numpy array to tensor and permute axes to [C,H,W]. - This function also divides image by 255.0 to convert it to [0,1] range. """ def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, areas: Optional[np.ndarray], bboxes: Optional[np.ndarray]): - return F.to_tensor(image), mask, joints, areas, bboxes + image = torch.from_numpy(np.transpose(image, (2, 0, 1))).float() + return image, mask, joints, areas, bboxes + + def get_equivalent_preprocessing(self) -> List: + return [ + {Processings.ImagePermute: {"permutation": (2, 0, 1)}}, + ] + + +@register_transform(Transforms.KeypointsImageStandardize) +class KeypointsImageStandardize(KeypointTransform): + """ + Standardize image pixel values with img/max_val + + :param max_val: Current maximum value of the image pixels. (usually 255) + """ + + def __init__(self, max_value: float = 255.0): + super().__init__() + self.max_value = max_value + + def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, areas: Optional[np.ndarray], bboxes: Optional[np.ndarray]): + image = (image / self.max_value).astype(np.float32) + return image, mask, joints, areas, bboxes + + def get_equivalent_preprocessing(self) -> List[Dict]: + return [{Processings.StandardizeImage: {"max_value": self.max_value}}] @register_transform(Transforms.KeypointsImageNormalize) class KeypointsImageNormalize(KeypointTransform): """ - Normalize image with mean and std. Note this transform should come after KeypointsImageToTensor - since it operates on torch Tensor and not numpy array. + Normalize image with mean and std. """ def __init__(self, mean, std): - self.mean = mean - self.std = std + self.mean = np.array(list(mean)).reshape((1, 1, -1)).astype(np.float32) + self.std = np.array(list(std)).reshape((1, 1, -1)).astype(np.float32) def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, areas: Optional[np.ndarray], bboxes: Optional[np.ndarray]): - image = F.normalize(image, mean=self.mean, std=self.std) + image = (image - self.mean) / self.std return image, mask, joints, areas, bboxes + def get_equivalent_preprocessing(self) -> List: + return [{Processings.NormalizeImage: {"mean": self.mean, "std": self.std}}] + @register_transform(Transforms.KeypointsRandomHorizontalFlip) class KeypointsRandomHorizontalFlip(KeypointTransform): @@ -136,6 +173,9 @@ def apply_to_bboxes(self, bboxes, cols): bboxes[:, 0] = cols - (bboxes[:, 0] + bboxes[:, 2]) return bboxes + def get_equivalent_preprocessing(self) -> List: + raise RuntimeError("KeypointsRandomHorizontalFlip does not have equivalent preprocessing.") + @register_transform(Transforms.KeypointsRandomVerticalFlip) class KeypointsRandomVerticalFlip(KeypointTransform): @@ -175,6 +215,9 @@ def apply_to_bboxes(self, bboxes, rows): bboxes[:, 1] = rows - (bboxes[:, 1] + bboxes[:, 3]) - 1 return bboxes + def get_equivalent_preprocessing(self) -> List: + raise RuntimeError("KeypointsRandomHorizontalFlip does not have equivalent preprocessing.") + @register_transform(Transforms.KeypointsLongestMaxSize) class KeypointsLongestMaxSize(KeypointTransform): @@ -235,6 +278,9 @@ def apply_to_keypoints(cls, keypoints, scale): def apply_to_bboxes(cls, bboxes, scale): return bboxes * scale + def get_equivalent_preprocessing(self) -> List: + return [{Processings.KeypointsLongestMaxSizeRescale: {"output_shape": (self.max_height, self.max_width)}}] + @register_transform(Transforms.KeypointsPadIfNeeded) class KeypointsPadIfNeeded(KeypointTransform): @@ -252,7 +298,7 @@ def __init__(self, min_height: int, min_width: int, image_pad_value: int, mask_p """ self.min_height = min_height self.min_width = min_width - self.image_pad_value = tuple(image_pad_value) if isinstance(image_pad_value, Iterable) else int(image_pad_value) + self.image_pad_value = image_pad_value self.mask_pad_value = mask_pad_value def __call__(self, image, mask, joints, areas: Optional[np.ndarray], bboxes: Optional[np.ndarray]): @@ -261,7 +307,8 @@ def __call__(self, image, mask, joints, areas: Optional[np.ndarray], bboxes: Opt pad_bottom = max(0, self.min_height - height) pad_right = max(0, self.min_width - width) - image = cv2.copyMakeBorder(image, top=0, bottom=pad_bottom, left=0, right=pad_right, value=self.image_pad_value, borderType=cv2.BORDER_CONSTANT) + image_pad_value = tuple(self.image_pad_value) if isinstance(self.image_pad_value, Iterable) else tuple([self.image_pad_value] * image.shape[-1]) + image = cv2.copyMakeBorder(image, top=0, bottom=pad_bottom, left=0, right=pad_right, value=image_pad_value, borderType=cv2.BORDER_CONSTANT) original_dtype = mask.dtype mask = cv2.copyMakeBorder( @@ -271,6 +318,9 @@ def __call__(self, image, mask, joints, areas: Optional[np.ndarray], bboxes: Opt return image, mask, joints, areas, bboxes + def get_equivalent_preprocessing(self) -> List: + return [{Processings.KeypointsBottomRightPadding: {"output_shape": (self.min_height, self.min_width), "pad_value": self.image_pad_value}}] + @register_transform(Transforms.KeypointsRandomAffineTransform) class KeypointsRandomAffineTransform(KeypointTransform): @@ -299,7 +349,7 @@ def __init__( self.min_scale = min_scale self.max_scale = max_scale self.max_translate = max_translate - self.image_pad_value = tuple(image_pad_value) if isinstance(image_pad_value, Iterable) else int(image_pad_value) + self.image_pad_value = image_pad_value self.mask_pad_value = mask_pad_value self.prob = prob @@ -337,8 +387,10 @@ def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, area mat_output = self._get_affine_matrix(image, angle, scale, dx, dy) mat_output = mat_output[:2] + image_pad_value = tuple(self.image_pad_value) if isinstance(self.image_pad_value, Iterable) else tuple([self.image_pad_value] * image.shape[-1]) + mask = self.apply_to_image(mask, mat_output, cv2.INTER_NEAREST, self.mask_pad_value, cv2.BORDER_CONSTANT) - image = self.apply_to_image(image, mat_output, cv2.INTER_LINEAR, self.image_pad_value, cv2.BORDER_CONSTANT) + image = self.apply_to_image(image, mat_output, cv2.INTER_LINEAR, image_pad_value, cv2.BORDER_CONSTANT) joints = self.apply_to_keypoints(joints, mat_output, image.shape) @@ -403,3 +455,6 @@ def apply_to_image(cls, image, mat, interpolation, padding_value, padding_mode=c borderValue=padding_value, borderMode=padding_mode, ) + + def get_equivalent_preprocessing(self) -> List: + raise RuntimeError(f"{self.__class__} does not have equivalent preprocessing.") diff --git a/src/super_gradients/training/transforms/utils.py b/src/super_gradients/training/transforms/utils.py index 7379569b93..c07fd087fc 100644 --- a/src/super_gradients/training/transforms/utils.py +++ b/src/super_gradients/training/transforms/utils.py @@ -26,7 +26,7 @@ def _rescale_image(image: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarr return cv2.resize(image, dsize=(width, height), interpolation=cv2.INTER_LINEAR).astype(np.uint8) -def _rescale_bboxes(targets: np.array, scale_factors: Tuple[float, float]) -> np.array: +def _rescale_bboxes(targets: np.ndarray, scale_factors: Tuple[float, float]) -> np.ndarray: """Rescale bboxes to given scale factors, without preserving aspect ratio. :param targets: Targets to rescale (N, 4+), where target[:, :4] is the bounding box coordinates. @@ -41,6 +41,24 @@ def _rescale_bboxes(targets: np.array, scale_factors: Tuple[float, float]) -> np return targets +def _rescale_keypoints(targets: np.ndarray, scale_factors: Tuple[float, float]) -> np.ndarray: + """Rescale keypoints to given scale factors, without preserving aspect ratio. + + :param targets: Array of keypoints to rescale. Can have arbitrary shape [N,2], [N,K,2], etc. + Last dimension encodes XY coordinates: target[..., 0] is the X coordinates and + targets[..., 1] is the Y coordinate. + :param scale_factors: Tuple of (scale_factor_h, scale_factor_w) scale factors to rescale to. + :return: Rescaled targets. + """ + + targets = targets.astype(np.float32, copy=True) + + sy, sx = scale_factors + targets[..., 0] *= sx + targets[..., 1] *= sy + return targets + + def _get_center_padding_coordinates(input_shape: Tuple[int, int], output_shape: Tuple[int, int]) -> PaddingCoordinates: """Get parameters for padding an image to given output shape, in center mode. @@ -82,6 +100,7 @@ def _pad_image(image: np.ndarray, padding_coordinates: PaddingCoordinates, pad_v """ pad_h = (padding_coordinates.top, padding_coordinates.bottom) pad_w = (padding_coordinates.left, padding_coordinates.right) + if len(image.shape) == 3: return np.pad(image, (pad_h, pad_w, (0, 0)), "constant", constant_values=pad_value) else: @@ -102,6 +121,20 @@ def _shift_bboxes(targets: np.array, shift_w: float, shift_h: float) -> np.array return np.concatenate((boxes, labels), 1) +def _shift_keypoints(targets: np.array, shift_w: float, shift_h: float) -> np.array: + """Shift keypoints with respect to padding values. + + :param targets: Keypoints to transform of shape (N, 2+), or (N, K, 2+), in format [x1, y1, ...] + :param shift_w: shift width. + :param shift_h: shift height. + :return: Transformed keypoints of the same shape as input. + """ + targets = targets.copy() + targets[..., 0] += shift_w + targets[..., 1] += shift_h + return targets + + def _rescale_xyxy_bboxes(targets: np.array, r: float) -> np.array: """Scale targets to given scale factors. diff --git a/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py b/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py index 7f491c8a17..2c90d8b080 100644 --- a/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py +++ b/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py @@ -79,13 +79,13 @@ def _hierarchical_pool(heatmap, pool_threshold1=300, pool_threshold2=200): return maxm -def _get_maximum_from_heatmap(heatmap, max_num_people: int, keypoint_threshold: float): +def _get_maximum_from_heatmap(heatmap, max_num_people: int, pose_center_score_threshold: float) -> Tuple[Tensor, Tensor]: """ :param heatmap: [1, H, W] Single-channel heatmap - :param max_num_people: (int) - :param keypoint_threshold: (float) - :return: + :param max_num_people: (int) Maximum number of poses to return + :param pose_center_score_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate + :return: Tuple of (indexes of poses, scores) """ maxm = _hierarchical_pool(heatmap) maxm = torch.eq(maxm, heatmap).float() @@ -93,7 +93,7 @@ def _get_maximum_from_heatmap(heatmap, max_num_people: int, keypoint_threshold: scores = heatmap.view(-1) scores, pos_ind = scores.topk(max_num_people) - select_ind = (scores > (keypoint_threshold)).nonzero() + select_ind = (scores > pose_center_score_threshold).nonzero() scores = scores[select_ind][:, 0] pos_ind = pos_ind[select_ind][:, 0] @@ -117,6 +117,17 @@ def _cal_area_2_torch(v): def _nms_core(pose_coord, heat_score, nms_threshold: float, nms_num_threshold: int): + """ + Non-maximum suppression for predicted poses. + Removes poses that has certain number of joints that are too close to each other. + + :param pose_coord: Array of shape [num_people, num_joints, 2] with pose coordinates + :param heat_score: Scores of each joint + :param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose. + Given in terms of a percentage of a square root of the area of the pose bounding box. + :param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one. + :return: Indexes of poses that should be kept + """ num_people, num_joints, _ = pose_coord.shape pose_area = _cal_area_2_torch(pose_coord)[:, None].repeat(1, num_people * num_joints) pose_area = pose_area.reshape(num_people, num_people, num_joints) @@ -159,15 +170,19 @@ def _get_heat_value(pose_coord, heatmap): return heatval -def pose_nms(heatmap_avg, poses, max_num_people: int, nms_threshold: float, nms_num_threshold: int) -> Tuple[np.ndarray, np.ndarray]: +def pose_nms( + heatmap_avg, poses, max_num_people: int, nms_threshold: float, nms_num_threshold: int, pose_score_threshold: float +) -> Tuple[np.ndarray, np.ndarray]: """ NMS for the regressed poses results. - :param heatmap_avg (Tensor): Avg of the heatmaps at all scales (1, 1+num_joints, w, h) - :param poses (List): Gather of the pose proposals [(num_people, num_joints, 3)] - :param max_num_people (int): Maximum number of decoded poses - :param nms_threshold (float) Minimum confidence threshold for joint - :param nms_num_threshold (int): Minimum number of joints per pose above the nms_threshold for pose to be considered a valid candidate + :param Tensor heatmap_avg: Avg of the heatmaps at all scales (1, 1+num_joints, w, h) + :param List poses: Gather of the pose proposals [(num_people, num_joints, 3)] + :param int max_num_people: Maximum number of decoded poses + :param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose. + Given in terms of a percentage of a square root of the area of the pose bounding box. + :param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one. + :param float pose_score_threshold: Minimum confidence threshold for pose. Pose with confidence lower than this threshold will be discarded. :return Tuple of (poses, scores) """ @@ -198,12 +213,18 @@ def pose_nms(heatmap_avg, poses, max_num_people: int, nms_threshold: float, nms_ poses = poses.numpy() if len(poses): scores = poses[:, :, 2].mean(axis=1) + + mask = scores >= pose_score_threshold + poses = poses[mask] + scores = scores[mask] else: return np.zeros((0, num_joints, 3), dtype=np.float32), np.zeros((0,), dtype=np.float32) return poses, scores -def aggregate_results(heatmap: Tensor, posemap: Tensor, output_stride: int, keypoint_threshold: float, max_num_people: int) -> Tuple[Tensor, List[Tensor]]: +def aggregate_results( + heatmap: Tensor, posemap: Tensor, output_stride: int, pose_center_score_threshold: float, max_num_people: int +) -> Tuple[Tensor, List[Tensor]]: """ Get initial pose proposals and aggregate the results of all scale. Not this implementation works only for batch size of 1. @@ -211,7 +232,7 @@ def aggregate_results(heatmap: Tensor, posemap: Tensor, output_stride: int, keyp :param heatmap: Heatmap at this scale (B, 1+num_joints, w, h) :param posemap: Posemap at this scale (B, 2*num_joints, w, h) :param output_stride: Ratio of input size / predictions size - :param keypoint_threshold: (float) + :param pose_center_score_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate :param max_num_people: (int) :return: @@ -225,7 +246,7 @@ def aggregate_results(heatmap: Tensor, posemap: Tensor, output_stride: int, keyp heatmap_sum = _up_interpolate(heatmap, size=(int(output_stride * w), int(output_stride * h))) center_heatmap = heatmap[0, -1:] - pose_ind, ctr_score = _get_maximum_from_heatmap(center_heatmap, keypoint_threshold=keypoint_threshold, max_num_people=max_num_people) + pose_ind, ctr_score = _get_maximum_from_heatmap(center_heatmap, pose_center_score_threshold=pose_center_score_threshold, max_num_people=max_num_people) posemap = posemap[0].permute(1, 2, 0).view(h * w, -1, 2) pose = output_stride * posemap[pose_ind] ctr_score = ctr_score[:, None].expand(-1, pose.shape[-2])[:, :, None] @@ -239,16 +260,27 @@ class DEKRPoseEstimationDecodeCallback(nn.Module): Class that implements decoding logic of DEKR's model predictions into poses. """ - def __init__(self, output_stride: int, max_num_people: int, keypoint_threshold: float, nms_threshold: float, nms_num_threshold: int, apply_sigmoid: bool): + def __init__( + self, + output_stride: int, + max_num_people: int, + keypoint_threshold: float, + nms_threshold: float, + nms_num_threshold: int, + apply_sigmoid: bool, + min_confidence: float = 0.0, + ): """ - :param output_stride: - :param max_num_people: - :param keypoint_threshold: - :param nms_threshold: - :param nms_num_threshold: - :param apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not + :param output_stride: Output stride of the model + :param int max_num_people: Maximum number of decoded poses + :param float keypoint_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate + :param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose. + Given in terms of a percentage of a square root of the area of the pose bounding box. + :param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one. + :param bool apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not bound to [0..1] range and trained with logits (E.g focal loss) + :param float min_confidence: Minimum confidence threshold for pose """ super().__init__() self.keypoint_threshold = keypoint_threshold @@ -257,14 +289,15 @@ def __init__(self, output_stride: int, max_num_people: int, keypoint_threshold: self.nms_threshold = nms_threshold self.nms_num_threshold = nms_num_threshold self.apply_sigmoid = apply_sigmoid + self.min_confidence = min_confidence @torch.no_grad() def forward(self, predictions: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[List[np.ndarray], List[np.ndarray]]: """ - :param predictions: Either tuple (heatmap, offset): - heatmap - [1, NumJoints+1,H,W] - offset - [1, NumJoints*2,H,W] + :param predictions: Tuple (heatmap, offset): + heatmap - [BatchSize, NumJoints+1,H,W] + offset - [BatchSize, NumJoints*2,H,W] :return: Tuple """ @@ -292,13 +325,18 @@ def decode_one_sized_batch(self, predictions: Tuple[Tensor, Tensor]) -> Tuple[Te heatmap_sum, poses_sum = aggregate_results( heatmap, posemap, - keypoint_threshold=self.keypoint_threshold, + pose_center_score_threshold=self.keypoint_threshold, max_num_people=self.max_num_people, output_stride=self.output_stride, ) poses, scores = pose_nms( - heatmap_sum, poses_sum, max_num_people=self.max_num_people, nms_threshold=self.nms_threshold, nms_num_threshold=self.nms_num_threshold + heatmap_sum, + poses_sum, + max_num_people=self.max_num_people, + nms_threshold=self.nms_threshold, + nms_num_threshold=self.nms_num_threshold, + pose_score_threshold=self.min_confidence, ) if len(poses) != len(scores): diff --git a/src/super_gradients/training/utils/predict/__init__.py b/src/super_gradients/training/utils/predict/__init__.py new file mode 100644 index 0000000000..c50f9d8011 --- /dev/null +++ b/src/super_gradients/training/utils/predict/__init__.py @@ -0,0 +1,30 @@ +from .predictions import Prediction, DetectionPrediction, PoseEstimationPrediction +from .prediction_results import ( + ImageDetectionPrediction, + ImagesDetectionPrediction, + VideoDetectionPrediction, + ImagePrediction, + ImagesPredictions, + VideoPredictions, +) +from .prediction_pose_estimation_results import ( + ImagePoseEstimationPrediction, + VideoPoseEstimationPrediction, + ImagesPoseEstimationPrediction, +) + + +__all__ = [ + "Prediction", + "DetectionPrediction", + "ImagePrediction", + "ImagesPredictions", + "VideoPredictions", + "ImageDetectionPrediction", + "ImagesDetectionPrediction", + "VideoDetectionPrediction", + "PoseEstimationPrediction", + "ImagePoseEstimationPrediction", + "ImagesPoseEstimationPrediction", + "VideoPoseEstimationPrediction", +] diff --git a/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py b/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py new file mode 100644 index 0000000000..9748bc8128 --- /dev/null +++ b/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py @@ -0,0 +1,320 @@ +import os +from dataclasses import dataclass +from typing import List + +import numpy as np + +from super_gradients.training.utils.predict import ImagePrediction, ImagesPredictions, VideoPredictions, PoseEstimationPrediction +from super_gradients.training.utils.media.image import show_image, save_image +from super_gradients.training.utils.media.video import show_video_from_frames, save_video +from super_gradients.training.utils.visualization.pose_estimation import draw_skeleton + + +@dataclass +class ImagePoseEstimationPrediction(ImagePrediction): + """Object wrapping an image and a detection model's prediction. + + :attr image: Input image + :attr predictions: Predictions of the model + :attr class_names: List of the class names to predict + """ + + image: np.ndarray + prediction: PoseEstimationPrediction + + def draw( + self, + edge_colors=None, + joint_thickness: int = 2, + keypoint_colors=None, + keypoint_radius: int = 5, + box_thickness: int = 2, + show_confidence: bool = False, + ) -> np.ndarray: + """Draw the predicted bboxes on the image. + + :param edge_colors: Optional list of tuples representing the colors for each joint. + If None, default colors are used. + If not None the length must be equal to the number of joint links in the skeleton. + :param joint_thickness: Thickness of the joint links (in pixels). + :param keypoint_colors: Optional list of tuples representing the colors for each keypoint. + If None, default colors are used. + If not None the length must be equal to the number of joints in the skeleton. + :param keypoint_radius: Radius of the keypoints (in pixels). + :param show_confidence: Whether to show confidence scores on the image. + :param box_thickness: Thickness of bounding boxes. + :return: Image with predicted bboxes. Note that this does not modify the original image. + """ + image = self.image.copy() + + for pred_i in np.argsort(self.prediction.scores): + image = draw_skeleton( + image=image, + keypoints=self.prediction.poses[pred_i], + score=self.prediction.scores[pred_i], + show_confidence=show_confidence, + edge_links=self.prediction.edge_links, + edge_colors=edge_colors or self.prediction.edge_colors, + joint_thickness=joint_thickness, + keypoint_colors=keypoint_colors or self.prediction.keypoint_colors, + keypoint_radius=keypoint_radius, + box_thickness=box_thickness, + ) + + return image + + def show( + self, + edge_colors=None, + joint_thickness: int = 2, + keypoint_colors=None, + keypoint_radius: int = 5, + box_thickness: int = 2, + show_confidence: bool = False, + ) -> None: + """Display the image with predicted bboxes. + + :param edge_colors: Optional list of tuples representing the colors for each joint. + If None, default colors are used. + If not None the length must be equal to the number of joint links in the skeleton. + :param joint_thickness: Thickness of the joint links (in pixels). + :param keypoint_colors: Optional list of tuples representing the colors for each keypoint. + If None, default colors are used. + If not None the length must be equal to the number of joints in the skeleton. + :param keypoint_radius: Radius of the keypoints (in pixels). + :param show_confidence: Whether to show confidence scores on the image. + :param box_thickness: Thickness of bounding boxes. + """ + image = self.draw( + edge_colors=edge_colors, + joint_thickness=joint_thickness, + keypoint_colors=keypoint_colors, + keypoint_radius=keypoint_radius, + box_thickness=box_thickness, + show_confidence=show_confidence, + ) + show_image(image) + + def save( + self, + output_path: str, + edge_colors=None, + joint_thickness: int = 2, + keypoint_colors=None, + keypoint_radius: int = 5, + box_thickness: int = 2, + show_confidence: bool = False, + ) -> None: + """Save the predicted bboxes on the images. + + :param output_path: Path to the output video file. + :param edge_colors: Optional list of tuples representing the colors for each joint. + If None, default colors are used. + If not None the length must be equal to the number of joint links in the skeleton. + :param joint_thickness: Thickness of the joint links (in pixels). + :param keypoint_colors: Optional list of tuples representing the colors for each keypoint. + If None, default colors are used. + If not None the length must be equal to the number of joints in the skeleton. + :param keypoint_radius: Radius of the keypoints (in pixels). + :param show_confidence: Whether to show confidence scores on the image. + :param box_thickness: Thickness of bounding boxes. + """ + image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence) + save_image(image=image, path=output_path) + + +@dataclass +class ImagesPoseEstimationPrediction(ImagesPredictions): + """Object wrapping the list of image detection predictions. + + :attr _images_prediction_lst: List of the predictions results + """ + + _images_prediction_lst: List[ImagePoseEstimationPrediction] + + def show( + self, + edge_colors=None, + joint_thickness: int = 2, + keypoint_colors=None, + keypoint_radius: int = 5, + box_thickness: int = 2, + show_confidence: bool = False, + ) -> None: + """Display the predicted bboxes on the images. + + :param edge_colors: Optional list of tuples representing the colors for each joint. + If None, default colors are used. + If not None the length must be equal to the number of joint links in the skeleton. + :param joint_thickness: Thickness of the joint links (in pixels). + :param keypoint_colors: Optional list of tuples representing the colors for each keypoint. + If None, default colors are used. + If not None the length must be equal to the number of joints in the skeleton. + :param keypoint_radius: Radius of the keypoints (in pixels). + :param show_confidence: Whether to show confidence scores on the image. + :param box_thickness: Thickness of bounding boxes. + """ + for prediction in self._images_prediction_lst: + prediction.show( + edge_colors=edge_colors, + joint_thickness=joint_thickness, + keypoint_colors=keypoint_colors, + keypoint_radius=keypoint_radius, + box_thickness=box_thickness, + show_confidence=show_confidence, + ) + + def save( + self, + output_folder: str, + edge_colors=None, + joint_thickness: int = 2, + keypoint_colors=None, + keypoint_radius: int = 5, + box_thickness: int = 2, + show_confidence: bool = False, + ) -> None: + """Save the predicted bboxes on the images. + + :param output_folder: Folder path, where the images will be saved. + :param edge_colors: Optional list of tuples representing the colors for each joint. + If None, default colors are used. + If not None the length must be equal to the number of joint links in the skeleton. + :param joint_thickness: Thickness of the joint links (in pixels). + :param keypoint_colors: Optional list of tuples representing the colors for each keypoint. + If None, default colors are used. + If not None the length must be equal to the number of joints in the skeleton. + :param keypoint_radius: Radius of the keypoints (in pixels). + :param show_confidence: Whether to show confidence scores on the image. + :param box_thickness: Thickness of bounding boxes. + """ + if output_folder: + os.makedirs(output_folder, exist_ok=True) + + for i, prediction in enumerate(self._images_prediction_lst): + image_output_path = os.path.join(output_folder, f"pred_{i}.jpg") + prediction.save( + output_path=image_output_path, + edge_colors=edge_colors, + joint_thickness=joint_thickness, + keypoint_colors=keypoint_colors, + keypoint_radius=keypoint_radius, + box_thickness=box_thickness, + show_confidence=show_confidence, + ) + + +@dataclass +class VideoPoseEstimationPrediction(VideoPredictions): + """Object wrapping the list of image detection predictions as a Video. + + :attr _images_prediction_lst: List of the predictions results + :att fps: Frames per second of the video + """ + + _images_prediction_lst: List[ImagePoseEstimationPrediction] + fps: int + + def draw( + self, + edge_colors=None, + joint_thickness: int = 2, + keypoint_colors=None, + keypoint_radius: int = 5, + box_thickness: int = 2, + show_confidence: bool = False, + ) -> List[np.ndarray]: + """Draw the predicted bboxes on the images. + + :param output_folder: Folder path, where the images will be saved. + :param edge_colors: Optional list of tuples representing the colors for each joint. + If None, default colors are used. + If not None the length must be equal to the number of joint links in the skeleton. + :param joint_thickness: Thickness of the joint links (in pixels). + :param keypoint_colors: Optional list of tuples representing the colors for each keypoint. + If None, default colors are used. + If not None the length must be equal to the number of joints in the skeleton. + :param keypoint_radius: Radius of the keypoints (in pixels). + :param show_confidence: Whether to show confidence scores on the image. + :param box_thickness: Thickness of bounding boxes. + + :return: List of images with predicted bboxes. Note that this does not modify the original image. + """ + frames_with_bbox = [ + result.draw( + edge_colors=edge_colors, + joint_thickness=joint_thickness, + keypoint_colors=keypoint_colors, + keypoint_radius=keypoint_radius, + box_thickness=box_thickness, + show_confidence=show_confidence, + ) + for result in self._images_prediction_lst + ] + return frames_with_bbox + + def show( + self, + edge_colors=None, + joint_thickness: int = 2, + keypoint_colors=None, + keypoint_radius: int = 5, + box_thickness: int = 2, + show_confidence: bool = False, + ) -> None: + """Display the predicted bboxes on the images. + + :param edge_colors: Optional list of tuples representing the colors for each joint. + If None, default colors are used. + If not None the length must be equal to the number of joint links in the skeleton. + :param joint_thickness: Thickness of the joint links (in pixels). + :param keypoint_colors: Optional list of tuples representing the colors for each keypoint. + If None, default colors are used. + If not None the length must be equal to the number of joints in the skeleton. + :param keypoint_radius: Radius of the keypoints (in pixels). + :param show_confidence: Whether to show confidence scores on the image. + :param box_thickness: Thickness of bounding boxes. + """ + frames = self.draw( + edge_colors=edge_colors, + joint_thickness=joint_thickness, + keypoint_colors=keypoint_colors, + keypoint_radius=keypoint_radius, + box_thickness=box_thickness, + show_confidence=show_confidence, + ) + show_video_from_frames(window_name="Pose Estimation", frames=frames, fps=self.fps) + + def save( + self, + output_path: str, + edge_colors=None, + joint_thickness: int = 2, + keypoint_colors=None, + keypoint_radius: int = 5, + box_thickness: int = 2, + show_confidence: bool = False, + ) -> None: + """Save the predicted bboxes on the images. + + :param output_path: Path to the output video file. + :param edge_colors: Optional list of tuples representing the colors for each joint. + If None, default colors are used. + If not None the length must be equal to the number of joint links in the skeleton. + :param joint_thickness: Thickness of the joint links (in pixels). + :param keypoint_colors: Optional list of tuples representing the colors for each keypoint. + If None, default colors are used. + If not None the length must be equal to the number of joints in the skeleton. + :param keypoint_radius: Radius of the keypoints (in pixels). + :param show_confidence: Whether to show confidence scores on the image. + :param box_thickness: Thickness of bounding boxes. + """ + frames = self.draw( + edge_colors=edge_colors, + joint_thickness=joint_thickness, + keypoint_colors=keypoint_colors, + keypoint_radius=keypoint_radius, + box_thickness=box_thickness, + show_confidence=show_confidence, + ) + save_video(output_path=output_path, frames=frames, fps=self.fps) diff --git a/src/super_gradients/training/models/prediction_results.py b/src/super_gradients/training/utils/predict/prediction_results.py similarity index 99% rename from src/super_gradients/training/models/prediction_results.py rename to src/super_gradients/training/utils/predict/prediction_results.py index 92d70bb6cc..e92bc7e27b 100644 --- a/src/super_gradients/training/models/prediction_results.py +++ b/src/super_gradients/training/utils/predict/prediction_results.py @@ -5,7 +5,7 @@ import numpy as np -from super_gradients.training.models.predictions import Prediction, DetectionPrediction +from .predictions import Prediction, DetectionPrediction from super_gradients.training.utils.media.video import show_video_from_frames, save_video from super_gradients.training.utils.media.image import show_image, save_image from super_gradients.training.utils.visualization.utils import generate_color_mapping diff --git a/src/super_gradients/training/utils/predict/predictions.py b/src/super_gradients/training/utils/predict/predictions.py new file mode 100644 index 0000000000..c88324d10d --- /dev/null +++ b/src/super_gradients/training/utils/predict/predictions.py @@ -0,0 +1,108 @@ +from typing import Tuple +from abc import ABC +from dataclasses import dataclass + +import numpy as np + +from super_gradients.common.factories.bbox_format_factory import BBoxFormatFactory +from super_gradients.training.datasets.data_formats.bbox_formats import convert_bboxes + + +@dataclass +class Prediction(ABC): + pass + + +@dataclass +class DetectionPrediction(Prediction): + """Represents a detection prediction, with bboxes represented in xyxy format.""" + + bboxes_xyxy: np.ndarray + confidence: np.ndarray + labels: np.ndarray + + def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]): + """ + :param bboxes: BBoxes in the format specified by bbox_format + :param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...) + :param confidence: Confidence scores for each bounding box + :param labels: Labels for each bounding box. + :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format + """ + self._validate_input(bboxes, confidence, labels) + + factory = BBoxFormatFactory() + bboxes_xyxy = convert_bboxes( + bboxes=bboxes, + image_shape=image_shape, + source_format=factory.get(bbox_format), + target_format=factory.get("xyxy"), + inplace=False, + ) + + self.bboxes_xyxy = bboxes_xyxy + self.confidence = confidence + self.labels = labels + + def _validate_input(self, bboxes: np.ndarray, confidence: np.ndarray, labels: np.ndarray) -> None: + n_bboxes, n_confidences, n_labels = bboxes.shape[0], confidence.shape[0], labels.shape[0] + if n_bboxes != n_confidences != n_labels: + raise ValueError( + f"The number of bounding boxes ({n_bboxes}) does not match the number of confidence scores ({n_confidences}) and labels ({n_labels})." + ) + + def __len__(self): + return len(self.bboxes_xyxy) + + +@dataclass +class PoseEstimationPrediction(Prediction): + """Represents a pose estimation prediction. + + :attr poses: Numpy array of [Num Poses, Num Joints, 2] shape + :attr scores: Numpy array of [Num Poses] shape + """ + + poses: np.ndarray + scores: np.ndarray + edge_links: np.ndarray + edge_colors: np.ndarray + keypoint_colors: np.ndarray + image_shape: Tuple[int, int] + + def __init__( + self, + poses: np.ndarray, + scores: np.ndarray, + edge_links: np.ndarray, + edge_colors: np.ndarray, + keypoint_colors: np.ndarray, + image_shape: Tuple[int, int], + ): + """ + :param poses: + :param scores: + :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format + """ + self._validate_input(poses, scores, edge_links, edge_colors, keypoint_colors) + self.poses = poses + self.scores = scores + self.edge_links = edge_links + self.edge_colors = edge_colors + self.image_shape = image_shape + self.keypoint_colors = keypoint_colors + + def _validate_input(self, poses: np.ndarray, scores: np.ndarray, edge_links, edge_colors, keypoint_colors) -> None: + if not isinstance(poses, np.ndarray): + raise ValueError(f"Argument poses must be a numpy array, not {type(poses)}") + if not isinstance(scores, np.ndarray): + raise ValueError(f"Argument scores must be a numpy array, not {type(scores)}") + if not isinstance(keypoint_colors, np.ndarray): + raise ValueError(f"Argument keypoint_colors must be a numpy array, not {type(keypoint_colors)}") + if len(poses) != len(scores) != len(keypoint_colors): + raise ValueError(f"The number of poses ({len(poses)}) does not match the number of scores ({len(scores)}).") + if len(edge_links) != len(edge_colors): + raise ValueError(f"The number of joint links ({len(edge_links)}) does not match the number of joint colors ({len(edge_colors)}).") + + def __len__(self): + return len(self.poses) diff --git a/src/super_gradients/training/utils/visualization/pose_estimation.py b/src/super_gradients/training/utils/visualization/pose_estimation.py new file mode 100644 index 0000000000..7bdc58e26b --- /dev/null +++ b/src/super_gradients/training/utils/visualization/pose_estimation.py @@ -0,0 +1,64 @@ +from typing import Union, List, Tuple + +import cv2 +import numpy as np + +from super_gradients.training.utils.visualization.detection import draw_bbox + + +def draw_skeleton( + image: np.ndarray, + keypoints: np.ndarray, + score: float, + edge_links: np.ndarray, + edge_colors: Union[None, np.ndarray, List[Tuple[int, int, int]]], + joint_thickness: int, + keypoint_colors: Union[None, np.ndarray, List[Tuple[int, int, int]]], + keypoint_radius: int, + show_confidence: bool, + box_thickness: int, +): + """ + Draw a skeleton on an image. + + :param image: Input image (will not be modified) + :param keypoints: Array of [Num Joints, 2] or [Num Joints, 3] containing the keypoints to draw. + First two values are the x and y coordinates, the third (optional, not used) is the confidence score. + :param score: Confidence score of the whole pose + :param edge_links: Array of [Num Links, 2] containing the links between joints to draw. + :param edge_colors: Array of shape [Num Links, 3] or list of tuples containing the (r,g,b) colors for each joint link. + :param joint_thickness: Thickness of the joint links + :param keypoint_colors: Array of shape [Num Joints, 3] or list of tuples containing the (r,g,b) colors for each keypoint. + :param keypoint_radius: Radius of the keypoints (in pixels) + :param show_confidence: Whether to show the bounding box around the pose and confidence score on top of it. + :param box_thickness: Thickness of bounding boxes. + + :return: A new image with the skeleton drawn on it + """ + image = image.copy() + if edge_colors is None: + edge_colors = [(255, 0, 255)] * len(edge_links) + + if keypoint_colors is None: + keypoint_colors = [(0, 255, 0)] * len(keypoints) + + if len(edge_links) != len(edge_colors): + raise ValueError("edge_colors and edge_links must have the same length") + + keypoints = keypoints[..., 0:2].astype(int) + + for keypoint, color in zip(keypoints, keypoint_colors): + color = tuple(map(int, color)) + cv2.circle(image, tuple(keypoint[:2]), radius=keypoint_radius, color=color, thickness=-1, lineType=cv2.LINE_AA) + + for joint, color in zip(edge_links, edge_colors): + p1 = tuple(keypoints[joint[0]][:2]) + p2 = tuple(keypoints[joint[1]][:2]) + color = tuple(map(int, color)) + cv2.line(image, p1, p2, color=color, thickness=joint_thickness, lineType=cv2.LINE_AA) + + if show_confidence: + x, y, w, h = cv2.boundingRect(keypoints) + image = draw_bbox(image, title=f"{score:.2f}", box_thickness=box_thickness, color=(255, 0, 255), x1=x, y1=y, x2=x + w, y2=y + h) + + return image diff --git a/tests/integration_tests/pose_estimation_models_test.py b/tests/integration_tests/pose_estimation_models_test.py index eceb21afb4..f6a5029d5f 100644 --- a/tests/integration_tests/pose_estimation_models_test.py +++ b/tests/integration_tests/pose_estimation_models_test.py @@ -15,26 +15,7 @@ class PoseEstimationModelsIntegrationTest(unittest.TestCase): def setUp(self): self.oks_sigmas = [0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 1.007, 1.007, 0.087, 0.087, 0.089, 0.089] - self.flip_indexes_heatmap = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 17] - self.flip_indexes_offset = [ - 0, - 2, - 1, - 4, - 3, - 6, - 5, - 8, - 7, - 10, - 9, - 12, - 11, - 14, - 13, - 16, - 15, - ] + self.flip_indexes = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] def test_dekr_model(self): val_loader = get_data_loader( @@ -77,7 +58,7 @@ def test_dekr_model_with_tta(self): ) model = models.get("dekr_w32_no_dc", pretrained_weights="coco_pose") - model = DEKRHorisontalFlipWrapper(model, self.flip_indexes_heatmap, self.flip_indexes_offset, apply_sigmoid=True).cuda().eval() + model = DEKRHorisontalFlipWrapper(model, self.flip_indexes, apply_sigmoid=True).cuda().eval() post_prediction_callback = DEKRPoseEstimationDecodeCallback( output_stride=4, max_num_people=30, apply_sigmoid=False, keypoint_threshold=0.05, nms_threshold=0.05, nms_num_threshold=8 @@ -107,7 +88,7 @@ def test_dekr_model_with_rescoring(self): ) model = models.get("dekr_w32_no_dc", pretrained_weights="coco_pose") - model = DEKRHorisontalFlipWrapper(model, self.flip_indexes_heatmap, self.flip_indexes_offset, apply_sigmoid=True).cuda().eval() + model = DEKRHorisontalFlipWrapper(model, self.flip_indexes, apply_sigmoid=True).cuda().eval() rescoring = models.get("pose_rescoring_coco", pretrained_weights="coco_pose").cuda().eval() diff --git a/tests/unit_tests/pose_estimation_dataset_test.py b/tests/unit_tests/pose_estimation_dataset_test.py index b27a9c87a7..4ff53d4541 100644 --- a/tests/unit_tests/pose_estimation_dataset_test.py +++ b/tests/unit_tests/pose_estimation_dataset_test.py @@ -1,7 +1,13 @@ +import os.path import unittest + import numpy as np import torch +from super_gradients.common.object_names import Models +from super_gradients.module_interfaces import HasPredict +from super_gradients.training import models +from super_gradients.training.dataloaders import coco2017_pose_val from super_gradients.training.datasets.pose_estimation_datasets import DEKRTargetsGenerator @@ -28,3 +34,14 @@ def test_dekr_target_generator(self): self.assertEqual(mask.shape, (18, 64, 64)) self.assertEqual(offset_map.shape, (34, 64, 64)) self.assertEqual(offset_weight.shape, (34, 64, 64)) + + def test_get_dataset_preprocessing_params(self): + data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "coco2017") + + loader = coco2017_pose_val(dataset_params={"target_generator": None, "data_dir": data_dir, "json_file": "annotations/person_keypoints_val2017.json"}) + preprocessing_params = loader.dataset.get_dataset_preprocessing_params() + self.assertIsNotNone(preprocessing_params) + + dekr: HasPredict = models.get(Models.DEKR_W32_NO_DC, pretrained_weights="coco_pose") + dekr.set_dataset_processing_params(**preprocessing_params) + dekr.predict(np.zeros((640, 640, 3), dtype=np.uint8))