Skip to content

Commit

Permalink
Refactor nvinfer processing (#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
denisvmedyantsev authored Nov 21, 2023
1 parent 6f18975 commit 3e0efce
Show file tree
Hide file tree
Showing 7 changed files with 771 additions and 702 deletions.
548 changes: 23 additions & 525 deletions savant/deepstream/buffer_processor.py

Large diffs are not rendered by default.

615 changes: 615 additions & 0 deletions savant/deepstream/nvinfer/processor.py

Large diffs are not rendered by default.

41 changes: 20 additions & 21 deletions savant/deepstream/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import time
from collections import defaultdict
from pathlib import Path
from queue import Queue
from threading import Lock, Thread
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -42,6 +41,7 @@
nvds_attr_meta_output_converter,
nvds_obj_meta_output_converter,
)
from savant.deepstream.nvinfer.processor import NvInferProcessor
from savant.deepstream.source_output import create_source_output
from savant.deepstream.utils import (
GST_NVEVENT_STREAM_EOS,
Expand All @@ -59,7 +59,7 @@
)
from savant.gstreamer import GLib, Gst # noqa:F401
from savant.gstreamer.pipeline import GstPipeline
from savant.gstreamer.utils import on_pad_event, pad_to_source_id
from savant.gstreamer.utils import add_buffer_probe, on_pad_event, pad_to_source_id
from savant.meta.constants import PRIMARY_OBJECT_KEY, UNTRACKED_OBJECT_ID
from savant.utils.fps_meter import FPSMeter
from savant.utils.platform import is_aarch64
Expand Down Expand Up @@ -94,9 +94,6 @@ def __init__(

self._max_parallel_streams = kwargs.get('max_parallel_streams', 64)

# model artifacts path
self._model_path = Path(kwargs['model_path'])

self._source_adding_lock = Lock()
self._sources = SourceInfoRegistry()

Expand Down Expand Up @@ -185,26 +182,31 @@ def _build_buffer_processor(
def add_element(
self,
element: PipelineElement,
with_probes: bool = False,
link: bool = True,
element_idx: Optional[Union[int, Tuple[int, int]]] = None,
) -> Gst.Element:
if isinstance(element, ModelElement):
if element.model.input.preprocess_object_image:
self._objects_preprocessing.add_preprocessing_function(
element_name=element.name,
preprocessing_func=element.model.input.preprocess_object_image,
)
if isinstance(element.model, (AttributeModel, ComplexModel)):
for attr in element.model.output.attributes:
if attr.internal:
self._internal_attrs.add((element.name, attr.name))
gst_element = super().add_element(
element=element,
with_probes=with_probes,
link=link,
element_idx=element_idx,
)

if isinstance(element, ModelElement):
if isinstance(element.model, (AttributeModel, ComplexModel)):
for attr in element.model.output.attributes:
if attr.internal:
self._internal_attrs.add((element.name, attr.name))
nvinfer = NvInferProcessor(
element,
self._objects_preprocessing,
self._frame_params,
self._video_pipeline,
)
if nvinfer.preproc is not None:
add_buffer_probe(gst_element.get_static_pad('sink'), nvinfer.preproc)
if nvinfer.postproc is not None:
add_buffer_probe(gst_element.get_static_pad('src'), nvinfer.postproc)

if element_idx is not None:
if isinstance(element, PyFuncElement):
gst_element.set_property('pipeline', self._video_pipeline)
Expand Down Expand Up @@ -876,10 +878,7 @@ def _create_muxer(self, live_source: bool) -> Gst.Element:
self._video_pipeline,
'prepare-input',
)
muxer_src_pad.add_probe(
Gst.PadProbeType.BUFFER,
self._buffer_processor.input_probe,
)
add_buffer_probe(muxer_src_pad, self._buffer_processor.prepare_input)

return self._muxer

Expand Down
117 changes: 12 additions & 105 deletions savant/gstreamer/buffer_processor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
"""Buffer processor for GStreamer pipeline."""
import inspect
from abc import ABC, abstractmethod
from queue import Queue
from types import FrameType
from typing import Iterator

from gi.repository import Gst

from savant.config.schema import PipelineElement
from savant.gstreamer.utils import gst_post_stream_failed_error
from savant.utils.fps_meter import FPSMeter
from savant.utils.logging import get_logger
from savant.utils.sink_factories import SinkMessage
Expand All @@ -28,6 +24,10 @@ def __init__(self, queue: Queue, fps_meter: FPSMeter):
self._queue = queue
self._fps_meter = fps_meter

@property
def logger(self):
return self._logger

# Buffer handlers
@abstractmethod
def prepare_input(self, buffer: Gst.Buffer):
Expand All @@ -37,107 +37,14 @@ def prepare_input(self, buffer: Gst.Buffer):
def prepare_output(self, buffer: Gst.Buffer, user_data) -> Iterator[SinkMessage]:
"""Pipeline output processor."""

def process_output(self, buffer: Gst.Buffer, user_data):
"""Pipeline output processor wrapper."""
for sink_message in self.prepare_output(buffer, user_data):
self._queue.put(sink_message)
# measure and logging FPS
if self._fps_meter():
self.logger.info(self._fps_meter.message)

@abstractmethod
def on_eos(self, user_data):
"""Pipeline EOS handler."""

@abstractmethod
def prepare_element_input(self, element: PipelineElement, buffer: Gst.Buffer):
"""Element input processor."""

@abstractmethod
def prepare_element_output(self, element: PipelineElement, buffer: Gst.Buffer):
"""Element output processor."""

# Pad probes to attach buffer handlers
def input_probe( # pylint: disable=unused-argument
self,
pad: Gst.Pad,
info: Gst.PadProbeInfo,
):
"""Attach pipeline input processor to pad."""

buffer = info.get_buffer()
try:
self.prepare_input(buffer)
except Exception as exc: # pylint: disable=broad-except
self._report_error(
pad,
f'Failed to prepare input for buffer with PTS {buffer.pts}: {exc}.',
inspect.currentframe(),
)

return Gst.PadProbeReturn.OK

def output_probe( # pylint: disable=unused-argument
self,
pad: Gst.Pad,
info: Gst.PadProbeInfo,
user_data,
):
"""Attach pipeline output processor to pad."""

buffer = info.get_buffer()
try:
for sink_message in self.prepare_output(buffer, user_data):
self._queue.put(sink_message)
# measure and logging FPS
if self._fps_meter():
self._logger.info(self._fps_meter.message)
except Exception as exc: # pylint: disable=broad-except
self._report_error(
pad,
f'Failed to prepare output for buffer with PTS {buffer.pts}: {exc}.',
inspect.currentframe(),
)

return Gst.PadProbeReturn.OK

def element_input_probe( # pylint: disable=unused-argument
self,
pad: Gst.Pad,
info: Gst.PadProbeInfo,
element: PipelineElement,
):
"""Attach element input processor to pad."""

buffer = info.get_buffer()
try:
self.prepare_element_input(element, buffer)
except Exception as exc: # pylint: disable=broad-except
self._report_error(
pad,
f'Failed to prepare "{element.name}" element input for buffer with PTS {buffer.pts}: {exc}.',
inspect.currentframe(),
)

return Gst.PadProbeReturn.OK

def element_output_probe( # pylint: disable=unused-argument
self,
pad: Gst.Pad,
info: Gst.PadProbeInfo,
element: PipelineElement,
):
"""Attach element output processor to pad."""

buffer = info.get_buffer()
try:
self.prepare_element_output(element, buffer)
except Exception as exc: # pylint: disable=broad-except
self._report_error(
pad,
f'Failed to prepare "{element.name}" element output for buffer with PTS {buffer.pts}: {exc}.',
inspect.currentframe(),
)

return Gst.PadProbeReturn.OK

def _report_error(self, pad: Gst.Pad, error: str, frame: FrameType):
self._logger.exception(error)
gst_post_stream_failed_error(
gst_element=pad.get_parent_element(),
frame=frame,
file_path=__file__,
text=error,
)
53 changes: 49 additions & 4 deletions savant/gstreamer/element_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""GStreamer pipeline elements factory."""

from typing import Union

from gi.repository import Gst # noqa:F401

from savant.config.schema import PipelineElement
from savant.base.model import AttributeModel, ComplexModel, ObjectModel
from savant.config.schema import ModelElement, PipelineElement


class CreateElementException(Exception):
Expand All @@ -25,6 +28,9 @@ def create(self, element: PipelineElement) -> Gst.Element:
if element.element == 'videotestsrc':
return self.create_videotestsrc(element)

if isinstance(element, ModelElement):
return self.create_model_element(element)

if isinstance(element, PipelineElement):
return self.create_element(element)

Expand All @@ -36,20 +42,54 @@ def create(self, element: PipelineElement) -> Gst.Element:
def create_element(element: PipelineElement) -> Gst.Element:
"""Creates Gst.Element.
:param element: pipeline element to create.
:param element: PipelineElement to create.
:raises CreateElementException: Unable to create element.
:return: Created Gst.Element
"""
gst_element = Gst.ElementFactory.make(element.element, element.name)
if not gst_element:
raise CreateElementException(f'Unable to create element {element}.')

# set element name from GstElement
if element.name is None:
element.name = gst_element.name

for prop_name, prop_value in element.properties.items():
if prop_value is not None:
gst_element.set_property(prop_name, prop_value)

return gst_element

@staticmethod
def create_model_element(element: ModelElement) -> Gst.Element:
"""Creates Gst.Element for ModelElement.
:param element: ModelElement to create.
:return: Created Gst.Element
"""

model: Union[AttributeModel, ComplexModel, ObjectModel] = element.model

if model.input.preprocess_object_meta:
model.input.preprocess_object_meta.load_user_code()
if model.input.preprocess_object_image:
model.input.preprocess_object_image.load_user_code()
if model.output.converter:
model.output.converter.load_user_code()
if isinstance(model, (ObjectModel, ComplexModel)):
for obj in model.output.objects:
if obj.selector:
obj.selector.load_user_code()

return GstElementFactory.create_element(element)

@staticmethod
def create_caps_filter(element: PipelineElement) -> Gst.Element:
"""Creates ``capsfilter`` Gst.Element.
:param element: Element to create.
:return: Created Gst.Element
"""
caps = None
if 'caps' in element.properties and isinstance(element.properties['caps'], str):
caps = Gst.Caps.from_string(element.properties['caps'])
Expand All @@ -61,10 +101,15 @@ def create_caps_filter(element: PipelineElement) -> Gst.Element:

@staticmethod
def create_videotestsrc(element: PipelineElement) -> Gst.Bin:
"""videotestsrc element as Gst.Bin with `pad-added`."""
src_decodebin = Gst.Bin.new(element.name)
"""Creates ``videotestsrc`` element as a Gst.Bin with ``pad-added``.
:param element: Element to create.
:return: Created Gst.Element
"""
src_element = GstElementFactory.create_element(element)

src_decodebin = Gst.Bin.new(element.name)

Gst.Bin.add(src_decodebin, src_element)

decodebin = GstElementFactory.create_element(PipelineElement('decodebin'))
Expand Down
Loading

0 comments on commit 3e0efce

Please sign in to comment.