Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added find_detection_score_threshold #973

Empty file.
52 changes: 52 additions & 0 deletions src/super_gradients/scripts/find_detection_score_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Find the best confidence score threshold for each class in object detection tasks
Use this script when you have a trained model and want to analyze / optimize its performance
The thresholds can be used later when performing NMS
Usage is similar to src/super_gradients/evaluate_from_recipe.py

Notes:
This script does NOT run TRAINING, so make sure in the recipe that you load a PRETRAINED MODEL
either from one of your checkpoint or from a pretrained model.

General use: python -m super_gradients.scripts.find_detection_score_threshold --config-name="DESIRED_RECIPE" architecture="DESIRED_ARCH"
checkpoint_params.pretrained_weights="DESIRED_DATASET"

Example: python -m super_gradients.scripts.find_detection_score_threshold --config-name=coco2017_yolox architecture=yolox_n
checkpoint_params.pretrained_weights=coco
"""

import hydra
import pkg_resources
from omegaconf import DictConfig

from super_gradients.training.dataloaders import dataloaders
from super_gradients.common.environment.cfg_utils import add_params_to_cfg
from super_gradients import Trainer, init_trainer


@hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", ""), version_base="1.2")
def main(cfg: DictConfig) -> None:
add_params_to_cfg(cfg.training_hyperparams.valid_metrics_list[0].DetectionMetrics, params=["calc_best_score_thresholds=True"])
_, valid_metrics_dict = Trainer.evaluate_from_recipe(cfg)

# INSTANTIATE DATA LOADERS
val_dataloader = dataloaders.get(name=cfg.val_dataloader, dataset_params={}, dataloader_params={"num_workers": 2})
class_names = val_dataloader.dataset.classes
prefix = "Best_score_threshold_cls_"
best_thresholds = {int(k[len(prefix) :]): v for k, v in valid_metrics_dict.items() if k.startswith(prefix)}
assert len(best_thresholds) == len(class_names)
print("-----Best_score_thresholds-----")
print(f"Best score threshold overall: {valid_metrics_dict['Best_score_threshold']:.2f}")
print("Best score thresholds per class:")
max_class_name = max(len(class_name) for class_name in class_names)
for k, v in best_thresholds.items():
print(f"{class_names[k]:<{max_class_name}} (class {k}):\t{v:.2f}")


def run():
init_trainer()
main()


if __name__ == "__main__":
run()
20 changes: 17 additions & 3 deletions src/super_gradients/training/metrics/detection_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class DetectionMetrics(Metric):
:param dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step.
:param accumulate_on_cpu: Run on CPU regardless of device used in other parts.
This is to avoid "CUDA out of memory" that might happen on GPU.
:param calc_best_score_thresholds Whether to calculate the best score threshold overall and per class
If True, the compute() function will return a metrics dictionary that not
only includes the average metrics calculated across all classes,
but also the optimal score threshold overall and for each individual class.
"""

def __init__(
Expand All @@ -44,6 +48,7 @@ def __init__(
top_k_predictions: int = 100,
dist_sync_on_step: bool = False,
accumulate_on_cpu: bool = True,
calc_best_score_thresholds: bool = False,
):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.num_cls = num_cls
Expand All @@ -65,6 +70,10 @@ def __init__(
f"F1{self._get_range_str()}": True,
}
self.component_names = list(self.greater_component_is_better.keys())
self.calc_best_score_thresholds = calc_best_score_thresholds
if self.calc_best_score_thresholds:
self.component_names.append("Best_score_threshold")
self.component_names += [f"Best_score_threshold_cls_{i}" for i in range(self.num_cls)]
self.components = len(self.component_names)

self.post_prediction_callback = post_prediction_callback
Expand Down Expand Up @@ -119,18 +128,19 @@ def compute(self) -> Dict[str, Union[float, torch.Tensor]]:
"""Compute the metrics for all the accumulated results.
:return: Metrics of interest
"""
mean_ap, mean_precision, mean_recall, mean_f1 = -1.0, -1.0, -1.0, -1.0
mean_ap, mean_precision, mean_recall, mean_f1, best_score_threshold, best_score_threshold_per_cls = -1.0, -1.0, -1.0, -1.0, -1.0, None
accumulated_matching_info = getattr(self, f"matching_info{self._get_range_str()}")

if len(accumulated_matching_info):
matching_info_tensors = [torch.cat(x, 0) for x in list(zip(*accumulated_matching_info))]

# shape (n_class, nb_iou_thresh)
ap, precision, recall, f1, unique_classes = compute_detection_metrics(
ap, precision, recall, f1, unique_classes, best_score_threshold, best_score_threshold_per_cls = compute_detection_metrics(
*matching_info_tensors,
recall_thresholds=self.recall_thresholds,
score_threshold=self.score_threshold,
device="cpu" if self.accumulate_on_cpu else self.device,
calc_best_score_thresholds=self.calc_best_score_thresholds,
)

# Precision, recall and f1 are computed for IoU threshold range, averaged over classes
Expand All @@ -140,12 +150,16 @@ def compute(self) -> Dict[str, Union[float, torch.Tensor]]:
# MaP is averaged over IoU thresholds and over classes
mean_ap = ap.mean()

return {
output_dict = {
f"Precision{self._get_range_str()}": mean_precision,
f"Recall{self._get_range_str()}": mean_recall,
f"mAP{self._get_range_str()}": mean_ap,
f"F1{self._get_range_str()}": mean_f1,
}
if self.calc_best_score_thresholds:
output_dict["Best_score_threshold"] = best_score_threshold
output_dict.update(best_score_threshold_per_cls)
return output_dict

def _sync_dist(self, dist_sync_fn=None, process_group=None):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def compute(self) -> Dict[str, Union[float, torch.Tensor]]:
preds_scores = torch.cat([x[2].cpu() for x in predictions], dim=0)
n_targets = sum([x[3] for x in predictions])

cls_precision, _, cls_recall = compute_detection_metrics_per_cls(
cls_precision, _, cls_recall, _, _ = compute_detection_metrics_per_cls(
preds_matched=preds_matched,
preds_to_ignore=preds_to_ignore,
preds_scores=preds_scores,
Expand Down
73 changes: 67 additions & 6 deletions src/super_gradients/training/utils/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,7 @@ def compute_detection_metrics(
device: str,
recall_thresholds: Optional[torch.Tensor] = None,
score_threshold: Optional[float] = 0.1,
calc_best_score_thresholds: bool = False,
) -> Tuple:
"""
Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.
Expand All @@ -1103,10 +1104,16 @@ def compute_detection_metrics(
:param score_threshold: Minimum confidence score to consider a prediction for the computation of
precision, recall and f1 (not MaP)
:param device: Device
:param calc_best_score_thresholds: If True, the best confidence score threshold is computed for each class

:return:
:ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs)
:unique_classes: Vector with all unique target classes
:best_score_threshold: torch.float with the best overall score threshold if calc_best_score_thresholds
is True else None
:best_score_threshold_per_cls: dict that stores the best score threshold for each class , if
calc_best_score_thresholds is True else None

"""
preds_matched, preds_to_ignore = preds_matched.to(device), preds_to_ignore.to(device)
preds_scores, preds_cls, targets_cls = preds_scores.to(device), preds_cls.to(device), targets_cls.to(device)
Expand All @@ -1120,24 +1127,39 @@ def compute_detection_metrics(
precision = torch.zeros((n_class, nb_iou_thrs), device=device)
recall = torch.zeros((n_class, nb_iou_thrs), device=device)

nb_score_thrs = 101
all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device)
f1_per_class_per_threshold = torch.zeros((n_class, nb_score_thrs), device=device) if calc_best_score_thresholds else None
best_score_threshold_per_cls = dict() if calc_best_score_thresholds else None

for cls_i, cls in enumerate(unique_classes):
cls_preds_idx, cls_targets_idx = (preds_cls == cls), (targets_cls == cls)
cls_ap, cls_precision, cls_recall = compute_detection_metrics_per_cls(
cls_ap, cls_precision, cls_recall, cls_f1_per_threshold, cls_best_score_threshold = compute_detection_metrics_per_cls(
preds_matched=preds_matched[cls_preds_idx],
preds_to_ignore=preds_to_ignore[cls_preds_idx],
preds_scores=preds_scores[cls_preds_idx],
n_targets=cls_targets_idx.sum(),
recall_thresholds=recall_thresholds,
score_threshold=score_threshold,
device=device,
calc_best_score_thresholds=calc_best_score_thresholds,
nb_score_thrs=nb_score_thrs,
)
ap[cls_i, :] = cls_ap
precision[cls_i, :] = cls_precision
recall[cls_i, :] = cls_recall
if calc_best_score_thresholds:
f1_per_class_per_threshold[cls_i, :] = cls_f1_per_threshold
best_score_threshold_per_cls[f"Best_score_threshold_cls_{int(cls)}"] = cls_best_score_threshold

f1 = 2 * precision * recall / (precision + recall + 1e-16)
if calc_best_score_thresholds:
mean_f1_across_classes = torch.mean(f1_per_class_per_threshold, dim=0)
best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_across_classes)]
else:
best_score_threshold = None

return ap, precision, recall, f1, unique_classes
return ap, precision, recall, f1, unique_classes, best_score_threshold, best_score_threshold_per_cls


def compute_detection_metrics_per_cls(
Expand All @@ -1148,6 +1170,8 @@ def compute_detection_metrics_per_cls(
recall_thresholds: torch.Tensor,
score_threshold: float,
device: str,
calc_best_score_thresholds: bool = False,
nb_score_thrs: int = 101,
):
"""
Compute the list of precision, recall and MaP of a given class for every recall IoU threshold.
Expand All @@ -1164,16 +1188,30 @@ def compute_detection_metrics_per_cls(
:param score_threshold: Minimum confidence score to consider a prediction for the computation of
precision and recall (not MaP)
:param device: Device
:param calc_best_score_thresholds: If True, the best confidence score threshold is computed for this class
:param nb_score_thrs: Number of score thresholds to consider when calc_best_score_thresholds is True

:return ap, precision, recall: Tensors of shape (nb_iou_thrs)
:return:
:ap, precision, recall: Tensors of shape (nb_iou_thrs)
:mean_f1_per_threshold: Tensor of shape (nb_score_thresholds) if calc_best_score_thresholds is True else None
:best_score_threshold: torch.float if calc_best_score_thresholds is True else None
"""
nb_iou_thrs = preds_matched.shape[-1]

mean_f1_per_threshold = torch.zeros(nb_score_thrs, device=device) if calc_best_score_thresholds else None
best_score_threshold = torch.tensor(0.0, dtype=torch.float, device=device) if calc_best_score_thresholds else None

tps = preds_matched
fps = torch.logical_and(torch.logical_not(preds_matched), torch.logical_not(preds_to_ignore))

if len(tps) == 0:
return torch.zeros(nb_iou_thrs, device=device), torch.zeros(nb_iou_thrs, device=device), torch.zeros(nb_iou_thrs, device=device)
return (
torch.zeros(nb_iou_thrs, device=device),
torch.zeros(nb_iou_thrs, device=device),
torch.zeros(nb_iou_thrs, device=device),
mean_f1_per_threshold,
best_score_threshold,
)

# Sort by decreasing score
dtype = torch.uint8 if preds_scores.is_cuda and preds_scores.dtype is torch.bool else preds_scores.dtype
Expand All @@ -1197,7 +1235,8 @@ def compute_detection_metrics_per_cls(

# We want the rolling precision/recall at index i so that: preds_scores[i-1] >= score_threshold > preds_scores[i]
# Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
lowest_score_above_threshold = torch.searchsorted(-preds_scores, -score_threshold, right=False)
# Note2: right=True due to negation
lowest_score_above_threshold = torch.searchsorted(-preds_scores, -score_threshold, right=True)

if lowest_score_above_threshold == 0: # Here score_threshold > preds_scores[0], so no pred is above the threshold
recall = torch.zeros(nb_iou_thrs, device=device)
Expand All @@ -1206,6 +1245,28 @@ def compute_detection_metrics_per_cls(
recall = rolling_recalls[lowest_score_above_threshold - 1]
precision = rolling_precisions[lowest_score_above_threshold - 1]

# ==================
# BEST CONFIDENCE SCORE THRESHOLD PER CLASS
if calc_best_score_thresholds:
all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device)

# We want the rolling precision/recall at index i so that: preds_scores[i-1] > score_threshold >= preds_scores[i]
# Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
lowest_scores_above_thresholds = torch.searchsorted(-preds_scores, -all_score_thresholds, right=True)

# When score_threshold > preds_scores[0], then no pred is above the threshold, so we pad with zeros
rolling_recalls_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_recalls), dim=0)
rolling_precisions_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_precisions), dim=0)

# shape = (n_score_thresholds, nb_iou_thrs)
recalls_per_threshold = torch.index_select(input=rolling_recalls_padded, dim=0, index=lowest_scores_above_thresholds)
precisions_per_threshold = torch.index_select(input=rolling_precisions_padded, dim=0, index=lowest_scores_above_thresholds)

# shape (n_score_thresholds, nb_iou_thrs)
f1_per_threshold = 2 * recalls_per_threshold * precisions_per_threshold / (recalls_per_threshold + precisions_per_threshold + 1e-16)
mean_f1_per_threshold = torch.mean(f1_per_threshold, dim=1) # average over iou thresholds
best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_per_threshold)]

# ==================
# AVERAGE PRECISION

Expand All @@ -1226,4 +1287,4 @@ def compute_detection_metrics_per_cls(
# Average over the recall_thresholds
ap = sampled_precision_points.mean(0)

return ap, precision, recall
return ap, precision, recall, mean_f1_per_threshold, best_score_threshold