Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the switching term to GOSPA. #898

Merged
merged 9 commits into from
Jan 9, 2024
Merged

Conversation

Carlson-J
Copy link
Contributor

Fixes issue described in #897

Adds the switching term to the GOSPA as outlined in #876

@Carlson-J Carlson-J marked this pull request as ready for review November 29, 2023 15:04
@Carlson-J Carlson-J requested a review from a team as a code owner November 29, 2023 15:04
@Carlson-J Carlson-J requested review from hpritchett-dstl and orosoman-dstl and removed request for a team November 29, 2023 15:04
Copy link
Member

@sdhiscocks sdhiscocks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Carlson-J. Overall looks good.

My main suggestion, is it'd probably be cleaner to use None and empty dict, rather than specific numbers, to represent unseen/not-associated state. Minor suggestion is making SwitchingLoss "private" (or possibly just part of GOSPA itself, rather than separate class). Diff of suggestion below:

diff --git a/stonesoup/metricgenerator/ospametric.py b/stonesoup/metricgenerator/ospametric.py
index 2ba6ff4e..737427e5 100644
--- a/stonesoup/metricgenerator/ospametric.py
+++ b/stonesoup/metricgenerator/ospametric.py
@@ -11,17 +11,14 @@ from ..types.time import TimeRange
 from ..types.metric import SingleTimeMetric, TimeRangeMetric
 
 
-class SwitchingLoss:
+class _SwitchingLoss:
     """
     Holds state assignment history and computes GOSPA switching term
     See https://www.mathworks.com/help/fusion/ref/trackgospametric-system-object.html#d126e213697
     """
-    def __init__(self, truth_ids, loss_factor, p):
-        self.not_associated = -1
-        self.unseen = -2
-        self.truth_associations = {i: self.unseen for i in truth_ids}
-        self.switching_loss = 0
-        self.has_associations = False
+    def __init__(self, loss_factor, p):
+        self.truth_associations = {}
+        self.switching_loss = None
         self.loss_factor = loss_factor
         self.p = p
         self.switching_penalty = 0.5
@@ -34,25 +31,23 @@ class SwitchingLoss:
         ----------
         truth_associations: dict(truth_track_id: measurement_track_id)
         """
-        self.has_associations = True
         self.switching_loss = 0
 
         for truth_id, meas_id in truth_associations.items():
-            if self.truth_associations[truth_id] == self.unseen and meas_id == self.not_associated:
+            if truth_id not in self.truth_associations and meas_id is None:
                 continue
-            if self.truth_associations[truth_id] == self.unseen:
+            if truth_id not in self.truth_associations:
                 self.truth_associations[truth_id] = meas_id
             elif self.truth_associations[truth_id] != meas_id:
                 self.switching_loss += self.switching_penalty
-                if (meas_id != self.not_associated
-                        and self.truth_associations[truth_id] != self.not_associated):
+                if meas_id is not None and self.truth_associations[truth_id] is not None:
                     self.switching_loss += self.switching_penalty
 
                 self.truth_associations[truth_id] = meas_id
 
     def loss(self):
         """Compute loss based on last association."""
-        if not self.has_associations:
+        if self.switching_loss is None:
             raise RuntimeError("Can't compute switching loss before any association are added.")
         return self.loss_factor * self.switching_loss**(1/self.p)
 
@@ -172,7 +167,7 @@ class GOSPAMetric(MetricGenerator):
             state.timestamp
             for state in chain(measured_states, truth_states)})
 
-        switching_metric = SwitchingLoss(truth_state_ids, self.switching_penalty, self.p)
+        switching_metric = _SwitchingLoss(self.switching_penalty, self.p)
         gospa_metrics = []
         for timestamp in timestamps:
             meas_mask = [state.timestamp == timestamp for state in measured_states]
@@ -185,9 +180,9 @@ class GOSPAMetric(MetricGenerator):
 
             metric, truth_to_measured_assignment = self.compute_gospa_metric(
                     meas_points, truth_points)
-            truth_mapping = {}
-            for i, meas_id in enumerate(truth_to_measured_assignment):
-                truth_mapping[truth_ids[i]] = -1 if meas_id == -1 else meas_ids[meas_id]
+            truth_mapping = {
+                truth_id: meas_ids[meas_id] if meas_id != -1 else None
+                for truth_id, meas_id in zip(truth_ids, truth_to_measured_assignment)}
 
             switching_metric.add_associations(truth_mapping)
             metric.value['switching'] = switching_metric.loss()
diff --git a/stonesoup/metricgenerator/tests/test_ospametric.py b/stonesoup/metricgenerator/tests/test_ospametric.py
index 9d33c11e..58a91421 100644
--- a/stonesoup/metricgenerator/tests/test_ospametric.py
+++ b/stonesoup/metricgenerator/tests/test_ospametric.py
@@ -5,7 +5,7 @@ import numpy as np
 import pytest
 
 from ..manager import MultiManager
-from ..ospametric import GOSPAMetric, OSPAMetric, SwitchingLoss
+from ..ospametric import GOSPAMetric, OSPAMetric, _SwitchingLoss
 from ...types.detection import Detection
 from ...types.groundtruth import GroundTruthPath, GroundTruthState
 from ...types.state import State
@@ -416,23 +416,23 @@ def test_ospa_computemetric_cardinality_error(p, first_value, second_value):
     ([
         {0: 0, 1: 1, 2: 2},
         {0: 0, 1: 1, 2: 2},
-        {0: 0, 1: 1, 2: -1},
+        {0: 0, 1: 1, 2: None},
         {0: 0, 1: 1, 2: 2},
         {0: 1, 1: 0, 2: 2},
         {1: 0, 0: 1, 2: 2},
-        {0: -1, 1: 2, 2: 0},
+        {0: None, 1: 2, 2: 0},
     ],  [0, 0, 0.5, 0.5, 2, 0, 2.5]),
     ([
         {0: 0, 1: 1, 2: 2},
-        {0: -1, 1: -1, 2: -1},
-        {0: 3, 1: -1, 2: -1},
+        {0: None, 1: None, 2: None},
+        {0: 3, 1: None, 2: None},
     ], [0, 1.5, 0.5]),
     ([
-        {0: -1, 1: -1, 2: -1},
+        {0: None, 1: None, 2: None},
         {0: 0, 1: 1, 2: 2},
     ], [0, 0]),
     ([  # The first time we associate with the track it should not count for loss
-        {0: -1, 1: -1, 2: -1},
+        {0: None, 1: None, 2: None},
         {0: 0, 1: 1, 2: 2},
     ], [0, 0]),
     ([  # The first time we associate with the track it should not count for loss
@@ -451,7 +451,7 @@ def test_switching_loss(associations, expected_losses):
     loss_factor = 1
     truth_ids = list(range(3))
 
-    switching_loss = SwitchingLoss(truth_ids, loss_factor, 1)
+    switching_loss = _SwitchingLoss(loss_factor, 1)
 
     with pytest.raises(RuntimeError) as _:
         switching_loss.loss()   # Should raise error if no associations have been added yet.

@Carlson-J Carlson-J requested a review from sdhiscocks December 14, 2023 16:13
stonesoup/metricgenerator/ospametric.py Outdated Show resolved Hide resolved
@sdhiscocks sdhiscocks merged commit 06477ce into dstl:main Jan 9, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants