-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added the switching term to GOSPA. #898
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Carlson-J. Overall looks good.
My main suggestion, is it'd probably be cleaner to use None
and empty dict, rather than specific numbers, to represent unseen/not-associated state. Minor suggestion is making SwitchingLoss
"private" (or possibly just part of GOSPA itself, rather than separate class). Diff of suggestion below:
diff --git a/stonesoup/metricgenerator/ospametric.py b/stonesoup/metricgenerator/ospametric.py
index 2ba6ff4e..737427e5 100644
--- a/stonesoup/metricgenerator/ospametric.py
+++ b/stonesoup/metricgenerator/ospametric.py
@@ -11,17 +11,14 @@ from ..types.time import TimeRange
from ..types.metric import SingleTimeMetric, TimeRangeMetric
-class SwitchingLoss:
+class _SwitchingLoss:
"""
Holds state assignment history and computes GOSPA switching term
See https://www.mathworks.com/help/fusion/ref/trackgospametric-system-object.html#d126e213697
"""
- def __init__(self, truth_ids, loss_factor, p):
- self.not_associated = -1
- self.unseen = -2
- self.truth_associations = {i: self.unseen for i in truth_ids}
- self.switching_loss = 0
- self.has_associations = False
+ def __init__(self, loss_factor, p):
+ self.truth_associations = {}
+ self.switching_loss = None
self.loss_factor = loss_factor
self.p = p
self.switching_penalty = 0.5
@@ -34,25 +31,23 @@ class SwitchingLoss:
----------
truth_associations: dict(truth_track_id: measurement_track_id)
"""
- self.has_associations = True
self.switching_loss = 0
for truth_id, meas_id in truth_associations.items():
- if self.truth_associations[truth_id] == self.unseen and meas_id == self.not_associated:
+ if truth_id not in self.truth_associations and meas_id is None:
continue
- if self.truth_associations[truth_id] == self.unseen:
+ if truth_id not in self.truth_associations:
self.truth_associations[truth_id] = meas_id
elif self.truth_associations[truth_id] != meas_id:
self.switching_loss += self.switching_penalty
- if (meas_id != self.not_associated
- and self.truth_associations[truth_id] != self.not_associated):
+ if meas_id is not None and self.truth_associations[truth_id] is not None:
self.switching_loss += self.switching_penalty
self.truth_associations[truth_id] = meas_id
def loss(self):
"""Compute loss based on last association."""
- if not self.has_associations:
+ if self.switching_loss is None:
raise RuntimeError("Can't compute switching loss before any association are added.")
return self.loss_factor * self.switching_loss**(1/self.p)
@@ -172,7 +167,7 @@ class GOSPAMetric(MetricGenerator):
state.timestamp
for state in chain(measured_states, truth_states)})
- switching_metric = SwitchingLoss(truth_state_ids, self.switching_penalty, self.p)
+ switching_metric = _SwitchingLoss(self.switching_penalty, self.p)
gospa_metrics = []
for timestamp in timestamps:
meas_mask = [state.timestamp == timestamp for state in measured_states]
@@ -185,9 +180,9 @@ class GOSPAMetric(MetricGenerator):
metric, truth_to_measured_assignment = self.compute_gospa_metric(
meas_points, truth_points)
- truth_mapping = {}
- for i, meas_id in enumerate(truth_to_measured_assignment):
- truth_mapping[truth_ids[i]] = -1 if meas_id == -1 else meas_ids[meas_id]
+ truth_mapping = {
+ truth_id: meas_ids[meas_id] if meas_id != -1 else None
+ for truth_id, meas_id in zip(truth_ids, truth_to_measured_assignment)}
switching_metric.add_associations(truth_mapping)
metric.value['switching'] = switching_metric.loss()
diff --git a/stonesoup/metricgenerator/tests/test_ospametric.py b/stonesoup/metricgenerator/tests/test_ospametric.py
index 9d33c11e..58a91421 100644
--- a/stonesoup/metricgenerator/tests/test_ospametric.py
+++ b/stonesoup/metricgenerator/tests/test_ospametric.py
@@ -5,7 +5,7 @@ import numpy as np
import pytest
from ..manager import MultiManager
-from ..ospametric import GOSPAMetric, OSPAMetric, SwitchingLoss
+from ..ospametric import GOSPAMetric, OSPAMetric, _SwitchingLoss
from ...types.detection import Detection
from ...types.groundtruth import GroundTruthPath, GroundTruthState
from ...types.state import State
@@ -416,23 +416,23 @@ def test_ospa_computemetric_cardinality_error(p, first_value, second_value):
([
{0: 0, 1: 1, 2: 2},
{0: 0, 1: 1, 2: 2},
- {0: 0, 1: 1, 2: -1},
+ {0: 0, 1: 1, 2: None},
{0: 0, 1: 1, 2: 2},
{0: 1, 1: 0, 2: 2},
{1: 0, 0: 1, 2: 2},
- {0: -1, 1: 2, 2: 0},
+ {0: None, 1: 2, 2: 0},
], [0, 0, 0.5, 0.5, 2, 0, 2.5]),
([
{0: 0, 1: 1, 2: 2},
- {0: -1, 1: -1, 2: -1},
- {0: 3, 1: -1, 2: -1},
+ {0: None, 1: None, 2: None},
+ {0: 3, 1: None, 2: None},
], [0, 1.5, 0.5]),
([
- {0: -1, 1: -1, 2: -1},
+ {0: None, 1: None, 2: None},
{0: 0, 1: 1, 2: 2},
], [0, 0]),
([ # The first time we associate with the track it should not count for loss
- {0: -1, 1: -1, 2: -1},
+ {0: None, 1: None, 2: None},
{0: 0, 1: 1, 2: 2},
], [0, 0]),
([ # The first time we associate with the track it should not count for loss
@@ -451,7 +451,7 @@ def test_switching_loss(associations, expected_losses):
loss_factor = 1
truth_ids = list(range(3))
- switching_loss = SwitchingLoss(truth_ids, loss_factor, 1)
+ switching_loss = _SwitchingLoss(loss_factor, 1)
with pytest.raises(RuntimeError) as _:
switching_loss.loss() # Should raise error if no associations have been added yet.
Fixes issue described in #897
Adds the switching term to the GOSPA as outlined in #876