From 5358490f3a0fcec8f96a7d15523f9ba3aaccc203 Mon Sep 17 00:00:00 2001
From: gawebb <gawebb@dstl.gov.uk>
Date: Thu, 18 Mar 2021 12:33:52 +0000
Subject: [PATCH] Changed all the trackers into iterable trackers

This removes the tracks_gen() method on Tracker classes, but instances
are iterable the same as BufferedGenerator version.
---
 docs/demos/AIS_Solent_Tracker.py            |   2 +-
 docs/demos/OpenSky_Demo.py                  |   2 +-
 docs/examples/Metrics.py                    |   2 +-
 docs/examples/Sensor_Platform_Simulation.py |   2 +-
 stonesoup/reader/yaml.py                    |  33 +--
 stonesoup/tracker/base.py                   |  19 +-
 stonesoup/tracker/pointprocess.py           |  46 +++--
 stonesoup/tracker/simple.py                 | 216 +++++++++++---------
 stonesoup/writer/tests/conftest.py          |  33 +--
 9 files changed, 196 insertions(+), 159 deletions(-)

diff --git a/docs/demos/AIS_Solent_Tracker.py b/docs/demos/AIS_Solent_Tracker.py
index fc1956ecf..523dc6a15 100644
--- a/docs/demos/AIS_Solent_Tracker.py
+++ b/docs/demos/AIS_Solent_Tracker.py
@@ -161,7 +161,7 @@
 # :class:`set` we can simply update this with `current_tracks` at each timestep, not worrying about
 # duplicates.
 tracks = set()
-for step, (time, current_tracks) in enumerate(tracker.tracks_gen(), 1):
+for step, (time, current_tracks) in enumerate(tracker, 1):
     tracks.update(current_tracks)
     if not step % 10:
         print("Step: {} Time: {}".format(step, time))
diff --git a/docs/demos/OpenSky_Demo.py b/docs/demos/OpenSky_Demo.py
index 000869ce9..e5ee9f42d 100644
--- a/docs/demos/OpenSky_Demo.py
+++ b/docs/demos/OpenSky_Demo.py
@@ -238,7 +238,7 @@
 )
 
 tracks = set()
-for step, (time, current_tracks) in enumerate(kalman_tracker.tracks_gen(), 1):
+for step, (time, current_tracks) in enumerate(kalman_tracker, 1):
     tracks.update(current_tracks)
 
 # %%
diff --git a/docs/examples/Metrics.py b/docs/examples/Metrics.py
index 1c407144d..c8787843a 100644
--- a/docs/examples/Metrics.py
+++ b/docs/examples/Metrics.py
@@ -146,7 +146,7 @@
 # With this basic tracker built and metrics ready, we'll now run the tracker, adding the sets of
 # :class:`~.GroundTruthPath`, :class:`~.Detection` and :class:`~.Track` objects: to the metric
 # manager.
-for time, tracks in tracker.tracks_gen():
+for time, tracks in tracker:
     metric_manager.add_data(
         groundtruth_sim.groundtruth_paths, tracks, detection_sim.detections,
         overwrite=False,  # Don't overwrite, instead add above as additional data
diff --git a/docs/examples/Sensor_Platform_Simulation.py b/docs/examples/Sensor_Platform_Simulation.py
index 3db3bb983..b7ee341b8 100644
--- a/docs/examples/Sensor_Platform_Simulation.py
+++ b/docs/examples/Sensor_Platform_Simulation.py
@@ -346,7 +346,7 @@ def initiate(self, detections, timestamp, **kwargs):
 groundtruth_paths = {}  # Store for plotting later
 detections = []  # Store for plotting later
 
-for time, ctracks in kalman_tracker.tracks_gen():
+for time, ctracks in kalman_tracker:
     for track in ctracks:
         loc = (track.state_vector[0], track.state_vector[2])
         if track not in kalman_tracks:
diff --git a/stonesoup/reader/yaml.py b/stonesoup/reader/yaml.py
index 8574b8d3b..4f9fb18c7 100644
--- a/stonesoup/reader/yaml.py
+++ b/stonesoup/reader/yaml.py
@@ -74,16 +74,23 @@ class YAMLTrackReader(YAMLReader, Tracker):
     def data_gen(self):
         yield from super().data_gen()
 
-    @BufferedGenerator.generator_method
-    def tracks_gen(self):
-        tracks = dict()
-        for time, document in self.data_gen():
-            updated_tracks = set()
-            for track in document.get('tracks', set()):
-                if track.id in tracks:
-                    tracks[track.id].states = track.states
-                else:
-                    tracks[track.id] = track
-                updated_tracks.add(tracks[track.id])
-
-            yield time, updated_tracks
+    def __iter__(self):
+        self.data_iter = iter(self.data_gen())
+        self._tracks = dict()
+        return super().__iter__()
+
+    @property
+    def tracks(self):
+        return self._tracks
+
+    def __next__(self):
+        time, document = next(self.data_iter)
+        updated_tracks = set()
+        for track in document.get('tracks', set()):
+            if track.id in self.tracks:
+                self._tracks[track.id].states = track.states
+            else:
+                self._tracks[track.id] = track
+            updated_tracks.add(self.tracks[track.id])
+
+        return time, updated_tracks
diff --git a/stonesoup/tracker/base.py b/stonesoup/tracker/base.py
index 489b2177b..15bd59295 100644
--- a/stonesoup/tracker/base.py
+++ b/stonesoup/tracker/base.py
@@ -2,23 +2,24 @@
 from abc import abstractmethod
 
 from ..base import Base
-from ..buffered_generator import BufferedGenerator
 
 
-class Tracker(Base, BufferedGenerator):
+class Tracker(Base):
     """Tracker base class"""
 
     @property
+    @abstractmethod
     def tracks(self):
-        return self.current[1]
+        raise NotImplementedError
 
-    @abstractmethod
-    @BufferedGenerator.generator_method
-    def tracks_gen(self):
-        """Returns a generator of tracks for each time step.
+    def __iter__(self):
+        return self
 
-        Yields
-        ------
+    @abstractmethod
+    def __next__(self):
+        """
+        Returns
+        -------
         : :class:`datetime.datetime`
             Datetime of current time step
         : set of :class:`~.Track`
diff --git a/stonesoup/tracker/pointprocess.py b/stonesoup/tracker/pointprocess.py
index 7efd56055..a35fc884e 100644
--- a/stonesoup/tracker/pointprocess.py
+++ b/stonesoup/tracker/pointprocess.py
@@ -9,7 +9,6 @@
 from ..updater import Updater
 from ..hypothesiser.gaussianmixture import GaussianMixtureHypothesiser
 from ..mixturereducer.gaussianmixture import GaussianMixtureReducer
-from ..buffered_generator import BufferedGenerator
 
 
 class PointProcessMultiTargetTracker(Tracker):
@@ -46,6 +45,10 @@ def tracks(self):
             tracks.add(track)
         return tracks
 
+    def __iter__(self):
+        self.detector_iter = iter(self.detector)
+        return super().__iter__()
+
     def update_tracks(self):
         """
         Updates the tracks (:class:`Track`) associated with the filter.
@@ -74,27 +77,26 @@ def update_tracks(self):
                             self.extraction_threshold:
                         self.target_tracks[tag] = Track([component], id=tag)
 
-    @BufferedGenerator.generator_method
-    def tracks_gen(self):
-        for time, detections in self.detector:
-            # Add birth component
-            self.birth_component.timestamp = time
-            self.gaussian_mixture.append(self.birth_component)
-            # Perform GM Prediction and generate hypotheses
-            hypotheses = self.hypothesiser.hypothesise(
-                        self.gaussian_mixture.components,
-                        detections,
-                        time
-                        )
-            # Perform GM Update
-            self.gaussian_mixture = self.updater.update(hypotheses)
-            # Reduce mixture - Pruning and Merging
-            self.gaussian_mixture.components = \
-                self.reducer.reduce(self.gaussian_mixture.components)
-            # Update the tracks
-            self.update_tracks()
-            self.end_tracks()
-            yield time, self.tracks
+    def __next__(self):
+        time, detections = next(self.detector_iter)
+        # Add birth component
+        self.birth_component.timestamp = time
+        self.gaussian_mixture.append(self.birth_component)
+        # Perform GM Prediction and generate hypotheses
+        hypotheses = self.hypothesiser.hypothesise(
+                    self.gaussian_mixture.components,
+                    detections,
+                    time
+                    )
+        # Perform GM Update
+        self.gaussian_mixture = self.updater.update(hypotheses)
+        # Reduce mixture - Pruning and Merging
+        self.gaussian_mixture.components = \
+            self.reducer.reduce(self.gaussian_mixture.components)
+        # Update the tracks
+        self.update_tracks()
+        self.end_tracks()
+        return time, self.tracks
 
     def end_tracks(self):
         """
diff --git a/stonesoup/tracker/simple.py b/stonesoup/tracker/simple.py
index 2a6d8de3d..cb03a136c 100644
--- a/stonesoup/tracker/simple.py
+++ b/stonesoup/tracker/simple.py
@@ -12,7 +12,6 @@
 from ..types.prediction import GaussianStatePrediction
 from ..types.update import GaussianStateUpdate
 from ..functions import gm_reduce_single
-from stonesoup.buffered_generator import BufferedGenerator
 
 
 class SingleTargetTracker(Tracker):
@@ -45,29 +44,36 @@ class SingleTargetTracker(Tracker):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
+        self._track = None
 
-    @BufferedGenerator.generator_method
-    def tracks_gen(self):
-        track = None
-        for time, detections in self.detector:
-            if track is not None:
-                associations = self.data_associator.associate(
-                    {track}, detections, time)
-                if associations[track]:
-                    state_post = self.updater.update(associations[track])
-                    track.append(state_post)
-                else:
-                    track.append(
-                        associations[track].prediction)
+    @property
+    def tracks(self):
+        return {self._track} if self._track else set()
 
-            if track is None or self.deleter.delete_tracks({track}):
-                new_tracks = self.initiator.initiate(detections, time)
-                if new_tracks:
-                    track = new_tracks.pop()
-                else:
-                    track = None
+    def __iter__(self):
+        self.detector_iter = iter(self.detector)
+        return super().__iter__()
+
+    def __next__(self):
+        time, detections = next(self.detector_iter)
+        if self._track is not None:
+            associations = self.data_associator.associate(
+                self.tracks, detections, time)
+            if associations[self._track]:
+                state_post = self.updater.update(associations[self._track])
+                self._track.append(state_post)
+            else:
+                self._track.append(
+                    associations[self._track].prediction)
+
+        if self._track is None or self.deleter.delete_tracks(self.tracks):
+            new_tracks = self.initiator.initiate(detections, time)
+            if new_tracks:
+                self._track = new_tracks.pop()
+            else:
+                self._track = None
 
-            yield (time, {track}) if track is not None else (time, set())
+        return time, self.tracks
 
 
 class MultiTargetTracker(Tracker):
@@ -93,28 +99,35 @@ class MultiTargetTracker(Tracker):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
+        self._tracks = set()
 
-    @BufferedGenerator.generator_method
-    def tracks_gen(self):
-        tracks = set()
+    @property
+    def tracks(self):
+        return self._tracks
 
-        for time, detections in self.detector:
+    def __iter__(self):
+        self.detector_iter = iter(self.detector)
+        return super().__iter__()
 
-            associations = self.data_associator.associate(
-                tracks, detections, time)
-            associated_detections = set()
-            for track, hypothesis in associations.items():
-                if hypothesis:
-                    state_post = self.updater.update(hypothesis)
-                    track.append(state_post)
-                    associated_detections.add(hypothesis.measurement)
-                else:
-                    track.append(hypothesis.prediction)
+    def __next__(self):
+        time, detections = next(self.detector_iter)
+
+        associations = self.data_associator.associate(
+            self.tracks, detections, time)
+        associated_detections = set()
+        for track, hypothesis in associations.items():
+            if hypothesis:
+                state_post = self.updater.update(hypothesis)
+                track.append(state_post)
+                associated_detections.add(hypothesis.measurement)
+            else:
+                track.append(hypothesis.prediction)
 
-            tracks -= self.deleter.delete_tracks(tracks)
-            tracks |= self.initiator.initiate(detections - associated_detections, time)
+        self._tracks -= self.deleter.delete_tracks(self.tracks)
+        self._tracks |= self.initiator.initiate(
+            detections - associated_detections, time)
 
-            yield time, tracks
+        return time, self.tracks
 
 
 class MultiTargetMixtureTracker(Tracker):
@@ -142,64 +155,71 @@ class MultiTargetMixtureTracker(Tracker):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-
-    @BufferedGenerator.generator_method
-    def tracks_gen(self):
-        tracks = set()
-
-        for time, detections in self.detector:
-
-            associations = self.data_associator.associate(
-                tracks, detections, time)
-            unassociated_detections = set(detections)
-            for track, multihypothesis in associations.items():
-
-                # calculate each Track's state as a Gaussian Mixture of
-                # its possible associations with each detection, then
-                # reduce the Mixture to a single Gaussian State
-                posterior_states = []
-                posterior_state_weights = []
-                for hypothesis in multihypothesis:
-                    if not hypothesis:
-                        posterior_states.append(hypothesis.prediction)
-                    else:
-                        posterior_states.append(
-                            self.updater.update(hypothesis))
-                    posterior_state_weights.append(
-                        hypothesis.probability)
-
-                means = StateVectors([state.state_vector for state in posterior_states])
-                covars = np.stack([state.covar for state in posterior_states], axis=2)
-                weights = np.asarray(posterior_state_weights)
-
-                post_mean, post_covar = gm_reduce_single(means, covars, weights)
-
-                missed_detection_weight = next(hyp.weight for hyp in multihypothesis if not hyp)
-
-                # Check if at least one reasonable measurement...
-                if any(hypothesis.weight > missed_detection_weight
-                       for hypothesis in multihypothesis):
-                    # ...and if so use update type
-                    track.append(GaussianStateUpdate(
-                        post_mean, post_covar,
-                        multihypothesis,
-                        multihypothesis[0].measurement.timestamp))
+        self._tracks = set()
+
+    @property
+    def tracks(self):
+        return self._tracks
+
+    def __iter__(self):
+        self.detector_iter = iter(self.detector)
+        return super().__iter__()
+
+    def __next__(self):
+        time, detections = next(self.detector_iter)
+
+        associations = self.data_associator.associate(
+            self.tracks, detections, time)
+        unassociated_detections = set(detections)
+        for track, multihypothesis in associations.items():
+
+            # calculate each Track's state as a Gaussian Mixture of
+            # its possible associations with each detection, then
+            # reduce the Mixture to a single Gaussian State
+            posterior_states = []
+            posterior_state_weights = []
+            for hypothesis in multihypothesis:
+                if not hypothesis:
+                    posterior_states.append(hypothesis.prediction)
                 else:
-                    # ...and if not, treat as a prediction
-                    track.append(GaussianStatePrediction(
-                        post_mean, post_covar,
-                        multihypothesis[0].prediction.timestamp))
-
-                # any detections in multihypothesis that had an
-                # association score (weight) lower than or equal to the
-                # association score of "MissedDetection" is considered
-                # unassociated - candidate for initiating a new Track
-                for hyp in multihypothesis:
-                    if hyp.weight > missed_detection_weight:
-                        if hyp.measurement in unassociated_detections:
-                            unassociated_detections.remove(hyp.measurement)
-
-            tracks -= self.deleter.delete_tracks(tracks)
-            tracks |= self.initiator.initiate(unassociated_detections, time)
-
-            yield time, tracks
+                    posterior_states.append(
+                        self.updater.update(hypothesis))
+                posterior_state_weights.append(
+                    hypothesis.probability)
+
+            means = StateVectors([state.state_vector for state in posterior_states])
+            covars = np.stack([state.covar for state in posterior_states], axis=2)
+            weights = np.asarray(posterior_state_weights)
+
+            post_mean, post_covar = gm_reduce_single(means, covars, weights)
+
+            missed_detection_weight = next(hyp.weight for hyp in multihypothesis if not hyp)
+
+            # Check if at least one reasonable measurement...
+            if any(hypothesis.weight > missed_detection_weight
+                   for hypothesis in multihypothesis):
+                # ...and if so use update type
+                track.append(GaussianStateUpdate(
+                    post_mean, post_covar,
+                    multihypothesis,
+                    multihypothesis[0].measurement.timestamp))
+            else:
+                # ...and if not, treat as a prediction
+                track.append(GaussianStatePrediction(
+                    post_mean, post_covar,
+                    multihypothesis[0].prediction.timestamp))
+
+            # any detections in multihypothesis that had an
+            # association score (weight) lower than or equal to the
+            # association score of "MissedDetection" is considered
+            # unassociated - candidate for initiating a new Track
+            for hyp in multihypothesis:
+                if hyp.weight > missed_detection_weight:
+                    if hyp.measurement in unassociated_detections:
+                        unassociated_detections.remove(hyp.measurement)
+
+        self._tracks -= self.deleter.delete_tracks(self.tracks)
+        self._tracks |= self.initiator.initiate(
+            unassociated_detections, time)
+
+        return time, self.tracks
diff --git a/stonesoup/writer/tests/conftest.py b/stonesoup/writer/tests/conftest.py
index dd9c99946..051282920 100644
--- a/stonesoup/writer/tests/conftest.py
+++ b/stonesoup/writer/tests/conftest.py
@@ -52,18 +52,25 @@ def groundtruth_paths_gen(self):
 @pytest.fixture()
 def tracker():
     class TestTracker(Tracker):
-        @BufferedGenerator.generator_method
-        def tracks_gen(self):
-            time = datetime.datetime(2018, 1, 1, 14)
+        @property
+        def tracks(self):
+            return self._tracks
+
+        def __iter__(self):
+            self.iter = iter(range(2))
+            self.time = datetime.datetime(2018, 1, 1, 13, 59)
+            return super().__iter__()
+
+        def __next__(self):
+            i = next(self.iter)
             state_vector = StateVector([[0]])
-            for i in range(2):
-                tracks = {
-                    Track(
-                        [State(
-                            state_vector + i + 10*j, timestamp=time)
-                            for j in range(i)],
-                        str(k))
-                    for k in range(i)}
-                yield time, tracks
-                time += datetime.timedelta(minutes=1)
+            self.time += datetime.timedelta(minutes=1)
+            self._tracks = {
+                Track(
+                    [State(
+                        state_vector + i + 10*j, timestamp=self.time)
+                        for j in range(i)],
+                    str(k))
+                for k in range(i)}
+            return self.time, self.tracks
     return TestTracker()