From 4431340b566ca16a7690014175984cfb2da1305d Mon Sep 17 00:00:00 2001 From: Oleg Abramov Date: Tue, 28 Nov 2023 17:58:20 +0600 Subject: [PATCH] Fix regular detector obj default conf filter threshold (#570) * Increase the all-classes default conf threshold val * Add warnings to detector output objs config * Remove heuristic for num_detected_classes * Clear up comment --- savant/deepstream/nvinfer/element_config.py | 68 ++++++++++++--------- savant/deepstream/nvinfer/file_config.py | 2 +- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/savant/deepstream/nvinfer/element_config.py b/savant/deepstream/nvinfer/element_config.py index d4ddb5b7..458b28c0 100644 --- a/savant/deepstream/nvinfer/element_config.py +++ b/savant/deepstream/nvinfer/element_config.py @@ -348,36 +348,48 @@ def process(self, msg, kwargs): # model type-specific parameters if issubclass(model_type, ObjectModel): - if not model_config.output.objects: - # try to load model object labels from file - label_file = model_config.get( - 'label_file', - nvinfer_config['property'].get('labelfile-path') - if nvinfer_config - else None, - ) + # model_config.output.objects is mandatory for object models + # but it may be autogenerated based on labelfile or num_detected_classes + + label_file = model_config.get( + 'label_file', + nvinfer_config['property'].get('labelfile-path') + if nvinfer_config + else None, + ) + if model_config.output.objects: + # highest priority is using manually defined model_config.output.objects if label_file: - label_file_path = model_path / label_file - if label_file_path.is_file(): - with open(label_file_path.resolve(), encoding='utf8') as file_obj: - model_config.output.objects = [ - NvInferObjectModelOutputObject( - class_id=class_id, label=label - ) - for class_id, label in enumerate( - file_obj.read().splitlines() - ) - ] - model_config.output.num_detected_classes = len( - model_config.output.objects - ) - logger.info( - 'Model object labels have been loaded from "%s".', - label_file_path, - ) + logger.warning( + 'Model output objects labels are defined manually ' + 'and will be used instead of labels from "%s".', + label_file, + ) - # generate labels (enumerate) - if not model_config.output.objects and model_config.output.num_detected_classes: + elif label_file: + # try to load model object labels from file + label_file_path = model_path / label_file + if label_file_path.is_file(): + with open(label_file_path.resolve(), encoding='utf8') as file_obj: + model_config.output.objects = [ + NvInferObjectModelOutputObject(class_id=class_id, label=label) + for class_id, label in enumerate(file_obj.read().splitlines()) + ] + if model_config.output.num_detected_classes: + logger.warning( + 'Ignoring manually set value for ' + '(model_config.output.num_detected_classes) ' + 'because labelfile is used.' + ) + model_config.output.num_detected_classes = len( + model_config.output.objects + ) + logger.info( + 'Model object labels have been loaded from "%s".', + label_file_path, + ) + elif model_config.output.num_detected_classes: + # generate labels (enumerate) model_config.output.objects = [ NvInferObjectModelOutputObject(class_id=class_id, label=str(class_id)) for class_id in range(model_config.output.num_detected_classes) diff --git a/savant/deepstream/nvinfer/file_config.py b/savant/deepstream/nvinfer/file_config.py index 9a4678e5..ec6da24a 100644 --- a/savant/deepstream/nvinfer/file_config.py +++ b/savant/deepstream/nvinfer/file_config.py @@ -272,7 +272,7 @@ def from_model( # setup class-attrs for object model (detector) # set a high confidence threshold initially for all classes # to filter out only the desired classes - config['class-attrs-all'] = {'pre-cluster-threshold': 1.1} + config['class-attrs-all'] = {'pre-cluster-threshold': 1e10} # replace class-attrs parameters with selector kwargs for obj in model_config.output.objects: class_attrs = {}