-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CLEAR MOT Associator #1017
CLEAR MOT Associator #1017
Conversation
Hey @kopytjuk Thanks for the PR. I think this would be a great addition 👍 |
Hey @sdhiscocks, thanks for the reply - the PR is ready for review! |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1017 +/- ##
==========================================
+ Coverage 93.60% 93.65% +0.05%
==========================================
Files 202 203 +1
Lines 12990 13104 +114
Branches 2651 2674 +23
==========================================
+ Hits 12159 12273 +114
Misses 588 588
Partials 243 243
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution @kopytjuk
Main suggestion I have, is around having one main additional helper look-up, as a dictionary for time to ID to state mapping.
This reduces a fair amount of extra code for extracting states by time, and speeds up the code a bit by avoiding searching.
diff --git a/stonesoup/dataassociator/clearmot.py b/stonesoup/dataassociator/clearmot.py
index 6d37473e..be1264f7 100644
--- a/stonesoup/dataassociator/clearmot.py
+++ b/stonesoup/dataassociator/clearmot.py
@@ -1,6 +1,7 @@
import datetime
import itertools
import sys
+from collections import defaultdict
from itertools import chain
from typing import (
Any,
@@ -9,7 +10,6 @@ from typing import (
Iterable,
List,
MutableSequence,
- Optional,
Set,
Tuple,
)
@@ -22,12 +22,12 @@ from ..base import Property
from ..measures import Euclidean, Measure
from ..types.association import AssociationSet, TimeRangeAssociation
from ..types.groundtruth import GroundTruthPath
-from ..types.state import State, StateMutableSequence
+from ..types.state import State
from ..types.time import TimeRange
from ..types.track import Track
from .base import TwoTrackToTrackAssociator
-StatesFromIdLookup = Dict[str, MutableSequence[State]]
+StatesFromTimeIdLookup = Dict[datetime.datetime, Dict[str, State]]
class ClearMotAssociator(TwoTrackToTrackAssociator):
@@ -79,11 +79,25 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
"""
# helper look-ups
- truth_states_by_id: StatesFromIdLookup = {truth.id: truth.states for truth in truth_set}
- track_states_by_id: StatesFromIdLookup = {track.id: track.states for track in tracks_set}
+ truth_tracks_by_id = {truth.id: truth for truth in truth_set}
+ estim_tracks_by_id = {track.id: track for track in tracks_set}
+
+ timestamps = set()
+
+ truth_states_by_time_id: StatesFromTimeIdLookup = defaultdict(dict)
+ for truth in truth_set:
+ for state in truth.last_timestamp_generator():
+ truth_states_by_time_id[state.timestamp][truth.id] = state
+ timestamps.add(state.timestamp)
+
+ track_states_by_time_id: StatesFromTimeIdLookup = defaultdict(dict)
+ for track in tracks_set:
+ for state in track.last_timestamp_generator():
+ track_states_by_time_id[state.timestamp][track.id] = state
+ timestamps.add(state.timestamp)
# Make a sorted list of all the unique timestamps used
- timestamps = self.determine_unique_timestamps(tracks_set, truth_set)
+ timestamps = sorted(timestamps)
# we use this to collect match sets over time
matches_over_time = []
@@ -93,16 +107,15 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
for current_time in timestamps:
- truth_ids_at_current_time, track_ids_at_current_time = \
- self._get_truth_and_track_ids_from_timestamp(truth_states_by_id,
- track_states_by_id,
- current_time)
+ truth_ids_at_current_time = OrderedSet(truth_states_by_time_id[current_time])
+ track_ids_at_current_time = OrderedSet(track_states_by_time_id[current_time])
+ truth_states_by_id = truth_states_by_time_id[current_time]
+ track_states_by_id = track_states_by_time_id[current_time]
matches_current = \
self._initialize_matches_from_previous_timestep(truth_states_by_id,
track_states_by_id,
matches_previous,
- current_time,
truth_ids_at_current_time,
track_ids_at_current_time,
)
@@ -116,7 +129,6 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
matches_from_unassigned = self._match_unassigned_tracks(truth_states_by_id,
track_states_by_id,
- current_time,
truth_ids_at_current_time,
track_ids_at_current_time)
matches_current |= matches_from_unassigned
@@ -125,33 +137,18 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
matches_previous = matches_current
associations = self._create_associations_from_matches_over_time(
- tracks_set, truth_set, timestamps, matches_over_time)
+ estim_tracks_by_id, truth_tracks_by_id, timestamps, matches_over_time)
return AssociationSet(associations)
- def _get_truth_and_track_ids_from_timestamp(self,
- truth_states_by_id: StatesFromIdLookup,
- track_states_by_id: StatesFromIdLookup,
- timestamp: datetime.datetime):
- truth_ids_at_current_time = [truth_id for (truth_id, truth_states)
- in truth_states_by_id.items()
- if get_state_at_time(truth_states, timestamp)]
- track_ids_at_current_time = [track_id for (track_id, track_states)
- in track_states_by_id.items()
- if get_state_at_time(track_states, timestamp)]
-
- return truth_ids_at_current_time, track_ids_at_current_time
-
- def _create_associations_from_matches_over_time(self, tracks_set: Set[Track],
- truth_set: Set[GroundTruthPath],
+ def _create_associations_from_matches_over_time(self,
+ estim_tracks_by_id: Dict[str, Track],
+ truth_tracks_by_id: Dict[str, GroundTruthPath],
timestamps: MutableSequence[datetime.datetime],
matches_over_time: List[Set[Tuple[str, str]]]):
unique_matches = {
match for matches_timestamp in matches_over_time for match in matches_timestamp}
- truth_tracks_by_id = {truth.id: truth for truth in truth_set}
- estim_tracks_by_id = {track.id: track for track in tracks_set}
-
associations = set()
for match in unique_matches:
timesteps_where_match_exists = list()
@@ -172,12 +169,11 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
return associations
def _initialize_matches_from_previous_timestep(self,
- truth_states_by_id: StatesFromIdLookup,
- track_states_by_id: StatesFromIdLookup,
+ truth_states_by_id: Dict[str, State],
+ track_states_by_id: Dict[str, State],
matches_previous: Set[Tuple[str, str]],
- current_time: datetime.datetime,
- truth_ids_at_current_time: Set[str],
- track_ids_at_current_time: Set[str]) \
+ truth_ids_at_current_time: OrderedSet[str],
+ track_ids_at_current_time: OrderedSet[str]) \
-> Set[Tuple[str, str]]:
"""Checks if matches from the previous timestep are still valid by their distance and
adds them to the returned set of matches.
@@ -192,18 +188,10 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
# assication threshold - if true, we keep it and add it to current set,
# if not we do not maintain the match
for (track_id, truth_id) in matches_previous:
- # get
- truth_states = truth_states_by_id[truth_id]
- truth_state_current = get_state_at_time(truth_states, current_time)
-
- if not truth_state_current:
- continue
-
- track_states = track_states_by_id[track_id]
- track_state_current = get_state_at_time(track_states, current_time)
-
- # if hypothesis is not available anymore
- if not track_state_current:
+ try:
+ truth_state_current = truth_states_by_id[truth_id]
+ track_state_current = track_states_by_id[track_id]
+ except KeyError:
continue
distance = self.measure(track_state_current, truth_state_current)
@@ -216,11 +204,11 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
track_ids_at_current_time.remove(track_id)
return matches_current
- def _match_unassigned_tracks(self, truth_states_by_id: StatesFromIdLookup,
- track_states_by_id: StatesFromIdLookup,
- current_time: datetime.datetime,
- truth_ids_at_current_time: Set[str],
- track_ids_at_current_time: Set[str]) -> Set[Tuple[str, str]]:
+ def _match_unassigned_tracks(self, truth_states_by_id: Dict[str, State],
+ track_states_by_id: Dict[str, State],
+ truth_ids_at_current_time: OrderedSet[str],
+ track_ids_at_current_time: OrderedSet[str]
+ ) -> Set[Tuple[str, str]]:
"""Match unassigned tracks using Munkers algorithm and distance threshold.
"""
num_truth_unassigned = len(truth_ids_at_current_time)
@@ -233,10 +221,8 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
for j in range(num_tracks_unassigned):
truth_id, track_id = truth_ids_at_current_time[i], track_ids_at_current_time[j]
- truth_states = truth_states_by_id[truth_id]
- track_states = track_states_by_id[track_id]
- truth_state_current = get_state_at_time(truth_states, current_time)
- track_state_current = get_state_at_time(track_states, current_time)
+ truth_state_current = truth_states_by_id[truth_id]
+ track_state_current = track_states_by_id[track_id]
distance = self.measure(track_state_current, truth_state_current)
cost_matrix[i, j] = distance
@@ -248,62 +234,6 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
matches.add((track_ids_at_current_time[j], truth_ids_at_current_time[i]))
return matches
- def determine_unique_timestamps(self, tracks_set: Set[Track], truth_set: Set[GroundTruthPath])\
- -> List[datetime.datetime]:
-
- track_states = extract_states(tracks_set)
- truth_states = extract_states(truth_set)
- timestamps = sorted({
- state.timestamp
- for state in chain(track_states, truth_states)})
- return timestamps
-
-
-def extract_states(object_with_states, return_ids=False) -> List[State]:
- """
- NOTE: copy of stonesoup/metricgenerator/ospametric.py
-
- Extracts a list of states from a list of (or single) objects
- containing states. This method is defined to handle :class:`~.StateMutableSequence`
- and :class:`~.State` types.
-
- Parameters
- ----------
- object_with_states: object containing a list of states
- Method of state extraction depends on the type of the object
- return_ids: If we should return obj ids as well.
-
- Returns
- -------
- : list of :class:`~.State`
- """
-
- state_list = StateMutableSequence()
- ids = []
- for i, element in enumerate(list(object_with_states)):
- if isinstance(element, StateMutableSequence):
- state_list.extend(element.states)
- ids.extend([i]*len(element.states))
- elif isinstance(element, State):
- state_list.append(element)
- ids.extend([i])
- else:
- raise ValueError(
- "{!r} has no state extraction method".format(element))
- if return_ids:
- return state_list, ids
- return state_list
-
-
-def get_state_at_time(state_sequence: MutableSequence[State],
- timestamp: datetime.datetime) -> Optional[State]:
- """Returns a state instance from a sequence of states for a given timestamp.
- Returns None if no data available."""
- try:
- return Track(state_sequence)[timestamp]
- except IndexError:
- return None
-
def get_strictly_monotonously_increasing_intervals(arr: MutableSequence[int])\
-> List[Tuple[int, int]]:
Hey @sdhiscocks, thanks for your suggestion - I added that - it makes the code definitely more clear and more performant as well! I changed parts of the code and tried to hide implementation details within private methods. I think the code is ready to review now. Building up on the associations, I would like to implement the MOTA & MOTP metrics in my next PR :) |
Thanks again for the contribution. Just noted one minor change required to include new module in docs: diff --git a/docs/source/stonesoup.dataassociator.rst b/docs/source/stonesoup.dataassociator.rst
index dd9d52be..6e2d0034 100644
--- a/docs/source/stonesoup.dataassociator.rst
+++ b/docs/source/stonesoup.dataassociator.rst
@@ -31,6 +31,12 @@ Track-to-track Association
.. automodule:: stonesoup.dataassociator.tracktotrack
:show-inheritance:
+CLEAR MOT Association
+---------------------
+
+.. automodule:: stonesoup.dataassociator.clearmot
+ :show-inheritance:
+
Trees
--------------------------
|
Implements the Classification of Events, Activities, and Relationships (CLEAR) Multi-Object Tracking (MOT) association scheme. This assigns a single track trajectory to a single GT trajectory.
Image from [1]
In one of the next PRs I would like to implement MOTP and MOTA metrics.
[1] Bernardin, K., Stiefelhagen, R. Evaluating Multiple Object Tracking Performance: The CLEAR MOT Metrics. J Image Video Proc 2008, 246309 (2008). https://doi.org/10.1155/2008/246309