Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regular detector obj default conf filter threshold #570

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions savant/deepstream/nvinfer/element_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,36 +348,48 @@ def process(self, msg, kwargs):

# model type-specific parameters
if issubclass(model_type, ObjectModel):
if not model_config.output.objects:
# try to load model object labels from file
label_file = model_config.get(
'label_file',
nvinfer_config['property'].get('labelfile-path')
if nvinfer_config
else None,
)
# model_config.output.objects is mandatory for object models
# but it may be autogenerated based on labelfile or num_detected_classes

label_file = model_config.get(
'label_file',
nvinfer_config['property'].get('labelfile-path')
if nvinfer_config
else None,
)
if model_config.output.objects:
# highest priority is using manually defined model_config.output.objects
if label_file:
label_file_path = model_path / label_file
if label_file_path.is_file():
with open(label_file_path.resolve(), encoding='utf8') as file_obj:
model_config.output.objects = [
NvInferObjectModelOutputObject(
class_id=class_id, label=label
)
for class_id, label in enumerate(
file_obj.read().splitlines()
)
]
model_config.output.num_detected_classes = len(
model_config.output.objects
)
logger.info(
'Model object labels have been loaded from "%s".',
label_file_path,
)
logger.warning(
'Model output objects labels are defined manually '
'and will be used instead of labels from "%s".',
label_file,
)

# generate labels (enumerate)
if not model_config.output.objects and model_config.output.num_detected_classes:
elif label_file:
# try to load model object labels from file
label_file_path = model_path / label_file
if label_file_path.is_file():
with open(label_file_path.resolve(), encoding='utf8') as file_obj:
model_config.output.objects = [
NvInferObjectModelOutputObject(class_id=class_id, label=label)
for class_id, label in enumerate(file_obj.read().splitlines())
]
if model_config.output.num_detected_classes:
abramov-oleg marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
'Ignoring manually set value for '
'(model_config.output.num_detected_classes) '
'because labelfile is used.'
)
model_config.output.num_detected_classes = len(
model_config.output.objects
)
logger.info(
'Model object labels have been loaded from "%s".',
label_file_path,
)
elif model_config.output.num_detected_classes:
# generate labels (enumerate)
model_config.output.objects = [
NvInferObjectModelOutputObject(class_id=class_id, label=str(class_id))
for class_id in range(model_config.output.num_detected_classes)
Expand Down
2 changes: 1 addition & 1 deletion savant/deepstream/nvinfer/file_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def from_model(
# setup class-attrs for object model (detector)
# set a high confidence threshold initially for all classes
# to filter out only the desired classes
config['class-attrs-all'] = {'pre-cluster-threshold': 1.1}
config['class-attrs-all'] = {'pre-cluster-threshold': 1e10}
# replace class-attrs parameters with selector kwargs
for obj in model_config.output.objects:
class_attrs = {}
Expand Down