Skip to content

Commit

Permalink
Fix regular detector obj default conf filter threshold (#570)
Browse files Browse the repository at this point in the history
* Increase the all-classes default conf threshold val

* Add warnings to detector output objs config

* Remove heuristic for num_detected_classes

* Clear up comment
  • Loading branch information
abramov-oleg authored Nov 28, 2023
1 parent 13bb632 commit 4431340
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
68 changes: 40 additions & 28 deletions savant/deepstream/nvinfer/element_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,36 +348,48 @@ def process(self, msg, kwargs):

# model type-specific parameters
if issubclass(model_type, ObjectModel):
if not model_config.output.objects:
# try to load model object labels from file
label_file = model_config.get(
'label_file',
nvinfer_config['property'].get('labelfile-path')
if nvinfer_config
else None,
)
# model_config.output.objects is mandatory for object models
# but it may be autogenerated based on labelfile or num_detected_classes

label_file = model_config.get(
'label_file',
nvinfer_config['property'].get('labelfile-path')
if nvinfer_config
else None,
)
if model_config.output.objects:
# highest priority is using manually defined model_config.output.objects
if label_file:
label_file_path = model_path / label_file
if label_file_path.is_file():
with open(label_file_path.resolve(), encoding='utf8') as file_obj:
model_config.output.objects = [
NvInferObjectModelOutputObject(
class_id=class_id, label=label
)
for class_id, label in enumerate(
file_obj.read().splitlines()
)
]
model_config.output.num_detected_classes = len(
model_config.output.objects
)
logger.info(
'Model object labels have been loaded from "%s".',
label_file_path,
)
logger.warning(
'Model output objects labels are defined manually '
'and will be used instead of labels from "%s".',
label_file,
)

# generate labels (enumerate)
if not model_config.output.objects and model_config.output.num_detected_classes:
elif label_file:
# try to load model object labels from file
label_file_path = model_path / label_file
if label_file_path.is_file():
with open(label_file_path.resolve(), encoding='utf8') as file_obj:
model_config.output.objects = [
NvInferObjectModelOutputObject(class_id=class_id, label=label)
for class_id, label in enumerate(file_obj.read().splitlines())
]
if model_config.output.num_detected_classes:
logger.warning(
'Ignoring manually set value for '
'(model_config.output.num_detected_classes) '
'because labelfile is used.'
)
model_config.output.num_detected_classes = len(
model_config.output.objects
)
logger.info(
'Model object labels have been loaded from "%s".',
label_file_path,
)
elif model_config.output.num_detected_classes:
# generate labels (enumerate)
model_config.output.objects = [
NvInferObjectModelOutputObject(class_id=class_id, label=str(class_id))
for class_id in range(model_config.output.num_detected_classes)
Expand Down
2 changes: 1 addition & 1 deletion savant/deepstream/nvinfer/file_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def from_model(
# setup class-attrs for object model (detector)
# set a high confidence threshold initially for all classes
# to filter out only the desired classes
config['class-attrs-all'] = {'pre-cluster-threshold': 1.1}
config['class-attrs-all'] = {'pre-cluster-threshold': 1e10}
# replace class-attrs parameters with selector kwargs
for obj in model_config.output.objects:
class_attrs = {}
Expand Down

0 comments on commit 4431340

Please sign in to comment.