diff --git a/src/super_gradients/scripts/__init__.py b/src/super_gradients/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/super_gradients/scripts/find_detection_score_threshold.py b/src/super_gradients/scripts/find_detection_score_threshold.py new file mode 100644 index 0000000000..d929b71fb7 --- /dev/null +++ b/src/super_gradients/scripts/find_detection_score_threshold.py @@ -0,0 +1,52 @@ +""" +Find the best confidence score threshold for each class in object detection tasks +Use this script when you have a trained model and want to analyze / optimize its performance +The thresholds can be used later when performing NMS +Usage is similar to src/super_gradients/evaluate_from_recipe.py + +Notes: + This script does NOT run TRAINING, so make sure in the recipe that you load a PRETRAINED MODEL + either from one of your checkpoint or from a pretrained model. + +General use: python -m super_gradients.scripts.find_detection_score_threshold --config-name="DESIRED_RECIPE" architecture="DESIRED_ARCH" + checkpoint_params.pretrained_weights="DESIRED_DATASET" + +Example: python -m super_gradients.scripts.find_detection_score_threshold --config-name=coco2017_yolox architecture=yolox_n + checkpoint_params.pretrained_weights=coco +""" + +import hydra +import pkg_resources +from omegaconf import DictConfig + +from super_gradients.training.dataloaders import dataloaders +from super_gradients.common.environment.cfg_utils import add_params_to_cfg +from super_gradients import Trainer, init_trainer + + +@hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", ""), version_base="1.2") +def main(cfg: DictConfig) -> None: + add_params_to_cfg(cfg.training_hyperparams.valid_metrics_list[0].DetectionMetrics, params=["calc_best_score_thresholds=True"]) + _, valid_metrics_dict = Trainer.evaluate_from_recipe(cfg) + + # INSTANTIATE DATA LOADERS + val_dataloader = dataloaders.get(name=cfg.val_dataloader, dataset_params={}, dataloader_params={"num_workers": 2}) + class_names = val_dataloader.dataset.classes + prefix = "Best_score_threshold_cls_" + best_thresholds = {int(k[len(prefix) :]): v for k, v in valid_metrics_dict.items() if k.startswith(prefix)} + assert len(best_thresholds) == len(class_names) + print("-----Best_score_thresholds-----") + print(f"Best score threshold overall: {valid_metrics_dict['Best_score_threshold']:.2f}") + print("Best score thresholds per class:") + max_class_name = max(len(class_name) for class_name in class_names) + for k, v in best_thresholds.items(): + print(f"{class_names[k]:<{max_class_name}} (class {k}):\t{v:.2f}") + + +def run(): + init_trainer() + main() + + +if __name__ == "__main__": + run() diff --git a/src/super_gradients/training/metrics/detection_metrics.py b/src/super_gradients/training/metrics/detection_metrics.py index 309a40f1d9..c51d1b1c1d 100755 --- a/src/super_gradients/training/metrics/detection_metrics.py +++ b/src/super_gradients/training/metrics/detection_metrics.py @@ -31,6 +31,10 @@ class DetectionMetrics(Metric): :param dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. :param accumulate_on_cpu: Run on CPU regardless of device used in other parts. This is to avoid "CUDA out of memory" that might happen on GPU. + :param calc_best_score_thresholds Whether to calculate the best score threshold overall and per class + If True, the compute() function will return a metrics dictionary that not + only includes the average metrics calculated across all classes, + but also the optimal score threshold overall and for each individual class. """ def __init__( @@ -44,6 +48,7 @@ def __init__( top_k_predictions: int = 100, dist_sync_on_step: bool = False, accumulate_on_cpu: bool = True, + calc_best_score_thresholds: bool = False, ): super().__init__(dist_sync_on_step=dist_sync_on_step) self.num_cls = num_cls @@ -65,6 +70,10 @@ def __init__( f"F1{self._get_range_str()}": True, } self.component_names = list(self.greater_component_is_better.keys()) + self.calc_best_score_thresholds = calc_best_score_thresholds + if self.calc_best_score_thresholds: + self.component_names.append("Best_score_threshold") + self.component_names += [f"Best_score_threshold_cls_{i}" for i in range(self.num_cls)] self.components = len(self.component_names) self.post_prediction_callback = post_prediction_callback @@ -119,18 +128,19 @@ def compute(self) -> Dict[str, Union[float, torch.Tensor]]: """Compute the metrics for all the accumulated results. :return: Metrics of interest """ - mean_ap, mean_precision, mean_recall, mean_f1 = -1.0, -1.0, -1.0, -1.0 + mean_ap, mean_precision, mean_recall, mean_f1, best_score_threshold, best_score_threshold_per_cls = -1.0, -1.0, -1.0, -1.0, -1.0, None accumulated_matching_info = getattr(self, f"matching_info{self._get_range_str()}") if len(accumulated_matching_info): matching_info_tensors = [torch.cat(x, 0) for x in list(zip(*accumulated_matching_info))] # shape (n_class, nb_iou_thresh) - ap, precision, recall, f1, unique_classes = compute_detection_metrics( + ap, precision, recall, f1, unique_classes, best_score_threshold, best_score_threshold_per_cls = compute_detection_metrics( *matching_info_tensors, recall_thresholds=self.recall_thresholds, score_threshold=self.score_threshold, device="cpu" if self.accumulate_on_cpu else self.device, + calc_best_score_thresholds=self.calc_best_score_thresholds, ) # Precision, recall and f1 are computed for IoU threshold range, averaged over classes @@ -140,12 +150,16 @@ def compute(self) -> Dict[str, Union[float, torch.Tensor]]: # MaP is averaged over IoU thresholds and over classes mean_ap = ap.mean() - return { + output_dict = { f"Precision{self._get_range_str()}": mean_precision, f"Recall{self._get_range_str()}": mean_recall, f"mAP{self._get_range_str()}": mean_ap, f"F1{self._get_range_str()}": mean_f1, } + if self.calc_best_score_thresholds: + output_dict["Best_score_threshold"] = best_score_threshold + output_dict.update(best_score_threshold_per_cls) + return output_dict def _sync_dist(self, dist_sync_fn=None, process_group=None): """ diff --git a/src/super_gradients/training/metrics/pose_estimation_metrics.py b/src/super_gradients/training/metrics/pose_estimation_metrics.py index 66860e213f..9675a0dd1d 100644 --- a/src/super_gradients/training/metrics/pose_estimation_metrics.py +++ b/src/super_gradients/training/metrics/pose_estimation_metrics.py @@ -285,7 +285,7 @@ def compute(self) -> Dict[str, Union[float, torch.Tensor]]: preds_scores = torch.cat([x[2].cpu() for x in predictions], dim=0) n_targets = sum([x[3] for x in predictions]) - cls_precision, _, cls_recall = compute_detection_metrics_per_cls( + cls_precision, _, cls_recall, _, _ = compute_detection_metrics_per_cls( preds_matched=preds_matched, preds_to_ignore=preds_to_ignore, preds_scores=preds_scores, diff --git a/src/super_gradients/training/utils/detection_utils.py b/src/super_gradients/training/utils/detection_utils.py index b77e1a8e73..ae9728ccbf 100755 --- a/src/super_gradients/training/utils/detection_utils.py +++ b/src/super_gradients/training/utils/detection_utils.py @@ -1088,6 +1088,7 @@ def compute_detection_metrics( device: str, recall_thresholds: Optional[torch.Tensor] = None, score_threshold: Optional[float] = 0.1, + calc_best_score_thresholds: bool = False, ) -> Tuple: """ Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class. @@ -1103,10 +1104,16 @@ def compute_detection_metrics( :param score_threshold: Minimum confidence score to consider a prediction for the computation of precision, recall and f1 (not MaP) :param device: Device + :param calc_best_score_thresholds: If True, the best confidence score threshold is computed for each class :return: :ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs) :unique_classes: Vector with all unique target classes + :best_score_threshold: torch.float with the best overall score threshold if calc_best_score_thresholds + is True else None + :best_score_threshold_per_cls: dict that stores the best score threshold for each class , if + calc_best_score_thresholds is True else None + """ preds_matched, preds_to_ignore = preds_matched.to(device), preds_to_ignore.to(device) preds_scores, preds_cls, targets_cls = preds_scores.to(device), preds_cls.to(device), targets_cls.to(device) @@ -1120,9 +1127,14 @@ def compute_detection_metrics( precision = torch.zeros((n_class, nb_iou_thrs), device=device) recall = torch.zeros((n_class, nb_iou_thrs), device=device) + nb_score_thrs = 101 + all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device) + f1_per_class_per_threshold = torch.zeros((n_class, nb_score_thrs), device=device) if calc_best_score_thresholds else None + best_score_threshold_per_cls = dict() if calc_best_score_thresholds else None + for cls_i, cls in enumerate(unique_classes): cls_preds_idx, cls_targets_idx = (preds_cls == cls), (targets_cls == cls) - cls_ap, cls_precision, cls_recall = compute_detection_metrics_per_cls( + cls_ap, cls_precision, cls_recall, cls_f1_per_threshold, cls_best_score_threshold = compute_detection_metrics_per_cls( preds_matched=preds_matched[cls_preds_idx], preds_to_ignore=preds_to_ignore[cls_preds_idx], preds_scores=preds_scores[cls_preds_idx], @@ -1130,14 +1142,24 @@ def compute_detection_metrics( recall_thresholds=recall_thresholds, score_threshold=score_threshold, device=device, + calc_best_score_thresholds=calc_best_score_thresholds, + nb_score_thrs=nb_score_thrs, ) ap[cls_i, :] = cls_ap precision[cls_i, :] = cls_precision recall[cls_i, :] = cls_recall + if calc_best_score_thresholds: + f1_per_class_per_threshold[cls_i, :] = cls_f1_per_threshold + best_score_threshold_per_cls[f"Best_score_threshold_cls_{int(cls)}"] = cls_best_score_threshold f1 = 2 * precision * recall / (precision + recall + 1e-16) + if calc_best_score_thresholds: + mean_f1_across_classes = torch.mean(f1_per_class_per_threshold, dim=0) + best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_across_classes)] + else: + best_score_threshold = None - return ap, precision, recall, f1, unique_classes + return ap, precision, recall, f1, unique_classes, best_score_threshold, best_score_threshold_per_cls def compute_detection_metrics_per_cls( @@ -1148,6 +1170,8 @@ def compute_detection_metrics_per_cls( recall_thresholds: torch.Tensor, score_threshold: float, device: str, + calc_best_score_thresholds: bool = False, + nb_score_thrs: int = 101, ): """ Compute the list of precision, recall and MaP of a given class for every recall IoU threshold. @@ -1164,16 +1188,30 @@ def compute_detection_metrics_per_cls( :param score_threshold: Minimum confidence score to consider a prediction for the computation of precision and recall (not MaP) :param device: Device + :param calc_best_score_thresholds: If True, the best confidence score threshold is computed for this class + :param nb_score_thrs: Number of score thresholds to consider when calc_best_score_thresholds is True - :return ap, precision, recall: Tensors of shape (nb_iou_thrs) + :return: + :ap, precision, recall: Tensors of shape (nb_iou_thrs) + :mean_f1_per_threshold: Tensor of shape (nb_score_thresholds) if calc_best_score_thresholds is True else None + :best_score_threshold: torch.float if calc_best_score_thresholds is True else None """ nb_iou_thrs = preds_matched.shape[-1] + mean_f1_per_threshold = torch.zeros(nb_score_thrs, device=device) if calc_best_score_thresholds else None + best_score_threshold = torch.tensor(0.0, dtype=torch.float, device=device) if calc_best_score_thresholds else None + tps = preds_matched fps = torch.logical_and(torch.logical_not(preds_matched), torch.logical_not(preds_to_ignore)) if len(tps) == 0: - return torch.zeros(nb_iou_thrs, device=device), torch.zeros(nb_iou_thrs, device=device), torch.zeros(nb_iou_thrs, device=device) + return ( + torch.zeros(nb_iou_thrs, device=device), + torch.zeros(nb_iou_thrs, device=device), + torch.zeros(nb_iou_thrs, device=device), + mean_f1_per_threshold, + best_score_threshold, + ) # Sort by decreasing score dtype = torch.uint8 if preds_scores.is_cuda and preds_scores.dtype is torch.bool else preds_scores.dtype @@ -1197,7 +1235,8 @@ def compute_detection_metrics_per_cls( # We want the rolling precision/recall at index i so that: preds_scores[i-1] >= score_threshold > preds_scores[i] # Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-" - lowest_score_above_threshold = torch.searchsorted(-preds_scores, -score_threshold, right=False) + # Note2: right=True due to negation + lowest_score_above_threshold = torch.searchsorted(-preds_scores, -score_threshold, right=True) if lowest_score_above_threshold == 0: # Here score_threshold > preds_scores[0], so no pred is above the threshold recall = torch.zeros(nb_iou_thrs, device=device) @@ -1206,6 +1245,28 @@ def compute_detection_metrics_per_cls( recall = rolling_recalls[lowest_score_above_threshold - 1] precision = rolling_precisions[lowest_score_above_threshold - 1] + # ================== + # BEST CONFIDENCE SCORE THRESHOLD PER CLASS + if calc_best_score_thresholds: + all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device) + + # We want the rolling precision/recall at index i so that: preds_scores[i-1] > score_threshold >= preds_scores[i] + # Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-" + lowest_scores_above_thresholds = torch.searchsorted(-preds_scores, -all_score_thresholds, right=True) + + # When score_threshold > preds_scores[0], then no pred is above the threshold, so we pad with zeros + rolling_recalls_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_recalls), dim=0) + rolling_precisions_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_precisions), dim=0) + + # shape = (n_score_thresholds, nb_iou_thrs) + recalls_per_threshold = torch.index_select(input=rolling_recalls_padded, dim=0, index=lowest_scores_above_thresholds) + precisions_per_threshold = torch.index_select(input=rolling_precisions_padded, dim=0, index=lowest_scores_above_thresholds) + + # shape (n_score_thresholds, nb_iou_thrs) + f1_per_threshold = 2 * recalls_per_threshold * precisions_per_threshold / (recalls_per_threshold + precisions_per_threshold + 1e-16) + mean_f1_per_threshold = torch.mean(f1_per_threshold, dim=1) # average over iou thresholds + best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_per_threshold)] + # ================== # AVERAGE PRECISION @@ -1226,4 +1287,4 @@ def compute_detection_metrics_per_cls( # Average over the recall_thresholds ap = sampled_precision_points.mean(0) - return ap, precision, recall + return ap, precision, recall, mean_f1_per_threshold, best_score_threshold