Skip to content

Commit

Permalink
Merge pull request #898 from Carlson-J/switching_GOSPA
Browse files Browse the repository at this point in the history
Added the switching term to GOSPA.
  • Loading branch information
sdhiscocks authored Jan 9, 2024
2 parents e25e0d2 + ebbf678 commit 06477ce
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 19 deletions.
103 changes: 85 additions & 18 deletions stonesoup/metricgenerator/ospametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,47 @@
from ..types.metric import SingleTimeMetric, TimeRangeMetric


class _SwitchingLoss:
"""
Holds state assignment history and computes GOSPA switching term
See https://www.mathworks.com/help/fusion/ref/trackgospametric-system-object.html#d126e213697
"""
def __init__(self, loss_factor, p):
self.truth_associations = {}
self.switching_loss = None
self.loss_factor = loss_factor
self.p = p
self.switching_penalty = 0.5

def add_associations(self, truth_associations):
"""
Add a new set of association and update the switching loss.
Parameters
----------
truth_associations: dict(truth_track_id: measurement_track_id)
"""
self.switching_loss = 0

for truth_id, meas_id in truth_associations.items():
if truth_id not in self.truth_associations and meas_id is None:
continue
if truth_id not in self.truth_associations:
self.truth_associations[truth_id] = meas_id
elif self.truth_associations[truth_id] != meas_id:
self.switching_loss += self.switching_penalty
if meas_id is not None and self.truth_associations[truth_id] is not None:
self.switching_loss += self.switching_penalty

self.truth_associations[truth_id] = meas_id

def loss(self):
"""Compute loss based on last association."""
if self.switching_loss is None:
raise RuntimeError("Can't compute switching loss before any association are added.")
return self.loss_factor * self.switching_loss**(1/self.p)


class GOSPAMetric(MetricGenerator):
"""
Computes the Generalized Optimal SubPattern Assignment (GOSPA) metric
Expand All @@ -27,6 +68,7 @@ class GOSPAMetric(MetricGenerator):
"""
p: float = Property(doc="1<=p<infty, exponent.")
c: float = Property(doc="c>0, cutoff distance.")
switching_penalty: float = Property(doc="Penalty term for switching.", default=0.0)
measure: Measure = Property(
default=Euclidean(),
doc="Distance measure to use. Default :class:`~.measures.Euclidean()`")
Expand Down Expand Up @@ -60,12 +102,12 @@ def compute_metric(self, manager):
"""
return self.compute_over_time(
self.extract_states(manager.states_sets[self.tracks_key]),
self.extract_states(manager.states_sets[self.truths_key])
*self.extract_states(manager.states_sets[self.tracks_key], True),
*self.extract_states(manager.states_sets[self.truths_key], True)
)

@staticmethod
def extract_states(object_with_states):
def extract_states(object_with_states, return_ids=False):
"""
Extracts a list of states from a list of (or single) objects
containing states. This method is defined to handle :class:`~.StateMutableSequence`
Expand All @@ -75,25 +117,31 @@ def extract_states(object_with_states):
----------
object_with_states: object containing a list of states
Method of state extraction depends on the type of the object
return_ids: If we should return obj ids as well.
Returns
-------
: list of :class:`~.State`
"""

state_list = StateMutableSequence()
for element in list(object_with_states):
ids = []
for i, element in enumerate(list(object_with_states)):
if isinstance(element, StateMutableSequence):
state_list.extend(element.states)
ids.extend([i]*len(element.states))
elif isinstance(element, State):
state_list.append(element)
ids.extend([i])
else:
raise ValueError(
"{!r} has no state extraction method".format(element))

if return_ids:
return state_list, ids
return state_list

def compute_over_time(self, measured_states, truth_states):
def compute_over_time(self, measured_states, measured_state_ids, truth_states,
truth_state_ids):
"""
Compute the GOSPA metric at every timestep from a list of measured
states and truth states.
Expand All @@ -102,7 +150,9 @@ def compute_over_time(self, measured_states, truth_states):
----------
measured_states: List of states created by a filter
measured_state_ids: ids for which state belongs in
truth_states: List of truth states to compare against
truth_state_ids: ids for which truth state belongs in
Returns
-------
Expand All @@ -117,17 +167,28 @@ def compute_over_time(self, measured_states, truth_states):
state.timestamp
for state in chain(measured_states, truth_states)})

switching_metric = _SwitchingLoss(self.switching_penalty, self.p)
gospa_metrics = []

for timestamp in timestamps:
meas_points = [state
for state in measured_states
if state.timestamp == timestamp]
truth_points = [state
for state in truth_states
if state.timestamp == timestamp]
meas_mask = [state.timestamp == timestamp for state in measured_states]
meas_points = np.array(measured_states)[meas_mask]
meas_ids = np.array(measured_state_ids)[meas_mask]

truth_mask = [state.timestamp == timestamp for state in truth_states]
truth_points = np.array(truth_states)[truth_mask]
truth_ids = np.array(truth_state_ids)[truth_mask]

metric, truth_to_measured_assignment = self.compute_gospa_metric(
meas_points, truth_points)
truth_mapping = {
truth_id: meas_ids[meas_id] if meas_id != -1 else None
for truth_id, meas_id in zip(truth_ids, truth_to_measured_assignment)}

switching_metric.add_associations(truth_mapping)
metric.value['switching'] = switching_metric.loss()
metric.value['distance'] = np.power(metric.value['distance']**self.alpha +
metric.value['switching']**self.alpha,
1.0/self.alpha)
gospa_metrics.append(metric)

# If only one timestamp is present then return a SingleTimeMetric
Expand Down Expand Up @@ -331,6 +392,12 @@ def compute_gospa_metric(self, measured_states, truth_states):
truth_to_measured_assignment, measured_to_truth_assignment, _ =\
self.compute_assignments(cost_matrix,
10 * num_truth_states * num_measured_states)

opt_cost -= np.sum(measured_to_truth_assignment == unassigned_index) * dummy_cost
if self.alpha == 2:
gospa_metric['false'] -= \
np.sum(measured_to_truth_assignment == unassigned_index)*dummy_cost

# Now use assignments to compute bids
for i in range(num_truth_states):
if truth_to_measured_assignment[i] != unassigned_index:
Expand All @@ -348,15 +415,15 @@ def compute_gospa_metric(self, measured_states, truth_states):

gospa_metric['false'] -= \
dummy_cost*(cost_matrix[i, const_assign] == const_cmp)

if cost_matrix[i, const_assign] == const_cmp:
truth_to_measured_assignment[i] = unassigned_index

else:
opt_cost = opt_cost - dummy_cost
if self.alpha == 2:
gospa_metric['missed'] -= dummy_cost

opt_cost -= np.sum(measured_to_truth_assignment == unassigned_index) * dummy_cost
if self.alpha == 2:
gospa_metric['false'] -= \
np.sum(measured_to_truth_assignment == unassigned_index)*dummy_cost
gospa_metric['distance'] = np.power((-1. * opt_cost), 1 / self.p)
gospa_metric['localisation'] *= -1.
gospa_metric['missed'] *= -1.
Expand Down Expand Up @@ -389,7 +456,7 @@ class OSPAMetric(GOSPAMetric):
"from MultiManager",
default='ospa_generator')

def compute_over_time(self, measured_states, truth_states):
def compute_over_time(self, measured_states, meas_ids, truth_states, truth_ids):
"""Compute the OSPA metric at every timestep from a list of measured
states and truth states
Expand Down
110 changes: 109 additions & 1 deletion stonesoup/metricgenerator/tests/test_ospametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from ..manager import MultiManager
from ..ospametric import GOSPAMetric, OSPAMetric
from ..ospametric import GOSPAMetric, OSPAMetric, _SwitchingLoss
from ...types.detection import Detection
from ...types.groundtruth import GroundTruthPath, GroundTruthState
from ...types.state import State
Expand Down Expand Up @@ -310,6 +310,66 @@ def test_ospametric_computemetric(p):
assert second_association.generator == generator


def test_switching_gospametric_computemetric():
"""Test GOSPA compute metric."""
max_penalty = 2
switching_penalty = 3
p = 2
generator = GOSPAMetric(
c=max_penalty,
p=p,
switching_penalty=switching_penalty
)

time = datetime.datetime.now()
times = [time.now() + datetime.timedelta(seconds=i) for i in range(3)]
tracks = {Track(states=[State(state_vector=[[i]], timestamp=time) for i, time
in zip([1, 2, 2], times)]),
Track(states=[State(state_vector=[[i]], timestamp=time) for i, time
in zip([2, 1, 1], times)]),
Track(states=[State(state_vector=[[i]], timestamp=time) for i, time
in zip([3, 100, 3], times)])}

truths = {GroundTruthPath(states=[State(state_vector=[[i]], timestamp=time)
for i, time in zip([1, 1, 1], times)]),
GroundTruthPath(states=[State(state_vector=[[i]], timestamp=time)
for i, time in zip([2, 2, 2], times)]),
GroundTruthPath(states=[State(state_vector=[[i]], timestamp=time)
for i, time in zip([3, 3, 3], times)])}

manager = MultiManager([generator])
manager.add_data({'groundtruth_paths': truths, 'tracks': tracks})
main_metric = generator.compute_metric(manager)
first_association, second_association, third_association = main_metric.value

assert main_metric.time_range.start_timestamp == times[0]

assert first_association.value['distance'] == 0
assert first_association.value['localisation'] == 0
assert first_association.value['missed'] == 0
assert first_association.value['false'] == 0
assert first_association.value['switching'] == 0
assert first_association.timestamp == times[0]
assert first_association.generator == generator

assert abs(second_association.value['distance'] - np.power(
max_penalty**p + (2.5**(1/p)*switching_penalty)**p, 1./p)) < 1e-9
assert second_association.value['localisation'] == 0
assert second_association.value['missed'] == 1*max_penalty
assert second_association.value['false'] == 1*max_penalty
assert abs(second_association.value['switching'] - 2.5**(1/p)*switching_penalty) < 1e-9
assert second_association.timestamp == times[1]
assert second_association.generator == generator

assert abs(third_association.value['distance'] - 0.5**(1/p)*switching_penalty) < 1e-9
assert third_association.value['localisation'] == 0
assert third_association.value['missed'] == 0
assert third_association.value['false'] == 0
assert abs(third_association.value['switching'] - 0.5**(1/p)*switching_penalty) < 1e-9
assert third_association.timestamp == times[2]
assert third_association.generator == generator


@pytest.mark.parametrize(
'p,first_value,second_value',
((1, 2.4, 2.16), (2, 4.49444, 4.47571), (np.inf, 10, 10)),
Expand Down Expand Up @@ -351,3 +411,51 @@ def test_ospa_computemetric_cardinality_error(p, first_value, second_value):
assert second_association.value == pytest.approx(second_value)
assert second_association.timestamp == time + datetime.timedelta(seconds=1)
assert second_association.generator == generator


@pytest.mark.parametrize("associations, expected_losses", [
([
{0: 0, 1: 1, 2: 2},
{0: 0, 1: 1, 2: 2},
{0: 0, 1: 1, 2: None},
{0: 0, 1: 1, 2: 2},
{0: 1, 1: 0, 2: 2},
{1: 0, 0: 1, 2: 2},
{0: None, 1: 2, 2: 0},
], [0, 0, 0.5, 0.5, 2, 0, 2.5]),
([
{0: 0, 1: 1, 2: 2},
{0: None, 1: None, 2: None},
{0: 3, 1: None, 2: None},
], [0, 1.5, 0.5]),
([
{0: None, 1: None, 2: None},
{0: 0, 1: 1, 2: 2},
], [0, 0]),
([ # The first time we associate with the track it should not count for loss
{0: None, 1: None, 2: None},
{0: 0, 1: 1, 2: 2},
], [0, 0]),
([ # The first time we associate with the track it should not count for loss
{0: 0},
{0: 0, 1: 1},
{0: 0, 1: 1, 2: 3}
], [0, 0, 0]),
([ # We don't want loss if we just didn't see it
{0: 0, 1: 1},
{0: 0},
{0: 0, 1: 1},
{0: 0, 1: 2}
], [0, 0, 0, 1]),
])
def test_switching_loss(associations, expected_losses):
loss_factor = 1

switching_loss = _SwitchingLoss(loss_factor, 1)

with pytest.raises(RuntimeError) as _:
switching_loss.loss() # Should raise error if no associations have been added yet.

for association, expected_loss in zip(associations, expected_losses):
switching_loss.add_associations(association)
assert switching_loss.loss() == expected_loss

0 comments on commit 06477ce

Please sign in to comment.