Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLEAR MOT Associator #1017

Merged
merged 19 commits into from
Jul 18, 2024
Merged

CLEAR MOT Associator #1017

merged 19 commits into from
Jul 18, 2024

Conversation

kopytjuk
Copy link
Contributor

@kopytjuk kopytjuk commented May 16, 2024

Implements the Classification of Events, Activities, and Relationships (CLEAR) Multi-Object Tracking (MOT) association scheme. This assigns a single track trajectory to a single GT trajectory.

image

Image from [1]

In one of the next PRs I would like to implement MOTP and MOTA metrics.

[1] Bernardin, K., Stiefelhagen, R. Evaluating Multiple Object Tracking Performance: The CLEAR MOT Metrics. J Image Video Proc 2008, 246309 (2008). https://doi.org/10.1155/2008/246309

@kopytjuk kopytjuk requested a review from a team as a code owner May 16, 2024 16:18
@kopytjuk kopytjuk requested review from nperree-dstl and spike-dstl and removed request for a team May 16, 2024 16:19
@kopytjuk kopytjuk marked this pull request as draft May 16, 2024 16:19
@sdhiscocks
Copy link
Member

Hey @kopytjuk Thanks for the PR. I think this would be a great addition 👍

@kopytjuk kopytjuk marked this pull request as ready for review May 26, 2024 10:59
@kopytjuk
Copy link
Contributor Author

Hey @kopytjuk Thanks for the PR. I think this would be a great addition 👍

Hey @sdhiscocks, thanks for the reply - the PR is ready for review!

Copy link

codecov bot commented May 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 93.65%. Comparing base (ab6af2b) to head (d7ab2c1).
Report is 137 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1017      +/-   ##
==========================================
+ Coverage   93.60%   93.65%   +0.05%     
==========================================
  Files         202      203       +1     
  Lines       12990    13104     +114     
  Branches     2651     2674      +23     
==========================================
+ Hits        12159    12273     +114     
  Misses        588      588              
  Partials      243      243              
Flag Coverage Δ
integration 65.91% <27.04%> (-0.37%) ⬇️
unittests 89.30% <99.18%> (+0.08%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@sdhiscocks sdhiscocks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @kopytjuk

Main suggestion I have, is around having one main additional helper look-up, as a dictionary for time to ID to state mapping.

This reduces a fair amount of extra code for extracting states by time, and speeds up the code a bit by avoiding searching.

diff --git a/stonesoup/dataassociator/clearmot.py b/stonesoup/dataassociator/clearmot.py
index 6d37473e..be1264f7 100644
--- a/stonesoup/dataassociator/clearmot.py
+++ b/stonesoup/dataassociator/clearmot.py
@@ -1,6 +1,7 @@
 import datetime
 import itertools
 import sys
+from collections import defaultdict
 from itertools import chain
 from typing import (
     Any,
@@ -9,7 +10,6 @@ from typing import (
     Iterable,
     List,
     MutableSequence,
-    Optional,
     Set,
     Tuple,
 )
@@ -22,12 +22,12 @@ from ..base import Property
 from ..measures import Euclidean, Measure
 from ..types.association import AssociationSet, TimeRangeAssociation
 from ..types.groundtruth import GroundTruthPath
-from ..types.state import State, StateMutableSequence
+from ..types.state import State
 from ..types.time import TimeRange
 from ..types.track import Track
 from .base import TwoTrackToTrackAssociator
 
-StatesFromIdLookup = Dict[str, MutableSequence[State]]
+StatesFromTimeIdLookup = Dict[datetime.datetime, Dict[str, State]]
 
 
 class ClearMotAssociator(TwoTrackToTrackAssociator):
@@ -79,11 +79,25 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
         """
 
         # helper look-ups
-        truth_states_by_id: StatesFromIdLookup = {truth.id: truth.states for truth in truth_set}
-        track_states_by_id: StatesFromIdLookup = {track.id: track.states for track in tracks_set}
+        truth_tracks_by_id = {truth.id: truth for truth in truth_set}
+        estim_tracks_by_id = {track.id: track for track in tracks_set}
+
+        timestamps = set()
+
+        truth_states_by_time_id: StatesFromTimeIdLookup = defaultdict(dict)
+        for truth in truth_set:
+            for state in truth.last_timestamp_generator():
+                truth_states_by_time_id[state.timestamp][truth.id] = state
+                timestamps.add(state.timestamp)
+
+        track_states_by_time_id: StatesFromTimeIdLookup = defaultdict(dict)
+        for track in tracks_set:
+            for state in track.last_timestamp_generator():
+                track_states_by_time_id[state.timestamp][track.id] = state
+                timestamps.add(state.timestamp)
 
         # Make a sorted list of all the unique timestamps used
-        timestamps = self.determine_unique_timestamps(tracks_set, truth_set)
+        timestamps = sorted(timestamps)
 
         # we use this to collect match sets over time
         matches_over_time = []
@@ -93,16 +107,15 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
 
         for current_time in timestamps:
 
-            truth_ids_at_current_time, track_ids_at_current_time = \
-                self._get_truth_and_track_ids_from_timestamp(truth_states_by_id,
-                                                             track_states_by_id,
-                                                             current_time)
+            truth_ids_at_current_time = OrderedSet(truth_states_by_time_id[current_time])
+            track_ids_at_current_time = OrderedSet(track_states_by_time_id[current_time])
+            truth_states_by_id = truth_states_by_time_id[current_time]
+            track_states_by_id = track_states_by_time_id[current_time]
 
             matches_current = \
                 self._initialize_matches_from_previous_timestep(truth_states_by_id,
                                                                 track_states_by_id,
                                                                 matches_previous,
-                                                                current_time,
                                                                 truth_ids_at_current_time,
                                                                 track_ids_at_current_time,
                                                                 )
@@ -116,7 +129,6 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
 
             matches_from_unassigned = self._match_unassigned_tracks(truth_states_by_id,
                                                                     track_states_by_id,
-                                                                    current_time,
                                                                     truth_ids_at_current_time,
                                                                     track_ids_at_current_time)
             matches_current |= matches_from_unassigned
@@ -125,33 +137,18 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
             matches_previous = matches_current
 
         associations = self._create_associations_from_matches_over_time(
-            tracks_set, truth_set, timestamps, matches_over_time)
+            estim_tracks_by_id, truth_tracks_by_id, timestamps, matches_over_time)
 
         return AssociationSet(associations)
 
-    def _get_truth_and_track_ids_from_timestamp(self,
-                                                truth_states_by_id: StatesFromIdLookup,
-                                                track_states_by_id: StatesFromIdLookup,
-                                                timestamp: datetime.datetime):
-        truth_ids_at_current_time = [truth_id for (truth_id, truth_states)
-                                     in truth_states_by_id.items()
-                                     if get_state_at_time(truth_states, timestamp)]
-        track_ids_at_current_time = [track_id for (track_id, track_states)
-                                     in track_states_by_id.items()
-                                     if get_state_at_time(track_states, timestamp)]
-
-        return truth_ids_at_current_time, track_ids_at_current_time
-
-    def _create_associations_from_matches_over_time(self, tracks_set: Set[Track],
-                                                    truth_set: Set[GroundTruthPath],
+    def _create_associations_from_matches_over_time(self,
+                                                    estim_tracks_by_id: Dict[str, Track],
+                                                    truth_tracks_by_id: Dict[str, GroundTruthPath],
                                                     timestamps: MutableSequence[datetime.datetime],
                                                     matches_over_time: List[Set[Tuple[str, str]]]):
         unique_matches = {
             match for matches_timestamp in matches_over_time for match in matches_timestamp}
 
-        truth_tracks_by_id = {truth.id: truth for truth in truth_set}
-        estim_tracks_by_id = {track.id: track for track in tracks_set}
-
         associations = set()
         for match in unique_matches:
             timesteps_where_match_exists = list()
@@ -172,12 +169,11 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
         return associations
 
     def _initialize_matches_from_previous_timestep(self,
-                                                   truth_states_by_id: StatesFromIdLookup,
-                                                   track_states_by_id: StatesFromIdLookup,
+                                                   truth_states_by_id: Dict[str, State],
+                                                   track_states_by_id: Dict[str, State],
                                                    matches_previous: Set[Tuple[str, str]],
-                                                   current_time: datetime.datetime,
-                                                   truth_ids_at_current_time: Set[str],
-                                                   track_ids_at_current_time: Set[str]) \
+                                                   truth_ids_at_current_time: OrderedSet[str],
+                                                   track_ids_at_current_time: OrderedSet[str]) \
             -> Set[Tuple[str, str]]:
         """Checks if matches from the previous timestep are still valid by their distance and
         adds them to the returned set of matches.
@@ -192,18 +188,10 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
         # assication threshold - if true, we keep it and add it to current set,
         # if not we do not maintain the match
         for (track_id, truth_id) in matches_previous:
-            # get
-            truth_states = truth_states_by_id[truth_id]
-            truth_state_current = get_state_at_time(truth_states, current_time)
-
-            if not truth_state_current:
-                continue
-
-            track_states = track_states_by_id[track_id]
-            track_state_current = get_state_at_time(track_states, current_time)
-
-            # if hypothesis is not available anymore
-            if not track_state_current:
+            try:
+                truth_state_current = truth_states_by_id[truth_id]
+                track_state_current = track_states_by_id[track_id]
+            except KeyError:
                 continue
 
             distance = self.measure(track_state_current, truth_state_current)
@@ -216,11 +204,11 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
                 track_ids_at_current_time.remove(track_id)
         return matches_current
 
-    def _match_unassigned_tracks(self, truth_states_by_id: StatesFromIdLookup,
-                                 track_states_by_id: StatesFromIdLookup,
-                                 current_time: datetime.datetime,
-                                 truth_ids_at_current_time: Set[str],
-                                 track_ids_at_current_time: Set[str]) -> Set[Tuple[str, str]]:
+    def _match_unassigned_tracks(self, truth_states_by_id: Dict[str, State],
+                                 track_states_by_id: Dict[str, State],
+                                 truth_ids_at_current_time: OrderedSet[str],
+                                 track_ids_at_current_time: OrderedSet[str]
+                                 ) -> Set[Tuple[str, str]]:
         """Match unassigned tracks using Munkers algorithm and distance threshold.
         """
         num_truth_unassigned = len(truth_ids_at_current_time)
@@ -233,10 +221,8 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
             for j in range(num_tracks_unassigned):
                 truth_id, track_id = truth_ids_at_current_time[i], track_ids_at_current_time[j]
 
-                truth_states = truth_states_by_id[truth_id]
-                track_states = track_states_by_id[track_id]
-                truth_state_current = get_state_at_time(truth_states, current_time)
-                track_state_current = get_state_at_time(track_states, current_time)
+                truth_state_current = truth_states_by_id[truth_id]
+                track_state_current = track_states_by_id[track_id]
                 distance = self.measure(track_state_current, truth_state_current)
                 cost_matrix[i, j] = distance
 
@@ -248,62 +234,6 @@ class ClearMotAssociator(TwoTrackToTrackAssociator):
                 matches.add((track_ids_at_current_time[j], truth_ids_at_current_time[i]))
         return matches
 
-    def determine_unique_timestamps(self, tracks_set: Set[Track], truth_set: Set[GroundTruthPath])\
-            -> List[datetime.datetime]:
-
-        track_states = extract_states(tracks_set)
-        truth_states = extract_states(truth_set)
-        timestamps = sorted({
-            state.timestamp
-            for state in chain(track_states, truth_states)})
-        return timestamps
-
-
-def extract_states(object_with_states, return_ids=False) -> List[State]:
-    """
-    NOTE: copy of stonesoup/metricgenerator/ospametric.py
-
-    Extracts a list of states from a list of (or single) objects
-    containing states. This method is defined to handle :class:`~.StateMutableSequence`
-    and :class:`~.State` types.
-
-    Parameters
-    ----------
-    object_with_states: object containing a list of states
-        Method of state extraction depends on the type of the object
-    return_ids: If we should return obj ids as well.
-
-    Returns
-    -------
-    : list of :class:`~.State`
-    """
-
-    state_list = StateMutableSequence()
-    ids = []
-    for i, element in enumerate(list(object_with_states)):
-        if isinstance(element, StateMutableSequence):
-            state_list.extend(element.states)
-            ids.extend([i]*len(element.states))
-        elif isinstance(element, State):
-            state_list.append(element)
-            ids.extend([i])
-        else:
-            raise ValueError(
-                "{!r} has no state extraction method".format(element))
-    if return_ids:
-        return state_list, ids
-    return state_list
-
-
-def get_state_at_time(state_sequence: MutableSequence[State],
-                      timestamp: datetime.datetime) -> Optional[State]:
-    """Returns a state instance from a sequence of states for a given timestamp.
-    Returns None if no data available."""
-    try:
-        return Track(state_sequence)[timestamp]
-    except IndexError:
-        return None
-
 
 def get_strictly_monotonously_increasing_intervals(arr: MutableSequence[int])\
         -> List[Tuple[int, int]]:

@kopytjuk
Copy link
Contributor Author

kopytjuk commented Jul 1, 2024

Hey @sdhiscocks, thanks for your suggestion - I added that - it makes the code definitely more clear and more performant as well!

I changed parts of the code and tried to hide implementation details within private methods. I think the code is ready to review now. Building up on the associations, I would like to implement the MOTA & MOTP metrics in my next PR :)

@kopytjuk kopytjuk requested a review from sdhiscocks July 1, 2024 18:18
@sdhiscocks
Copy link
Member

Thanks again for the contribution.

Just noted one minor change required to include new module in docs:

diff --git a/docs/source/stonesoup.dataassociator.rst b/docs/source/stonesoup.dataassociator.rst
index dd9d52be..6e2d0034 100644
--- a/docs/source/stonesoup.dataassociator.rst
+++ b/docs/source/stonesoup.dataassociator.rst
@@ -31,6 +31,12 @@ Track-to-track Association
 .. automodule:: stonesoup.dataassociator.tracktotrack
     :show-inheritance:
 
+CLEAR MOT Association
+---------------------
+
+.. automodule:: stonesoup.dataassociator.clearmot
+    :show-inheritance:
+
 Trees
 --------------------------
 

@sdhiscocks sdhiscocks merged commit e4e50ad into dstl:main Jul 18, 2024
10 checks passed
@kopytjuk kopytjuk deleted the clear-mot-associator branch July 20, 2024 07:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants