From ddd3ddad7ea08392451564725628b076ff9a6431 Mon Sep 17 00:00:00 2001 From: Pavel Tomskikh Date: Sat, 25 Nov 2023 01:06:53 +0700 Subject: [PATCH] Extend PyFunc API to allow adding custom metrics to exporter (#574) * #567 extend PyFunc with custom metrics * #567 add an example of PyFunc with custom metrics * #567 add docs for PyFunc metrics * #567 move MetricsRegistry to a separate module --- docs/source/reference/api/index.rst | 1 + docs/source/reference/api/metrics.rst | 12 ++ gst_plugins/python/pyfunc.py | 12 ++ requirements/savant-rs.txt | 2 +- samples/pass_through_processing/README.md | 3 + .../docker-compose.x86.yml | 2 + samples/pass_through_processing/module.yml | 6 + .../py_func_metrics_example.py | 65 +++++++ savant/deepstream/pipeline.py | 2 + savant/deepstream/pyfunc.py | 26 +++ savant/metrics/__init__.py | 8 + savant/metrics/base.py | 7 + savant/metrics/metric.py | 134 +++++++++++++ savant/metrics/prometheus.py | 181 ++++++++++-------- savant/metrics/registry.py | 38 ++++ 15 files changed, 416 insertions(+), 83 deletions(-) create mode 100644 docs/source/reference/api/metrics.rst create mode 100644 samples/pass_through_processing/py_func_metrics_example.py create mode 100644 savant/metrics/metric.py create mode 100644 savant/metrics/registry.py diff --git a/docs/source/reference/api/index.rst b/docs/source/reference/api/index.rst index 464ca79e..e2488f0b 100644 --- a/docs/source/reference/api/index.rst +++ b/docs/source/reference/api/index.rst @@ -17,4 +17,5 @@ API Reference param_storage utils libs + metrics client diff --git a/docs/source/reference/api/metrics.rst b/docs/source/reference/api/metrics.rst new file mode 100644 index 00000000..e16fd2d7 --- /dev/null +++ b/docs/source/reference/api/metrics.rst @@ -0,0 +1,12 @@ +Metrics +======= + +.. currentmodule:: savant.metrics + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: autosummary/class.rst + + Counter + Gauge diff --git a/gst_plugins/python/pyfunc.py b/gst_plugins/python/pyfunc.py index cce8d901..a1e50e7f 100644 --- a/gst_plugins/python/pyfunc.py +++ b/gst_plugins/python/pyfunc.py @@ -15,6 +15,7 @@ gst_post_stream_failed_error, gst_post_stream_failed_warning, ) +from savant.metrics.base import BaseMetricsExporter from savant.utils.logging import LoggerMixin # RGBA format is required to access the frame (pyds.get_nvds_buf_surface) @@ -75,6 +76,12 @@ class GstPluginPyFunc(LoggerMixin, GstBase.BaseTransform): 'VideoPipeline object from savant-rs.', GObject.ParamFlags.READWRITE, ), + 'metrics-exporter': ( + object, + 'Metrics exporter.', + 'Metrics exporter.', + GObject.ParamFlags.READWRITE, + ), 'stream-pool-size': ( int, 'Max stream pool size', @@ -103,6 +110,7 @@ def __init__(self): self.class_name: Optional[str] = None self.kwargs: Optional[str] = None self.video_pipeline: Optional[VideoPipeline] = None + self.metrics_exporter: Optional[BaseMetricsExporter] = None self.dev_mode: bool = False self.max_stream_pool_size: int = 1 # pyfunc object @@ -121,6 +129,8 @@ def do_get_property(self, prop: GObject.GParamSpec) -> Any: return self.kwargs if prop.name == 'pipeline': return self.video_pipeline + if prop.name == 'metrics-exporter': + return self.metrics_exporter if prop.name == 'stream-pool-size': return self.max_stream_pool_size if prop.name == 'dev-mode': @@ -141,6 +151,8 @@ def do_set_property(self, prop: GObject.GParamSpec, value: Any): self.kwargs = value elif prop.name == 'pipeline': self.video_pipeline = value + elif prop.name == 'metrics-exporter': + self.metrics_exporter = value elif prop.name == 'stream-pool-size': self.max_stream_pool_size = value elif prop.name == 'dev-mode': diff --git a/requirements/savant-rs.txt b/requirements/savant-rs.txt index 95a3b021..18ce0a19 100644 --- a/requirements/savant-rs.txt +++ b/requirements/savant-rs.txt @@ -1 +1 @@ -savant-rs==0.1.83 +savant-rs==0.1.84 diff --git a/samples/pass_through_processing/README.md b/samples/pass_through_processing/README.md index 816bbd4c..d33a205d 100644 --- a/samples/pass_through_processing/README.md +++ b/samples/pass_through_processing/README.md @@ -50,6 +50,9 @@ docker compose -f samples/pass_through_processing/docker-compose.l4t.yml up # for pre-configured Grafana dashboard visit # http://127.0.0.1:3000/d/a4c1f484-75c9-4375-a04d-ab5d50578239/module-performance-metrics?orgId=1&refresh=5s +# for the tracker metrics visit +# http://127.0.0.1:8000/metrics + # Ctrl+C to stop running the compose bundle ``` diff --git a/samples/pass_through_processing/docker-compose.x86.yml b/samples/pass_through_processing/docker-compose.x86.yml index a35f5fa1..59351010 100644 --- a/samples/pass_through_processing/docker-compose.x86.yml +++ b/samples/pass_through_processing/docker-compose.x86.yml @@ -50,6 +50,8 @@ services: module-tracker: image: ghcr.io/insight-platform/savant-deepstream:latest restart: unless-stopped + ports: + - "8000:8000" volumes: - zmq_sockets:/tmp/zmq-sockets - ../../models/peoplenet_detector:/models diff --git a/samples/pass_through_processing/module.yml b/samples/pass_through_processing/module.yml index 9d11da12..688888ef 100644 --- a/samples/pass_through_processing/module.yml +++ b/samples/pass_through_processing/module.yml @@ -83,5 +83,11 @@ pipeline: tracker-width: 960 # 640 # must be a multiple of 32 tracker-height: 544 # 384 display-tracking-id: 0 + # pyfunc metrics example + - element: pyfunc + # specify the pyfunc's python module + module: samples.pass_through_processing.py_func_metrics_example + # specify the pyfunc's python class from the module + class_name: PyFuncMetricsExample # sink definition is skipped, zeromq sink is used by default to connect with sink adapters diff --git a/samples/pass_through_processing/py_func_metrics_example.py b/samples/pass_through_processing/py_func_metrics_example.py new file mode 100644 index 00000000..763aca37 --- /dev/null +++ b/samples/pass_through_processing/py_func_metrics_example.py @@ -0,0 +1,65 @@ +"""Example of how to use metrics in PyFunc.""" +from savant.deepstream.meta.frame import NvDsFrameMeta +from savant.deepstream.pyfunc import NvDsPyFuncPlugin +from savant.gstreamer import Gst +from savant.metrics import Counter, Gauge + + +class PyFuncMetricsExample(NvDsPyFuncPlugin): + """Example of how to use metrics in PyFunc. + + Metrics values example: + + .. code-block:: text + + # HELP frames_per_source_total Number of processed frames per source + # TYPE frames_per_source_total counter + frames_per_source_total{module_stage="tracker",source_id="city-traffic"} 748.0 1700803467794 + # HELP total_queue_length The total queue length for the pipeline + # TYPE total_queue_length gauge + total_queue_length{module_stage="tracker",source_id="city-traffic"} 36.0 1700803467794 + + Note: the "module_stage" label is configured in docker-compose file and added to all metrics. + """ + + # Called when the new source is added + def on_source_add(self, source_id: str): + # Check if the metric is not registered yet + if 'frames_per_source' not in self.metrics: + # Register the counter metric + self.metrics['frames_per_source'] = Counter( + name='frames_per_source', + description='Number of processed frames per source', + # Labels are optional, by default there are no labels + labelnames=('source_id',), + ) + self.logger.info('Registered metric: %s', 'frames_per_source') + if 'total_queue_length' not in self.metrics: + # Register the gauge metric + self.metrics['total_queue_length'] = Gauge( + name='total_queue_length', + description='The total queue length for the pipeline', + # There are no labels for this metric + ) + self.logger.info('Registered metric: %s', 'total_queue_length') + + def process_frame(self, buffer: Gst.Buffer, frame_meta: NvDsFrameMeta): + # Count the frame for this source + self.metrics['frames_per_source'].inc( + # 1, # Default increment value + # Labels should be a tuple and must match the labelnames + labels=(frame_meta.source_id,), + ) + try: + last_runtime_metric = self.get_runtime_metrics(1)[0] + queue_length = sum( + stage.queue_length for stage in last_runtime_metric.stage_stats + ) + except IndexError: + queue_length = 0 + + # Set the total queue length for this source + self.metrics['total_queue_length'].set( + queue_length, # The new gauge value + # There are no labels for this metric + ) diff --git a/savant/deepstream/pipeline.py b/savant/deepstream/pipeline.py index c7cc955a..3bde1909 100644 --- a/savant/deepstream/pipeline.py +++ b/savant/deepstream/pipeline.py @@ -210,6 +210,8 @@ def add_element( if isinstance(element, PyFuncElement): gst_element.set_property('pipeline', self._video_pipeline) gst_element.set_property('stream-pool-size', self._batch_size) + if self._metrics_exporter is not None: + gst_element.set_property('metrics-exporter', self._metrics_exporter) # TODO: add stage names to element config? if isinstance(element_idx, int): stage = self._element_stages[element_idx] diff --git a/savant/deepstream/pyfunc.py b/savant/deepstream/pyfunc.py index 9292a26e..808867fd 100644 --- a/savant/deepstream/pyfunc.py +++ b/savant/deepstream/pyfunc.py @@ -21,6 +21,8 @@ nvds_frame_meta_iterator, ) from savant.gstreamer import Gst # noqa: F401 +from savant.metrics.base import BaseMetricsExporter +from savant.metrics.registry import MetricsRegistry from savant.utils.source_info import SourceInfoRegistry @@ -35,6 +37,8 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self._sources = SourceInfoRegistry() self._video_pipeline: Optional[VideoPipeline] = None + self._metrics_exporter: Optional[BaseMetricsExporter] = None + self._metrics_registry: Optional[MetricsRegistry] = None self._last_nvevent_seqnum: Dict[int, Dict[int, int]] = { event_type: {} for event_type in [ @@ -49,6 +53,8 @@ def __init__(self, **kwargs): def on_start(self) -> bool: """Do on plugin start.""" self._video_pipeline = self.gst_element.get_property('pipeline') + self._metrics_exporter = self.gst_element.get_property('metrics-exporter') + self._metrics_registry = MetricsRegistry(self._metrics_exporter) # the prop is set to pipeline batch size during init self._stream_pool_size = self.gst_element.get_property('stream-pool-size') return True @@ -183,3 +189,23 @@ def get_runtime_metrics(self, n: int): """Get last runtime metrics.""" return self._video_pipeline.get_stat_records(n) + + @property + def metrics(self) -> MetricsRegistry: + """Get metrics registry. + + Usage example: + + .. code-block:: python + + from savant.metrics import Counter + self.metrics['frames_per_source'] = Counter( + name='frames_per_source', + description='Number of processed frames per source', + labelnames=('source_id',), + ) + ... + self.metrics['frames_per_source'].inc(labels=('camera-1',)) + """ + + return self._metrics_registry diff --git a/savant/metrics/__init__.py b/savant/metrics/__init__.py index d9c378d8..f011cd91 100644 --- a/savant/metrics/__init__.py +++ b/savant/metrics/__init__.py @@ -4,6 +4,7 @@ from savant.config.schema import MetricsParameters from savant.metrics.base import BaseMetricsExporter +from savant.metrics.metric import Counter, Gauge from savant.metrics.prometheus import PrometheusMetricsExporter @@ -20,3 +21,10 @@ def build_metrics_exporter( return PrometheusMetricsExporter(pipeline, params.provider_params) raise ValueError(f'Unknown metrics provider: {params.provider}') + + +__all__ = [ + 'Counter', + 'Gauge', + 'build_metrics_exporter', +] diff --git a/savant/metrics/base.py b/savant/metrics/base.py index 558bbfe6..186fc0c3 100644 --- a/savant/metrics/base.py +++ b/savant/metrics/base.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from savant.metrics.metric import Metric + class BaseMetricsExporter(ABC): """Base class for metrics exporters.""" @@ -11,3 +13,8 @@ def start(self): @abstractmethod def stop(self): """Stop metrics exporter.""" + + @abstractmethod + def register_metric(self, metric: Metric): + """Register metric.""" + pass diff --git a/savant/metrics/metric.py b/savant/metrics/metric.py new file mode 100644 index 00000000..988a0e12 --- /dev/null +++ b/savant/metrics/metric.py @@ -0,0 +1,134 @@ +import time +from typing import Dict, Optional, Tuple + + +class Metric: + """Base class for metrics. + + :param name: Metric name. + :param description: Metric description. + :param labelnames: Metric label names. + """ + + def __init__( + self, + name: str, + description: str = '', + labelnames: Tuple[str, ...] = (), + ): + self._name = name + self._description = description or name + self._labelnames = labelnames + self._values: Dict[Tuple[str, ...], Tuple[float, float]] = {} + + @property + def name(self) -> str: + """Metric name.""" + return self._name + + @property + def description(self) -> str: + """Metric description.""" + return self._description + + @property + def labelnames(self) -> Tuple[str, ...]: + """Metric label names.""" + return self._labelnames + + @property + def values(self) -> Dict[Tuple[str, ...], Tuple[float, float]]: + """Metric values. + + :return: Dictionary: labels -> (value, timestamp). + """ + return self._values + + +class Counter(Metric): + """Counter metric. + + Usage example: + + .. code-block:: python + + counter = Counter( + name='frames_per_source', + description='Number of processed frames per source', + labelnames=('source_id',), + ) + counter.inc(labels=('camera-1',)) + """ + + def inc( + self, + amount=1, + labels: Tuple[str, ...] = (), + timestamp: Optional[float] = None, + ): + """Increment counter by amount. + + :param amount: Increment amount. + :param labels: Labels values. + :param timestamp: Metric timestamp. + """ + + assert len(labels) == len(self._labelnames), 'Labels must match label names' + assert amount > 0, 'Counter increment amount must be positive' + last_value = self._values.get(labels, (0, 0))[0] + if timestamp is None: + timestamp = time.time() + self._values[labels] = last_value + amount, timestamp + + def set( + self, + value, + labels: Tuple[str, ...] = (), + timestamp: Optional[float] = None, + ): + """Set counter to specific value. + + :param value: Counter value. Must be non-decreasing. + :param labels: Labels values. + :param timestamp: Metric timestamp. + """ + + assert len(labels) == len(self._labelnames), 'Labels must match label names' + last_value = self._values.get(labels, (0, 0))[0] + assert value >= last_value, 'Counter value must be non-decreasing' + if timestamp is None: + timestamp = time.time() + self._values[labels] = value, timestamp + + +class Gauge(Metric): + """Gauge metric. + + Usage example: + + .. code-block:: python + + gauge = Gauge( + name='total_queue_length', + description='The total queue length for the pipeline', + ) + gauge.set(123) + """ + + def set( + self, + value, + labels: Tuple[str, ...] = (), + timestamp: Optional[float] = None, + ): + """Set gauge to specific value. + + :param value: Gauge value. + :param labels: Labels values. + :param timestamp: Metric timestamp. + """ + + assert len(labels) == len(self._labelnames), 'Labels must match label names' + if timestamp is None: + timestamp = time.time() + self._values[labels] = value, timestamp diff --git a/savant/metrics/prometheus.py b/savant/metrics/prometheus.py index dc09e166..15eab54d 100644 --- a/savant/metrics/prometheus.py +++ b/savant/metrics/prometheus.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List -from prometheus_client import start_http_server +from prometheus_client import CollectorRegistry, start_http_server from prometheus_client.metrics_core import CounterMetricFamily, GaugeMetricFamily -from prometheus_client.registry import REGISTRY, Collector +from prometheus_client.registry import Collector from savant_rs.pipeline2 import ( FrameProcessingStatRecord, FrameProcessingStatRecordType, @@ -10,6 +10,7 @@ ) from savant.metrics.base import BaseMetricsExporter +from savant.metrics.metric import Counter, Gauge, Metric from savant.utils.logging import get_logger logger = get_logger(__name__) @@ -25,20 +26,25 @@ class PrometheusMetricsExporter(BaseMetricsExporter): def __init__(self, pipeline: VideoPipeline, params: Dict[str, Any]): self._port = params['port'] labels = params.get('labels') or {} + self._registry = CollectorRegistry() self._metrics_collector = ModuleMetricsCollector(pipeline, labels) def start(self): logger.debug('Starting Prometheus metrics exporter on port %s', self._port) - start_http_server(self._port) + start_http_server(self._port, registry=self._registry) logger.debug('Registering metrics collector') - REGISTRY.register(self._metrics_collector) + self._registry.register(self._metrics_collector) logger.info('Started Prometheus metrics exporter on port %s', self._port) def stop(self): logger.debug('Unregistering metrics collector') - REGISTRY.unregister(self._metrics_collector) + self._registry.unregister(self._metrics_collector) # TODO: stop the server + def register_metric(self, metric: Metric): + logger.debug('Registering metric %s of type %s', metric.name, type(metric)) + self._metrics_collector.register_metric(metric) + class ModuleMetricsCollector(Collector): """Collector for module metrics with timestamps.""" @@ -52,17 +58,43 @@ def __init__( self._pipeline = pipeline extra_labels = sorted(extra_labels.items()) - extra_label_names = [name for name, _ in extra_labels] - self._label_names = ['record_type'] + extra_label_names - self._stage_label_names = ['record_type', 'stage_name'] + extra_label_names + self._extra_label_names = tuple(name for name, _ in extra_labels) self._extra_label_values = tuple(value for _, value in extra_labels) - self._frame_counter: Dict[Tuple[str, ...], Tuple[int, float]] = {} - self._object_counter: Dict[Tuple[str, ...], Tuple[int, float]] = {} - self._stage_queue_length: Dict[Tuple[str, ...], Tuple[int, float]] = {} - self._stage_frame_counter: Dict[Tuple[str, ...], Tuple[int, float]] = {} - self._stage_object_counter: Dict[Tuple[str, ...], Tuple[int, float]] = {} - self._stage_batch_counter: Dict[Tuple[str, ...], Tuple[int, float]] = {} + label_names = ('record_type',) + stage_label_names = ('record_type', 'stage_name') + self._metrics = { + 'frame_counter': Counter( + 'frame_counter', + 'Number of frames passed through the module', + label_names, + ), + 'object_counter': Counter( + 'object_counter', + 'Number of objects passed through the module', + label_names, + ), + 'stage_queue_length': Gauge( + 'stage_queue_length', + 'Queue length in the stage', + stage_label_names, + ), + 'stage_frame_counter': Counter( + 'stage_frame_counter', + 'Number of frames passed through the stage', + stage_label_names, + ), + 'stage_object_counter': Counter( + 'stage_object_counter', + 'Number of objects passed through the stage', + stage_label_names, + ), + 'stage_batch_counter': Counter( + 'stage_batch_counter', + 'Number of frame batches passed through the stage', + stage_label_names, + ), + } def update_metrics(self, record: FrameProcessingStatRecord): """Update metrics values.""" @@ -76,14 +108,22 @@ def update_metrics(self, record: FrameProcessingStatRecord): record_type_str = _record_type_to_string(record.record_type) ts = record.ts / 1000 labels = (record_type_str,) - self._frame_counter[labels] = record.frame_no, ts - self._object_counter[labels] = record.object_counter, ts + self._metrics['frame_counter'].set(record.frame_no, labels, ts) + self._metrics['object_counter'].set(record.object_counter, labels, ts) for stage in record.stage_stats: stage_labels = record_type_str, stage.stage_name - self._stage_queue_length[stage_labels] = (stage.queue_length, ts) - self._stage_frame_counter[stage_labels] = (stage.frame_counter, ts) - self._stage_object_counter[stage_labels] = (stage.object_counter, ts) - self._stage_batch_counter[stage_labels] = (stage.batch_counter, ts) + self._metrics['stage_queue_length'].set( + stage.queue_length, stage_labels, ts + ) + self._metrics['stage_frame_counter'].set( + stage.frame_counter, stage_labels, ts + ) + self._metrics['stage_object_counter'].set( + stage.object_counter, stage_labels, ts + ) + self._metrics['stage_batch_counter'].set( + stage.batch_counter, stage_labels, ts + ) def collect(self): """Build and collect all metrics.""" @@ -97,65 +137,35 @@ def build_all_metrics(self): """Build all metrics for Promethus to collect.""" logger.trace('Building metrics') - yield self.build_metric( - 'frame_counter', - 'Number of frames passed through the module', - self._label_names, - self._frame_counter, - CounterMetricFamily, - ) - yield self.build_metric( - 'object_counter', - 'Number of objects passed through the module', - self._label_names, - self._object_counter, - CounterMetricFamily, - ) - yield self.build_metric( - 'stage_queue_length', - 'Queue length in the stage', - self._stage_label_names, - self._stage_queue_length, - GaugeMetricFamily, - ) - yield self.build_metric( - 'stage_frame_counter', - 'Number of frames passed through the stage', - self._stage_label_names, - self._stage_frame_counter, - CounterMetricFamily, - ) - yield self.build_metric( - 'stage_object_counter', - 'Number of objects passed through the stage', - self._stage_label_names, - self._stage_object_counter, - CounterMetricFamily, - ) - yield self.build_metric( - 'stage_batch_counter', - 'Number of frame batches passed through the stage', - self._stage_label_names, - self._stage_batch_counter, - CounterMetricFamily, - ) + for metric in self._metrics.values(): + yield self.build_metric(metric) - def build_metric( - self, - name: str, - documentation: str, - label_names: List[str], - values: Dict[Tuple[str, ...], Tuple[int, float]], - metric_class, - ): + def build_metric(self, metric: Metric): """Build metric for Prometheus to collect.""" - logger.trace('Building metric %s', name) - counter = metric_class(name, documentation, labels=label_names) - for labels, (value, ts) in values.items(): - counter.add_metric(labels + self._extra_label_values, value, timestamp=ts) + logger.trace('Building metric %s', metric.name) + if isinstance(metric, Counter): + metric_class = CounterMetricFamily + elif isinstance(metric, Gauge): + metric_class = GaugeMetricFamily + else: + raise ValueError( + f'Failed to build metric {metric.name}: unsupported metric type {type(metric)}' + ) + prom_metric = metric_class( + name=metric.name, + documentation=metric.description, + labels=metric.labelnames + self._extra_label_names, + ) + for labels, (value, ts) in metric.values.items(): + logger.trace('Building metric %s for labels %s', metric.name, labels) + prom_metric.add_metric( + labels=labels + self._extra_label_values, + value=value, + timestamp=ts, + ) - return counter + return prom_metric def get_last_records(self) -> List[FrameProcessingStatRecord]: """Get last metrics records from the pipeline. @@ -164,13 +174,10 @@ def get_last_records(self) -> List[FrameProcessingStatRecord]: timestamp-based records. """ - last_record_id = self._last_record_id frame_based_record = None timestamp_based_record = None - for record in self._pipeline.get_stat_records(100): # TODO: use last_record_id - if record.id <= self._last_record_id: - continue + for record in self._pipeline.get_stat_records_newer_than(self._last_record_id): if record.record_type == FrameProcessingStatRecordType.Frame: if frame_based_record is None or frame_based_record.id < record.id: frame_based_record = record @@ -180,8 +187,7 @@ def get_last_records(self) -> List[FrameProcessingStatRecord]: or timestamp_based_record.id < record.id ): timestamp_based_record = record - last_record_id = max(last_record_id, record.id) - self._last_record_id = last_record_id + self._last_record_id = max(self._last_record_id, record.id) records = [] if frame_based_record is not None: records.append(frame_based_record) @@ -190,6 +196,17 @@ def get_last_records(self) -> List[FrameProcessingStatRecord]: return records + def register_metric(self, metric: Metric): + if not isinstance(metric, (Counter, Gauge)): + raise ValueError( + f'Failed to register metric {metric.name}: unsupported metric type {type(metric)}' + ) + if metric.name in self._metrics: + raise ValueError( + f'Failed to register metric {metric.name}: metric already exists' + ) + self._metrics[metric.name] = metric + def _record_type_to_string(record_type: FrameProcessingStatRecordType) -> str: # Cannot use dict since FrameProcessingStatRecordType is not hashable diff --git a/savant/metrics/registry.py b/savant/metrics/registry.py new file mode 100644 index 00000000..8acbf8fa --- /dev/null +++ b/savant/metrics/registry.py @@ -0,0 +1,38 @@ +from typing import Optional + +from savant.metrics.base import BaseMetricsExporter +from savant.metrics.metric import Metric +from savant.utils.logging import get_logger + +logger = get_logger(__name__) + + +class MetricsRegistry: + """Metrics registry. + + Provides a dict-like interface for registering an updating metrics. + """ + + def __init__(self, exporter: Optional[BaseMetricsExporter]): + self._exporter = exporter + self._metrics = {} + + def __setitem__(self, key, value: Metric): + if not isinstance(value, Metric): + raise ValueError('Value must be a Metric instance') + if key in self._metrics: + raise KeyError(f'Key {key!r} already exists') + if self._exporter is not None: + self._exporter.register_metric(value) + else: + logger.warning( + 'Metric exporter not configured. Ignoring metric %s.', + value.name, + ) + self._metrics[key] = value + + def __getitem__(self, key): + return self._metrics[key] + + def __contains__(self, key): + return key in self._metrics