Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added find_detection_score_threshold #973

15 changes: 12 additions & 3 deletions src/super_gradients/training/metrics/detection_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class DetectionMetrics(Metric):
:param dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step.
:param accumulate_on_cpu: Run on CPU regardless of device used in other parts.
This is to avoid "CUDA out of memory" that might happen on GPU.
:param calc_best_score_thresholds Whether to calculate the best score threshold per class
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand All @@ -43,6 +44,7 @@ def __init__(
top_k_predictions: int = 100,
dist_sync_on_step: bool = False,
accumulate_on_cpu: bool = True,
calc_best_score_thresholds: bool = False,
):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.num_cls = num_cls
Expand All @@ -61,6 +63,9 @@ def __init__(
f"F1{self._get_range_str()}": True,
}
self.component_names = list(self.greater_component_is_better.keys())
self.calc_best_score_thresholds = calc_best_score_thresholds
if self.calc_best_score_thresholds:
self.component_names += [f"Best_score_threshold_cls_{i}" for i in range(self.num_cls)]
self.components = len(self.component_names)

self.post_prediction_callback = post_prediction_callback
Expand Down Expand Up @@ -115,18 +120,19 @@ def compute(self) -> Dict[str, Union[float, torch.Tensor]]:
"""Compute the metrics for all the accumulated results.
:return: Metrics of interest
"""
mean_ap, mean_precision, mean_recall, mean_f1 = -1.0, -1.0, -1.0, -1.0
mean_ap, mean_precision, mean_recall, mean_f1, best_score_thresholds = -1.0, -1.0, -1.0, -1.0, None
accumulated_matching_info = getattr(self, f"matching_info{self._get_range_str()}")

if len(accumulated_matching_info):
matching_info_tensors = [torch.cat(x, 0) for x in list(zip(*accumulated_matching_info))]

# shape (n_class, nb_iou_thresh)
ap, precision, recall, f1, unique_classes = compute_detection_metrics(
ap, precision, recall, f1, unique_classes, best_score_thresholds = compute_detection_metrics(
*matching_info_tensors,
recall_thresholds=self.recall_thresholds,
score_threshold=self.score_threshold,
device="cpu" if self.accumulate_on_cpu else self.device,
calc_best_score_thresholds=self.calc_best_score_thresholds,
)

# Precision, recall and f1 are computed for IoU threshold range, averaged over classes
Expand All @@ -136,12 +142,15 @@ def compute(self) -> Dict[str, Union[float, torch.Tensor]]:
# MaP is averaged over IoU thresholds and over classes
mean_ap = ap.mean()

return {
output_dict = {
f"Precision{self._get_range_str()}": mean_precision,
f"Recall{self._get_range_str()}": mean_recall,
f"mAP{self._get_range_str()}": mean_ap,
f"F1{self._get_range_str()}": mean_f1,
}
if self.calc_best_score_thresholds:
output_dict.update(best_score_thresholds)
return output_dict

def _sync_dist(self, dist_sync_fn=None, process_group=None):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def compute(self) -> Dict[str, Union[float, torch.Tensor]]:
preds_scores = torch.cat([x[2].cpu() for x in predictions], dim=0)
n_targets = sum([x[3] for x in predictions])

cls_precision, _, cls_recall = compute_detection_metrics_per_cls(
cls_precision, _, cls_recall, _ = compute_detection_metrics_per_cls(
preds_matched=preds_matched,
preds_to_ignore=preds_to_ignore,
preds_scores=preds_scores,
Expand Down
49 changes: 43 additions & 6 deletions src/super_gradients/training/utils/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,7 @@ def compute_detection_metrics(
device: str,
recall_thresholds: Optional[torch.Tensor] = None,
score_threshold: Optional[float] = 0.1,
calc_best_score_thresholds: bool = False,
) -> Tuple:
"""
Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.
Expand All @@ -1073,10 +1074,14 @@ def compute_detection_metrics(
:param score_threshold: Minimum confidence score to consider a prediction for the computation of
precision, recall and f1 (not MaP)
:param device: Device
:param calc_best_score_thresholds: If True, the best confidence score threshold is computed for each class

:return:
:ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs)
:unique_classes: Vector with all unique target classes
:best_score_thresholds: dict that stores for each class the best score threshold, if
calc_best_score_thresholds is True else None

"""
preds_matched, preds_to_ignore = preds_matched.to(device), preds_to_ignore.to(device)
preds_scores, preds_cls, targets_cls = preds_scores.to(device), preds_cls.to(device), targets_cls.to(device)
Expand All @@ -1089,25 +1094,29 @@ def compute_detection_metrics(
ap = torch.zeros((n_class, nb_iou_thrs), device=device)
precision = torch.zeros((n_class, nb_iou_thrs), device=device)
recall = torch.zeros((n_class, nb_iou_thrs), device=device)
best_score_thresholds = dict() if calc_best_score_thresholds else None

for cls_i, cls in enumerate(unique_classes):
cls_preds_idx, cls_targets_idx = (preds_cls == cls), (targets_cls == cls)
cls_ap, cls_precision, cls_recall = compute_detection_metrics_per_cls(
cls_ap, cls_precision, cls_recall, cls_best_score_threshold = compute_detection_metrics_per_cls(
preds_matched=preds_matched[cls_preds_idx],
preds_to_ignore=preds_to_ignore[cls_preds_idx],
preds_scores=preds_scores[cls_preds_idx],
n_targets=cls_targets_idx.sum(),
recall_thresholds=recall_thresholds,
score_threshold=score_threshold,
device=device,
calc_best_score_threshold=calc_best_score_thresholds,
)
ap[cls_i, :] = cls_ap
precision[cls_i, :] = cls_precision
recall[cls_i, :] = cls_recall
if calc_best_score_thresholds:
best_score_thresholds[f"Best_score_threshold_cls_{int(cls)}"] = cls_best_score_threshold

f1 = 2 * precision * recall / (precision + recall + 1e-16)

return ap, precision, recall, f1, unique_classes
return ap, precision, recall, f1, unique_classes, best_score_thresholds


def compute_detection_metrics_per_cls(
Expand All @@ -1118,6 +1127,7 @@ def compute_detection_metrics_per_cls(
recall_thresholds: torch.Tensor,
score_threshold: float,
device: str,
calc_best_score_threshold: bool = False,
):
"""
Compute the list of precision, recall and MaP of a given class for every recall IoU threshold.
Expand All @@ -1134,16 +1144,20 @@ def compute_detection_metrics_per_cls(
:param score_threshold: Minimum confidence score to consider a prediction for the computation of
precision and recall (not MaP)
:param device: Device
:param calc_best_score_threshold: If True, the best confidence score threshold is computed for this class

:return ap, precision, recall: Tensors of shape (nb_iou_thrs)
:return
:ap, precision, recall: Tensors of shape (nb_iou_thrs)
:best_score_threshold: torch.float if calc_best_score_threshold is True else None
"""
nb_iou_thrs = preds_matched.shape[-1]
best_score_threshold = torch.tensor(0.0, dtype=torch.float, device=device) if calc_best_score_threshold else None

tps = preds_matched
fps = torch.logical_and(torch.logical_not(preds_matched), torch.logical_not(preds_to_ignore))

if len(tps) == 0:
return torch.zeros(nb_iou_thrs, device=device), torch.zeros(nb_iou_thrs, device=device), torch.zeros(nb_iou_thrs, device=device)
return torch.zeros(nb_iou_thrs, device=device), torch.zeros(nb_iou_thrs, device=device), torch.zeros(nb_iou_thrs, device=device), best_score_threshold

# Sort by decreasing score
dtype = torch.uint8 if preds_scores.is_cuda and preds_scores.dtype is torch.bool else preds_scores.dtype
Expand All @@ -1167,7 +1181,8 @@ def compute_detection_metrics_per_cls(

# We want the rolling precision/recall at index i so that: preds_scores[i-1] >= score_threshold > preds_scores[i]
# Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
lowest_score_above_threshold = torch.searchsorted(-preds_scores, -score_threshold, right=False)
# Note2: right=True due to negation
lowest_score_above_threshold = torch.searchsorted(-preds_scores, -score_threshold, right=True)

if lowest_score_above_threshold == 0: # Here score_threshold > preds_scores[0], so no pred is above the threshold
recall = torch.zeros(nb_iou_thrs, device=device)
Expand All @@ -1176,6 +1191,28 @@ def compute_detection_metrics_per_cls(
recall = rolling_recalls[lowest_score_above_threshold - 1]
precision = rolling_precisions[lowest_score_above_threshold - 1]

# ==================
# BEST CONFIDENCE SCORE THRESHOLD PER CLASS
if calc_best_score_threshold:
all_score_thresholds = torch.linspace(0, 1, 101, device=device)

# We want the rolling precision/recall at index i so that: preds_scores[i-1] > score_threshold >= preds_scores[i]
# Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
lowest_scores_above_thresholds = torch.searchsorted(-preds_scores, -all_score_thresholds, right=True)

# When score_threshold > preds_scores[0], then no pred is above the threshold, so we pad with zeros
rolling_recalls_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_recalls), dim=0)
rolling_precisions_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_precisions), dim=0)

# shape = (n_score_thresholds, nb_iou_thrs)
recalls_per_threshold = torch.index_select(input=rolling_recalls_padded, dim=0, index=lowest_scores_above_thresholds)
precisions_per_threshold = torch.index_select(input=rolling_precisions_padded, dim=0, index=lowest_scores_above_thresholds)

# shape (101, 10)
f1 = 2 * recalls_per_threshold * precisions_per_threshold / (recalls_per_threshold + precisions_per_threshold + 1e-16)
mean_f1 = torch.mean(f1, dim=1) # average over iou thresholds
best_score_threshold = all_score_thresholds[torch.argmax(mean_f1)]

# ==================
# AVERAGE PRECISION

Expand All @@ -1196,4 +1233,4 @@ def compute_detection_metrics_per_cls(
# Average over the recall_thresholds
ap = sampled_precision_points.mean(0)

return ap, precision, recall
return ap, precision, recall, best_score_threshold
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
49 changes: 49 additions & 0 deletions utils_script/find_detection_score_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Find the best confidence score threshold for each class in object detection tasks
Use this script when you have a trained model and want to analyze / optimize its performance
The thresholds can be used later when performing NMS
Usage is similar to src/super_gradients/examples/evaluate_from_recipe_example/evaluate_from_recipe.py

Notes:
This script does NOT run TRAINING, so make sure in the recipe that you load a PRETRAINED MODEL
either from one of your checkpoint or from a pretrained model.

General use: python find_detection_score_threshold.py --config-name="DESIRED_RECIPE" architecture="DESIRED_ARCH"
checkpoint_params.pretrained_weights="DESIRED_DATASET"

Example: python find_detection_score_threshold.py --config-name=coco2017_yolox architecture=yolox_n
checkpoint_params.pretrained_weights=coco
"""

import hydra
import pkg_resources

from omegaconf import DictConfig

from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
from super_gradients.common.environment.cfg_utils import add_params_to_cfg
from super_gradients import Trainer, init_trainer


@hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", ""), version_base="1.2")
def main(cfg: DictConfig) -> None:
add_params_to_cfg(cfg.training_hyperparams.valid_metrics_list[0].DetectionMetrics, params=["calc_best_score_thresholds=True"])
_, valid_metrics_dict = Trainer.evaluate_from_recipe(cfg)

class_names = COCO_DETECTION_CLASSES_LIST # change this line to use a different dataset
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
prefix = "Best_score_threshold_cls_"
best_thresholds = {int(k[len(prefix) :]): v for k, v in valid_metrics_dict.items() if k.startswith(prefix)}
assert len(best_thresholds) == len(class_names)
print("-----Best_score_thresholds-----")
max_class_name = max(len(class_name) for class_name in class_names)
for k, v in best_thresholds.items():
print(f"{class_names[k]:<{max_class_name}} (class {k}):\t{v:.2f}")


def run():
init_trainer()
main()


if __name__ == "__main__":
run()