From 5358490f3a0fcec8f96a7d15523f9ba3aaccc203 Mon Sep 17 00:00:00 2001 From: gawebb <gawebb@dstl.gov.uk> Date: Thu, 18 Mar 2021 12:33:52 +0000 Subject: [PATCH] Changed all the trackers into iterable trackers This removes the tracks_gen() method on Tracker classes, but instances are iterable the same as BufferedGenerator version. --- docs/demos/AIS_Solent_Tracker.py | 2 +- docs/demos/OpenSky_Demo.py | 2 +- docs/examples/Metrics.py | 2 +- docs/examples/Sensor_Platform_Simulation.py | 2 +- stonesoup/reader/yaml.py | 33 +-- stonesoup/tracker/base.py | 19 +- stonesoup/tracker/pointprocess.py | 46 +++-- stonesoup/tracker/simple.py | 216 +++++++++++--------- stonesoup/writer/tests/conftest.py | 33 +-- 9 files changed, 196 insertions(+), 159 deletions(-) diff --git a/docs/demos/AIS_Solent_Tracker.py b/docs/demos/AIS_Solent_Tracker.py index fc1956ecf..523dc6a15 100644 --- a/docs/demos/AIS_Solent_Tracker.py +++ b/docs/demos/AIS_Solent_Tracker.py @@ -161,7 +161,7 @@ # :class:`set` we can simply update this with `current_tracks` at each timestep, not worrying about # duplicates. tracks = set() -for step, (time, current_tracks) in enumerate(tracker.tracks_gen(), 1): +for step, (time, current_tracks) in enumerate(tracker, 1): tracks.update(current_tracks) if not step % 10: print("Step: {} Time: {}".format(step, time)) diff --git a/docs/demos/OpenSky_Demo.py b/docs/demos/OpenSky_Demo.py index 000869ce9..e5ee9f42d 100644 --- a/docs/demos/OpenSky_Demo.py +++ b/docs/demos/OpenSky_Demo.py @@ -238,7 +238,7 @@ ) tracks = set() -for step, (time, current_tracks) in enumerate(kalman_tracker.tracks_gen(), 1): +for step, (time, current_tracks) in enumerate(kalman_tracker, 1): tracks.update(current_tracks) # %% diff --git a/docs/examples/Metrics.py b/docs/examples/Metrics.py index 1c407144d..c8787843a 100644 --- a/docs/examples/Metrics.py +++ b/docs/examples/Metrics.py @@ -146,7 +146,7 @@ # With this basic tracker built and metrics ready, we'll now run the tracker, adding the sets of # :class:`~.GroundTruthPath`, :class:`~.Detection` and :class:`~.Track` objects: to the metric # manager. -for time, tracks in tracker.tracks_gen(): +for time, tracks in tracker: metric_manager.add_data( groundtruth_sim.groundtruth_paths, tracks, detection_sim.detections, overwrite=False, # Don't overwrite, instead add above as additional data diff --git a/docs/examples/Sensor_Platform_Simulation.py b/docs/examples/Sensor_Platform_Simulation.py index 3db3bb983..b7ee341b8 100644 --- a/docs/examples/Sensor_Platform_Simulation.py +++ b/docs/examples/Sensor_Platform_Simulation.py @@ -346,7 +346,7 @@ def initiate(self, detections, timestamp, **kwargs): groundtruth_paths = {} # Store for plotting later detections = [] # Store for plotting later -for time, ctracks in kalman_tracker.tracks_gen(): +for time, ctracks in kalman_tracker: for track in ctracks: loc = (track.state_vector[0], track.state_vector[2]) if track not in kalman_tracks: diff --git a/stonesoup/reader/yaml.py b/stonesoup/reader/yaml.py index 8574b8d3b..4f9fb18c7 100644 --- a/stonesoup/reader/yaml.py +++ b/stonesoup/reader/yaml.py @@ -74,16 +74,23 @@ class YAMLTrackReader(YAMLReader, Tracker): def data_gen(self): yield from super().data_gen() - @BufferedGenerator.generator_method - def tracks_gen(self): - tracks = dict() - for time, document in self.data_gen(): - updated_tracks = set() - for track in document.get('tracks', set()): - if track.id in tracks: - tracks[track.id].states = track.states - else: - tracks[track.id] = track - updated_tracks.add(tracks[track.id]) - - yield time, updated_tracks + def __iter__(self): + self.data_iter = iter(self.data_gen()) + self._tracks = dict() + return super().__iter__() + + @property + def tracks(self): + return self._tracks + + def __next__(self): + time, document = next(self.data_iter) + updated_tracks = set() + for track in document.get('tracks', set()): + if track.id in self.tracks: + self._tracks[track.id].states = track.states + else: + self._tracks[track.id] = track + updated_tracks.add(self.tracks[track.id]) + + return time, updated_tracks diff --git a/stonesoup/tracker/base.py b/stonesoup/tracker/base.py index 489b2177b..15bd59295 100644 --- a/stonesoup/tracker/base.py +++ b/stonesoup/tracker/base.py @@ -2,23 +2,24 @@ from abc import abstractmethod from ..base import Base -from ..buffered_generator import BufferedGenerator -class Tracker(Base, BufferedGenerator): +class Tracker(Base): """Tracker base class""" @property + @abstractmethod def tracks(self): - return self.current[1] + raise NotImplementedError - @abstractmethod - @BufferedGenerator.generator_method - def tracks_gen(self): - """Returns a generator of tracks for each time step. + def __iter__(self): + return self - Yields - ------ + @abstractmethod + def __next__(self): + """ + Returns + ------- : :class:`datetime.datetime` Datetime of current time step : set of :class:`~.Track` diff --git a/stonesoup/tracker/pointprocess.py b/stonesoup/tracker/pointprocess.py index 7efd56055..a35fc884e 100644 --- a/stonesoup/tracker/pointprocess.py +++ b/stonesoup/tracker/pointprocess.py @@ -9,7 +9,6 @@ from ..updater import Updater from ..hypothesiser.gaussianmixture import GaussianMixtureHypothesiser from ..mixturereducer.gaussianmixture import GaussianMixtureReducer -from ..buffered_generator import BufferedGenerator class PointProcessMultiTargetTracker(Tracker): @@ -46,6 +45,10 @@ def tracks(self): tracks.add(track) return tracks + def __iter__(self): + self.detector_iter = iter(self.detector) + return super().__iter__() + def update_tracks(self): """ Updates the tracks (:class:`Track`) associated with the filter. @@ -74,27 +77,26 @@ def update_tracks(self): self.extraction_threshold: self.target_tracks[tag] = Track([component], id=tag) - @BufferedGenerator.generator_method - def tracks_gen(self): - for time, detections in self.detector: - # Add birth component - self.birth_component.timestamp = time - self.gaussian_mixture.append(self.birth_component) - # Perform GM Prediction and generate hypotheses - hypotheses = self.hypothesiser.hypothesise( - self.gaussian_mixture.components, - detections, - time - ) - # Perform GM Update - self.gaussian_mixture = self.updater.update(hypotheses) - # Reduce mixture - Pruning and Merging - self.gaussian_mixture.components = \ - self.reducer.reduce(self.gaussian_mixture.components) - # Update the tracks - self.update_tracks() - self.end_tracks() - yield time, self.tracks + def __next__(self): + time, detections = next(self.detector_iter) + # Add birth component + self.birth_component.timestamp = time + self.gaussian_mixture.append(self.birth_component) + # Perform GM Prediction and generate hypotheses + hypotheses = self.hypothesiser.hypothesise( + self.gaussian_mixture.components, + detections, + time + ) + # Perform GM Update + self.gaussian_mixture = self.updater.update(hypotheses) + # Reduce mixture - Pruning and Merging + self.gaussian_mixture.components = \ + self.reducer.reduce(self.gaussian_mixture.components) + # Update the tracks + self.update_tracks() + self.end_tracks() + return time, self.tracks def end_tracks(self): """ diff --git a/stonesoup/tracker/simple.py b/stonesoup/tracker/simple.py index 2a6d8de3d..cb03a136c 100644 --- a/stonesoup/tracker/simple.py +++ b/stonesoup/tracker/simple.py @@ -12,7 +12,6 @@ from ..types.prediction import GaussianStatePrediction from ..types.update import GaussianStateUpdate from ..functions import gm_reduce_single -from stonesoup.buffered_generator import BufferedGenerator class SingleTargetTracker(Tracker): @@ -45,29 +44,36 @@ class SingleTargetTracker(Tracker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._track = None - @BufferedGenerator.generator_method - def tracks_gen(self): - track = None - for time, detections in self.detector: - if track is not None: - associations = self.data_associator.associate( - {track}, detections, time) - if associations[track]: - state_post = self.updater.update(associations[track]) - track.append(state_post) - else: - track.append( - associations[track].prediction) + @property + def tracks(self): + return {self._track} if self._track else set() - if track is None or self.deleter.delete_tracks({track}): - new_tracks = self.initiator.initiate(detections, time) - if new_tracks: - track = new_tracks.pop() - else: - track = None + def __iter__(self): + self.detector_iter = iter(self.detector) + return super().__iter__() + + def __next__(self): + time, detections = next(self.detector_iter) + if self._track is not None: + associations = self.data_associator.associate( + self.tracks, detections, time) + if associations[self._track]: + state_post = self.updater.update(associations[self._track]) + self._track.append(state_post) + else: + self._track.append( + associations[self._track].prediction) + + if self._track is None or self.deleter.delete_tracks(self.tracks): + new_tracks = self.initiator.initiate(detections, time) + if new_tracks: + self._track = new_tracks.pop() + else: + self._track = None - yield (time, {track}) if track is not None else (time, set()) + return time, self.tracks class MultiTargetTracker(Tracker): @@ -93,28 +99,35 @@ class MultiTargetTracker(Tracker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._tracks = set() - @BufferedGenerator.generator_method - def tracks_gen(self): - tracks = set() + @property + def tracks(self): + return self._tracks - for time, detections in self.detector: + def __iter__(self): + self.detector_iter = iter(self.detector) + return super().__iter__() - associations = self.data_associator.associate( - tracks, detections, time) - associated_detections = set() - for track, hypothesis in associations.items(): - if hypothesis: - state_post = self.updater.update(hypothesis) - track.append(state_post) - associated_detections.add(hypothesis.measurement) - else: - track.append(hypothesis.prediction) + def __next__(self): + time, detections = next(self.detector_iter) + + associations = self.data_associator.associate( + self.tracks, detections, time) + associated_detections = set() + for track, hypothesis in associations.items(): + if hypothesis: + state_post = self.updater.update(hypothesis) + track.append(state_post) + associated_detections.add(hypothesis.measurement) + else: + track.append(hypothesis.prediction) - tracks -= self.deleter.delete_tracks(tracks) - tracks |= self.initiator.initiate(detections - associated_detections, time) + self._tracks -= self.deleter.delete_tracks(self.tracks) + self._tracks |= self.initiator.initiate( + detections - associated_detections, time) - yield time, tracks + return time, self.tracks class MultiTargetMixtureTracker(Tracker): @@ -142,64 +155,71 @@ class MultiTargetMixtureTracker(Tracker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - @BufferedGenerator.generator_method - def tracks_gen(self): - tracks = set() - - for time, detections in self.detector: - - associations = self.data_associator.associate( - tracks, detections, time) - unassociated_detections = set(detections) - for track, multihypothesis in associations.items(): - - # calculate each Track's state as a Gaussian Mixture of - # its possible associations with each detection, then - # reduce the Mixture to a single Gaussian State - posterior_states = [] - posterior_state_weights = [] - for hypothesis in multihypothesis: - if not hypothesis: - posterior_states.append(hypothesis.prediction) - else: - posterior_states.append( - self.updater.update(hypothesis)) - posterior_state_weights.append( - hypothesis.probability) - - means = StateVectors([state.state_vector for state in posterior_states]) - covars = np.stack([state.covar for state in posterior_states], axis=2) - weights = np.asarray(posterior_state_weights) - - post_mean, post_covar = gm_reduce_single(means, covars, weights) - - missed_detection_weight = next(hyp.weight for hyp in multihypothesis if not hyp) - - # Check if at least one reasonable measurement... - if any(hypothesis.weight > missed_detection_weight - for hypothesis in multihypothesis): - # ...and if so use update type - track.append(GaussianStateUpdate( - post_mean, post_covar, - multihypothesis, - multihypothesis[0].measurement.timestamp)) + self._tracks = set() + + @property + def tracks(self): + return self._tracks + + def __iter__(self): + self.detector_iter = iter(self.detector) + return super().__iter__() + + def __next__(self): + time, detections = next(self.detector_iter) + + associations = self.data_associator.associate( + self.tracks, detections, time) + unassociated_detections = set(detections) + for track, multihypothesis in associations.items(): + + # calculate each Track's state as a Gaussian Mixture of + # its possible associations with each detection, then + # reduce the Mixture to a single Gaussian State + posterior_states = [] + posterior_state_weights = [] + for hypothesis in multihypothesis: + if not hypothesis: + posterior_states.append(hypothesis.prediction) else: - # ...and if not, treat as a prediction - track.append(GaussianStatePrediction( - post_mean, post_covar, - multihypothesis[0].prediction.timestamp)) - - # any detections in multihypothesis that had an - # association score (weight) lower than or equal to the - # association score of "MissedDetection" is considered - # unassociated - candidate for initiating a new Track - for hyp in multihypothesis: - if hyp.weight > missed_detection_weight: - if hyp.measurement in unassociated_detections: - unassociated_detections.remove(hyp.measurement) - - tracks -= self.deleter.delete_tracks(tracks) - tracks |= self.initiator.initiate(unassociated_detections, time) - - yield time, tracks + posterior_states.append( + self.updater.update(hypothesis)) + posterior_state_weights.append( + hypothesis.probability) + + means = StateVectors([state.state_vector for state in posterior_states]) + covars = np.stack([state.covar for state in posterior_states], axis=2) + weights = np.asarray(posterior_state_weights) + + post_mean, post_covar = gm_reduce_single(means, covars, weights) + + missed_detection_weight = next(hyp.weight for hyp in multihypothesis if not hyp) + + # Check if at least one reasonable measurement... + if any(hypothesis.weight > missed_detection_weight + for hypothesis in multihypothesis): + # ...and if so use update type + track.append(GaussianStateUpdate( + post_mean, post_covar, + multihypothesis, + multihypothesis[0].measurement.timestamp)) + else: + # ...and if not, treat as a prediction + track.append(GaussianStatePrediction( + post_mean, post_covar, + multihypothesis[0].prediction.timestamp)) + + # any detections in multihypothesis that had an + # association score (weight) lower than or equal to the + # association score of "MissedDetection" is considered + # unassociated - candidate for initiating a new Track + for hyp in multihypothesis: + if hyp.weight > missed_detection_weight: + if hyp.measurement in unassociated_detections: + unassociated_detections.remove(hyp.measurement) + + self._tracks -= self.deleter.delete_tracks(self.tracks) + self._tracks |= self.initiator.initiate( + unassociated_detections, time) + + return time, self.tracks diff --git a/stonesoup/writer/tests/conftest.py b/stonesoup/writer/tests/conftest.py index dd9c99946..051282920 100644 --- a/stonesoup/writer/tests/conftest.py +++ b/stonesoup/writer/tests/conftest.py @@ -52,18 +52,25 @@ def groundtruth_paths_gen(self): @pytest.fixture() def tracker(): class TestTracker(Tracker): - @BufferedGenerator.generator_method - def tracks_gen(self): - time = datetime.datetime(2018, 1, 1, 14) + @property + def tracks(self): + return self._tracks + + def __iter__(self): + self.iter = iter(range(2)) + self.time = datetime.datetime(2018, 1, 1, 13, 59) + return super().__iter__() + + def __next__(self): + i = next(self.iter) state_vector = StateVector([[0]]) - for i in range(2): - tracks = { - Track( - [State( - state_vector + i + 10*j, timestamp=time) - for j in range(i)], - str(k)) - for k in range(i)} - yield time, tracks - time += datetime.timedelta(minutes=1) + self.time += datetime.timedelta(minutes=1) + self._tracks = { + Track( + [State( + state_vector + i + 10*j, timestamp=self.time) + for j in range(i)], + str(k)) + for k in range(i)} + return self.time, self.tracks return TestTracker()