Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the switching term to GOSPA. #898

Merged
merged 9 commits into from
Jan 9, 2024
103 changes: 85 additions & 18 deletions stonesoup/metricgenerator/ospametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,47 @@
from ..types.metric import SingleTimeMetric, TimeRangeMetric


class _SwitchingLoss:
"""
Holds state assignment history and computes GOSPA switching term
See https://www.mathworks.com/help/fusion/ref/trackgospametric-system-object.html#d126e213697
"""
def __init__(self, loss_factor, p):
self.truth_associations = {}
self.switching_loss = None
self.loss_factor = loss_factor
self.p = p
self.switching_penalty = 0.5

def add_associations(self, truth_associations):
"""
Add a new set of association and update the switching loss.

Parameters
----------
truth_associations: dict(truth_track_id: measurement_track_id)
"""
self.switching_loss = 0

for truth_id, meas_id in truth_associations.items():
if truth_id not in self.truth_associations and meas_id is None:
continue
if truth_id not in self.truth_associations:
self.truth_associations[truth_id] = meas_id
elif self.truth_associations[truth_id] != meas_id:
self.switching_loss += self.switching_penalty
if meas_id is not None and self.truth_associations[truth_id] is not None:
self.switching_loss += self.switching_penalty

self.truth_associations[truth_id] = meas_id

def loss(self):
"""Compute loss based on last association."""
if self.switching_loss is None:
raise RuntimeError("Can't compute switching loss before any association are added.")
return self.loss_factor * self.switching_loss**(1/self.p)


class GOSPAMetric(MetricGenerator):
"""
Computes the Generalized Optimal SubPattern Assignment (GOSPA) metric
Expand All @@ -27,6 +68,7 @@ class GOSPAMetric(MetricGenerator):
"""
p: float = Property(doc="1<=p<infty, exponent.")
c: float = Property(doc="c>0, cutoff distance.")
switching_penalty: float = Property(doc="Penalty term for switching.", default=0.0)
measure: Measure = Property(
default=Euclidean(),
doc="Distance measure to use. Default :class:`~.measures.Euclidean()`")
Expand Down Expand Up @@ -60,12 +102,12 @@ def compute_metric(self, manager):

"""
return self.compute_over_time(
self.extract_states(manager.states_sets[self.tracks_key]),
self.extract_states(manager.states_sets[self.truths_key])
*self.extract_states(manager.states_sets[self.tracks_key], True),
*self.extract_states(manager.states_sets[self.truths_key], True)
)

@staticmethod
def extract_states(object_with_states):
def extract_states(object_with_states, return_ids=False):
"""
Extracts a list of states from a list of (or single) objects
containing states. This method is defined to handle :class:`~.StateMutableSequence`
Expand All @@ -75,25 +117,31 @@ def extract_states(object_with_states):
----------
object_with_states: object containing a list of states
Method of state extraction depends on the type of the object
return_ids: If we should return obj ids as well.

Returns
-------
: list of :class:`~.State`
"""

state_list = StateMutableSequence()
for element in list(object_with_states):
ids = []
for i, element in enumerate(list(object_with_states)):
if isinstance(element, StateMutableSequence):
state_list.extend(element.states)
ids.extend([i]*len(element.states))
elif isinstance(element, State):
state_list.append(element)
ids.extend([i])
else:
raise ValueError(
"{!r} has no state extraction method".format(element))

if return_ids:
return state_list, ids
return state_list

def compute_over_time(self, measured_states, truth_states):
def compute_over_time(self, measured_states, measured_state_ids, truth_states,
truth_state_ids):
"""
Compute the GOSPA metric at every timestep from a list of measured
states and truth states.
Expand All @@ -102,7 +150,9 @@ def compute_over_time(self, measured_states, truth_states):
----------

measured_states: List of states created by a filter
measured_state_ids: ids for which state belongs in
truth_states: List of truth states to compare against
truth_state_ids: ids for which truth state belongs in

Returns
-------
Expand All @@ -117,17 +167,28 @@ def compute_over_time(self, measured_states, truth_states):
state.timestamp
for state in chain(measured_states, truth_states)})

switching_metric = _SwitchingLoss(self.switching_penalty, self.p)
gospa_metrics = []

for timestamp in timestamps:
meas_points = [state
for state in measured_states
if state.timestamp == timestamp]
truth_points = [state
for state in truth_states
if state.timestamp == timestamp]
meas_mask = [state.timestamp == timestamp for state in measured_states]
meas_points = np.array(measured_states)[meas_mask]
meas_ids = np.array(measured_state_ids)[meas_mask]

truth_mask = [state.timestamp == timestamp for state in truth_states]
truth_points = np.array(truth_states)[truth_mask]
truth_ids = np.array(truth_state_ids)[truth_mask]

metric, truth_to_measured_assignment = self.compute_gospa_metric(
meas_points, truth_points)
truth_mapping = {
truth_id: meas_ids[meas_id] if meas_id != -1 else None
for truth_id, meas_id in zip(truth_ids, truth_to_measured_assignment)}

switching_metric.add_associations(truth_mapping)
metric.value['switching'] = switching_metric.loss()
metric.value['distance'] = np.power(metric.value['distance']**self.alpha +
metric.value['switching']**self.alpha,
1.0/self.alpha)
gospa_metrics.append(metric)

# If only one timestamp is present then return a SingleTimeMetric
Expand Down Expand Up @@ -331,6 +392,12 @@ def compute_gospa_metric(self, measured_states, truth_states):
truth_to_measured_assignment, measured_to_truth_assignment, _ =\
self.compute_assignments(cost_matrix,
10 * num_truth_states * num_measured_states)

opt_cost -= np.sum(measured_to_truth_assignment == unassigned_index) * dummy_cost
if self.alpha == 2:
gospa_metric['false'] -= \
np.sum(measured_to_truth_assignment == unassigned_index)*dummy_cost

# Now use assignments to compute bids
for i in range(num_truth_states):
if truth_to_measured_assignment[i] != unassigned_index:
Expand All @@ -348,15 +415,15 @@ def compute_gospa_metric(self, measured_states, truth_states):

gospa_metric['false'] -= \
dummy_cost*(cost_matrix[i, const_assign] == const_cmp)

if cost_matrix[i, const_assign] == const_cmp:
truth_to_measured_assignment[i] = unassigned_index

else:
opt_cost = opt_cost - dummy_cost
if self.alpha == 2:
gospa_metric['missed'] -= dummy_cost

opt_cost -= np.sum(measured_to_truth_assignment == unassigned_index) * dummy_cost
if self.alpha == 2:
gospa_metric['false'] -= \
np.sum(measured_to_truth_assignment == unassigned_index)*dummy_cost
gospa_metric['distance'] = np.power((-1. * opt_cost), 1 / self.p)
gospa_metric['localisation'] *= -1.
gospa_metric['missed'] *= -1.
Expand Down Expand Up @@ -389,7 +456,7 @@ class OSPAMetric(GOSPAMetric):
"from MultiManager",
default='ospa_generator')

def compute_over_time(self, measured_states, truth_states):
def compute_over_time(self, measured_states, meas_ids, truth_states, truth_ids):
"""Compute the OSPA metric at every timestep from a list of measured
states and truth states

Expand Down
110 changes: 109 additions & 1 deletion stonesoup/metricgenerator/tests/test_ospametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from ..manager import MultiManager
from ..ospametric import GOSPAMetric, OSPAMetric
from ..ospametric import GOSPAMetric, OSPAMetric, _SwitchingLoss
from ...types.detection import Detection
from ...types.groundtruth import GroundTruthPath, GroundTruthState
from ...types.state import State
Expand Down Expand Up @@ -310,6 +310,66 @@ def test_ospametric_computemetric(p):
assert second_association.generator == generator


def test_switching_gospametric_computemetric():
"""Test GOSPA compute metric."""
max_penalty = 2
switching_penalty = 3
p = 2
generator = GOSPAMetric(
c=max_penalty,
p=p,
switching_penalty=switching_penalty
)

time = datetime.datetime.now()
times = [time.now() + datetime.timedelta(seconds=i) for i in range(3)]
tracks = {Track(states=[State(state_vector=[[i]], timestamp=time) for i, time
in zip([1, 2, 2], times)]),
Track(states=[State(state_vector=[[i]], timestamp=time) for i, time
in zip([2, 1, 1], times)]),
Track(states=[State(state_vector=[[i]], timestamp=time) for i, time
in zip([3, 100, 3], times)])}

truths = {GroundTruthPath(states=[State(state_vector=[[i]], timestamp=time)
for i, time in zip([1, 1, 1], times)]),
GroundTruthPath(states=[State(state_vector=[[i]], timestamp=time)
for i, time in zip([2, 2, 2], times)]),
GroundTruthPath(states=[State(state_vector=[[i]], timestamp=time)
for i, time in zip([3, 3, 3], times)])}

manager = MultiManager([generator])
manager.add_data({'groundtruth_paths': truths, 'tracks': tracks})
main_metric = generator.compute_metric(manager)
first_association, second_association, third_association = main_metric.value

assert main_metric.time_range.start_timestamp == times[0]

assert first_association.value['distance'] == 0
assert first_association.value['localisation'] == 0
assert first_association.value['missed'] == 0
assert first_association.value['false'] == 0
assert first_association.value['switching'] == 0
assert first_association.timestamp == times[0]
assert first_association.generator == generator

assert abs(second_association.value['distance'] - np.power(
max_penalty**p + (2.5**(1/p)*switching_penalty)**p, 1./p)) < 1e-9
assert second_association.value['localisation'] == 0
assert second_association.value['missed'] == 1*max_penalty
assert second_association.value['false'] == 1*max_penalty
assert abs(second_association.value['switching'] - 2.5**(1/p)*switching_penalty) < 1e-9
assert second_association.timestamp == times[1]
assert second_association.generator == generator

assert abs(third_association.value['distance'] - 0.5**(1/p)*switching_penalty) < 1e-9
assert third_association.value['localisation'] == 0
assert third_association.value['missed'] == 0
assert third_association.value['false'] == 0
assert abs(third_association.value['switching'] - 0.5**(1/p)*switching_penalty) < 1e-9
assert third_association.timestamp == times[2]
assert third_association.generator == generator


@pytest.mark.parametrize(
'p,first_value,second_value',
((1, 2.4, 2.16), (2, 4.49444, 4.47571), (np.inf, 10, 10)),
Expand Down Expand Up @@ -351,3 +411,51 @@ def test_ospa_computemetric_cardinality_error(p, first_value, second_value):
assert second_association.value == pytest.approx(second_value)
assert second_association.timestamp == time + datetime.timedelta(seconds=1)
assert second_association.generator == generator


@pytest.mark.parametrize("associations, expected_losses", [
([
{0: 0, 1: 1, 2: 2},
{0: 0, 1: 1, 2: 2},
{0: 0, 1: 1, 2: None},
{0: 0, 1: 1, 2: 2},
{0: 1, 1: 0, 2: 2},
{1: 0, 0: 1, 2: 2},
{0: None, 1: 2, 2: 0},
], [0, 0, 0.5, 0.5, 2, 0, 2.5]),
([
{0: 0, 1: 1, 2: 2},
{0: None, 1: None, 2: None},
{0: 3, 1: None, 2: None},
], [0, 1.5, 0.5]),
([
{0: None, 1: None, 2: None},
{0: 0, 1: 1, 2: 2},
], [0, 0]),
([ # The first time we associate with the track it should not count for loss
{0: None, 1: None, 2: None},
{0: 0, 1: 1, 2: 2},
], [0, 0]),
([ # The first time we associate with the track it should not count for loss
{0: 0},
{0: 0, 1: 1},
{0: 0, 1: 1, 2: 3}
], [0, 0, 0]),
([ # We don't want loss if we just didn't see it
{0: 0, 1: 1},
{0: 0},
{0: 0, 1: 1},
{0: 0, 1: 2}
], [0, 0, 0, 1]),
])
def test_switching_loss(associations, expected_losses):
loss_factor = 1

switching_loss = _SwitchingLoss(loss_factor, 1)

with pytest.raises(RuntimeError) as _:
switching_loss.loss() # Should raise error if no associations have been added yet.

for association, expected_loss in zip(associations, expected_losses):
switching_loss.add_associations(association)
assert switching_loss.loss() == expected_loss