Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 901 extreme batch visualization for object detection #1339

Merged
merged 3 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 163 additions & 2 deletions src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import signal
import time
from abc import ABC, abstractmethod
from typing import List, Union, Optional, Sequence, Mapping
from typing import List, Union, Optional, Sequence, Mapping, Tuple

import csv
import cv2
Expand All @@ -27,7 +27,7 @@
from super_gradients.common.sg_loggers.time_units import GlobalBatchStepNumber, EpochNumber
from super_gradients.training.utils import get_param
from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback
from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback
from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback, cxcywh2xyxy, xyxy2cxcywh
from super_gradients.training.utils.distributed_training_utils import maybe_all_reduce_tensor_average, maybe_all_gather_np_images
from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization
from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path
Expand Down Expand Up @@ -1112,6 +1112,167 @@ def _is_more_extreme(self, score: float) -> bool:
return self.extreme_score > score


@register_callback("ExtremeBatchDetectionVisualizationCallback")
class ExtremeBatchDetectionVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
"""
ExtremeBatchSegVisualizationCallback

Visualizes worst/best batch in an epoch for Object detection.
For clarity, the batch is saved twice in the SG Logger, once with the model's predictions and once with
ground truth targets.

Assumptions on bbox dormats:
- After applying post_prediction_callback on context.preds, the predictions are a list/Tensor s.t:
predictions[i] is a tensor of shape nx6 - (x1, y1, x2, y2, confidence, class) where x and y are in pixel units.

- context.targets is a tensor of shape (total_num_targets, 6), in LABEL_CXCYWH format: (index, label, cx, cy, w, h).



Example usage in Yaml config:

training_hyperparams:
phase_callbacks:
- ExtremeBatchDetectionVisualizationCallback:
metric:
DetectionMetrics_050:
score_thres: 0.1
top_k_predictions: 300
num_cls: ${num_classes}
normalize_targets: True
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
score_threshold: 0.01
nms_top_k: 1000
max_predictions: 300
nms_threshold: 0.7
metric_component_name: '[email protected]'
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
score_threshold: 0.25
nms_top_k: 1000
max_predictions: 300
nms_threshold: 0.7
normalize_targets: True

:param metric: Metric, will be the metric which is monitored.

:param metric_component_name: In case metric returns multiple values (as Mapping),
the value at metric.compute()[metric_component_name] will be the one monitored.

:param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:

if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.

If a single item is returned rather then a tuple:
<LOSS_CLASS.__name__>.

When there is no such attributes and criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"Loss_"<IDX>

:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).

:param freq: int, epoch frequency to perform all of the above (default=1).

:param classes: List[str], a list of class names corresponding to the class indices for display.
When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does
not exist an error will be raised (default=None).

:param normalize_targets: bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader
are in pixel values range, this needs to be set to True (default=False)

"""

def __init__(
self,
post_prediction_callback: DetectionPostPredictionCallback,
metric: Optional[Metric] = None,
metric_component_name: Optional[str] = None,
loss_to_monitor: Optional[str] = None,
max: bool = False,
freq: int = 1,
classes: Optional[List[str]] = None,
normalize_targets: bool = False,
):
super(ExtremeBatchDetectionVisualizationCallback, self).__init__(
metric=metric, metric_component_name=metric_component_name, loss_to_monitor=loss_to_monitor, max=max, freq=freq
)
self.post_prediction_callback = post_prediction_callback
if classes is None:
logger.info(
"No classes have been passed to ExtremeBatchDetectionVisualizationCallback. "
"Will try to fetch them through context.valid_loader.dataset classes attribute if it exists."
)
self.classes = classes
self.normalize_targets = normalize_targets

@staticmethod
def universal_undo_preprocessing_fn(inputs: torch.Tensor) -> np.ndarray:
"""
A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg.
:param inputs:
:return:
"""
inputs -= inputs.min()
inputs /= inputs.max()
inputs *= 255
inputs = inputs.to(torch.uint8)
inputs = inputs.cpu().numpy()
inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1)
inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
return inputs

def process_extreme_batch(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Processes the extreme batch, and returns 2 image batches for visualization - one with predictions and one with GT boxes.
:return:Tuple[np.ndarray, np.ndarray], the predictions batch, the GT batch
"""
inputs = self.extreme_batch
preds = self.post_prediction_callback(self.extreme_preds, self.extreme_batch.device)
targets = self.extreme_targets.clone()
if self.normalize_targets:
target_bboxes = targets[:, 2:]
target_bboxes = cxcywh2xyxy(target_bboxes)
_, _, height, width = inputs.shape
target_bboxes[:, [0, 2]] /= width
target_bboxes[:, [1, 3]] /= height
target_bboxes = xyxy2cxcywh(target_bboxes)
targets[:, 2:] = target_bboxes

images_to_save_preds = DetectionVisualization.visualize_batch(
inputs, preds, targets, "extreme_batch_preds", self.classes, gt_alpha=0.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
)
images_to_save_preds = np.stack(images_to_save_preds)

images_to_save_gt = DetectionVisualization.visualize_batch(
inputs, None, targets, "extreme_batch_gt", self.classes, gt_alpha=1.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn
)
images_to_save_gt = np.stack(images_to_save_gt)

return images_to_save_preds, images_to_save_gt

def on_validation_loader_end(self, context: PhaseContext) -> None:
if self.classes is None:
if hasattr(context.valid_loader.dataset, "classes"):
self.classes = context.valid_loader.dataset.classes

else:
raise RuntimeError("Couldn't fetch classes from valid_loader, please pass classes explicitly")
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
if context.epoch % self.freq == 0:
images_to_save_preds, images_to_save_gt = self.process_extreme_batch()
images_to_save_preds = maybe_all_gather_np_images(images_to_save_preds)
images_to_save_gt = maybe_all_gather_np_images(images_to_save_gt)

if not context.ddp_silent_mode:
context.sg_logger.add_images(tag=f"{self._tag}_preds", images=images_to_save_preds, global_step=context.epoch, data_format="NHWC")
context.sg_logger.add_images(tag=f"{self._tag}_GT", images=images_to_save_gt, global_step=context.epoch, data_format="NHWC")

self._reset()


@register_callback("ExtremeBatchSegVisualizationCallback")
class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/super_gradients/training/utils/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def visualize_batch(
:param image_tensor: rgb images, (B, H, W, 3)
:param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6),
values on dim 1 are: x1, y1, x2, y2, confidence, class
:param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h
:param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, cx cy w h
(coordinates scaled to [0, 1])
:param batch_name: id of the current batch to use for image naming

Expand All @@ -518,6 +518,8 @@ def visualize_batch(
"""
image_np = undo_preprocessing_func(image_tensor.detach())
targets = DetectionVisualization._scaled_ccwh_to_xyxy(target_boxes.detach().cpu().numpy(), *image_np.shape[1:3], image_scale)
if pred_boxes is None:
pred_boxes = [None for _ in range(image_np.shape[0])]

out_images = []
for i in range(image_np.shape[0]):
Expand Down
86 changes: 71 additions & 15 deletions tests/unit_tests/extreme_batch_cb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,27 @@
from super_gradients import Trainer
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.dataloaders.dataloaders import segmentation_test_dataloader
from super_gradients.training.dataloaders.dataloaders import segmentation_test_dataloader, detection_test_dataloader
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.losses.ddrnet_loss import DDRNetLoss
from super_gradients.training.metrics import IoU
from super_gradients.training.utils.callbacks.callbacks import ExtremeBatchSegVisualizationCallback
from super_gradients.training.metrics import IoU, DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from super_gradients.training.utils.callbacks.callbacks import ExtremeBatchSegVisualizationCallback, ExtremeBatchDetectionVisualizationCallback


# Helper method to set up Trainer and model with common parameters
def setup_trainer_and_model(experiment_name: str):
def setup_trainer_and_model_seg(experiment_name: str):
trainer = Trainer(experiment_name)
model = models.get(Models.DDRNET_23, arch_params={"use_aux_heads": True}, pretrained_weights="cityscapes")
return trainer, model


def setup_trainer_and_model_detection(experiment_name: str):
trainer = Trainer(experiment_name)
model = models.get(Models.YOLO_NAS_S, num_classes=1)
return trainer, model


class DummyIOU(IoU):
"""
Metric for testing the segmentation callback works with compound metrics
Expand All @@ -28,13 +36,12 @@ def compute(self):
class ExtremeBatchSanityTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.training_params = {
cls.seg_training_params = {
"max_epochs": 3,
"initial_lr": 1e-2,
"loss": DDRNetLoss(),
"lr_mode": "poly",
"ema": True,
"average_best_models": True,
"optimizer": "SGD",
"mixed_precision": False,
"optimizer_params": {"weight_decay": 5e-4, "momentum": 0.9},
Expand All @@ -45,25 +52,74 @@ def setUpClass(cls):
"greater_metric_to_watch_is_better": True,
}

cls.od_training_params = {
"max_epochs": 3,
"initial_lr": 1e-2,
"loss": PPYoloELoss(num_classes=1, use_static_assigner=False, reg_max=16),
"lr_mode": "poly",
"ema": True,
"optimizer": "SGD",
"mixed_precision": False,
"optimizer_params": {"weight_decay": 5e-4, "momentum": 0.9},
"load_opt_params": False,
"valid_metrics_list": [
DetectionMetrics_050(
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65),
num_cls=1,
)
],
"train_metrics_list": [],
"metric_to_watch": "[email protected]",
"greater_metric_to_watch_is_better": True,
}

def test_detection_extreme_batch_with_metric_sanity(self):
trainer, model = setup_trainer_and_model_detection("test_detection_extreme_batch_with_metric_sanity")
self.od_training_params["phase_callbacks"] = [
ExtremeBatchDetectionVisualizationCallback(
classes=["1"],
metric=DetectionMetrics_050(
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65),
num_cls=1,
),
metric_component_name="[email protected]",
post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65),
)
]
trainer.train(model=model, training_params=self.od_training_params, train_loader=detection_test_dataloader(), valid_loader=detection_test_dataloader())

def test_detection_extreme_batch_with_loss_sanity(self):
trainer, model = setup_trainer_and_model_detection("test_detection_extreme_batch_with_loss_sanity")
self.od_training_params["phase_callbacks"] = [
ExtremeBatchDetectionVisualizationCallback(
classes=["1"],
loss_to_monitor="PPYoloELoss/loss_cls",
post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65),
)
]
trainer.train(model=model, training_params=self.od_training_params, train_loader=detection_test_dataloader(), valid_loader=detection_test_dataloader())

def test_segmentation_extreme_batch_with_metric_sanity(self):
trainer, model = setup_trainer_and_model("test_segmentation_extreme_batch_with_metric_sanity")
self.training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(IoU(5))]
trainer, model = setup_trainer_and_model_seg("test_segmentation_extreme_batch_with_metric_sanity")
self.seg_training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(IoU(5))]
trainer.train(
model=model, training_params=self.training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader()
model=model, training_params=self.seg_training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader()
)

def test_segmentation_extreme_batch_with_compound_metric_sanity(self):
trainer, model = setup_trainer_and_model("test_segmentation_extreme_batch_with_compound_metric_sanity")
self.training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(DummyIOU(5), metric_component_name="diou_minus")]
trainer, model = setup_trainer_and_model_seg("test_segmentation_extreme_batch_with_compound_metric_sanity")
self.seg_training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(DummyIOU(5), metric_component_name="diou_minus")]
trainer.train(
model=model, training_params=self.training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader()
model=model, training_params=self.seg_training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader()
)

def test_segmentation_extreme_batch_with_loss_sanity(self):
trainer, model = setup_trainer_and_model("test_segmentation_extreme_batch_with_loss_sanity")
self.training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(loss_to_monitor="DDRNetLoss/aux_loss1")]
trainer, model = setup_trainer_and_model_seg("test_segmentation_extreme_batch_with_loss_sanity")
self.seg_training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(loss_to_monitor="DDRNetLoss/aux_loss1")]
trainer.train(
model=model, training_params=self.training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader()
model=model, training_params=self.seg_training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader()
)


Expand Down