Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trackers are Iterables instead of Generators #415

Merged
merged 1 commit into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/demos/AIS_Solent_Tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
# :class:`set` we can simply update this with `current_tracks` at each timestep, not worrying about
# duplicates.
tracks = set()
for step, (time, current_tracks) in enumerate(tracker.tracks_gen(), 1):
for step, (time, current_tracks) in enumerate(tracker, 1):
tracks.update(current_tracks)
if not step % 10:
print("Step: {} Time: {}".format(step, time))
Expand Down
2 changes: 1 addition & 1 deletion docs/demos/OpenSky_Demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@
)

tracks = set()
for step, (time, current_tracks) in enumerate(kalman_tracker.tracks_gen(), 1):
for step, (time, current_tracks) in enumerate(kalman_tracker, 1):
tracks.update(current_tracks)

# %%
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
# With this basic tracker built and metrics ready, we'll now run the tracker, adding the sets of
# :class:`~.GroundTruthPath`, :class:`~.Detection` and :class:`~.Track` objects: to the metric
# manager.
for time, tracks in tracker.tracks_gen():
for time, tracks in tracker:
metric_manager.add_data(
groundtruth_sim.groundtruth_paths, tracks, detection_sim.detections,
overwrite=False, # Don't overwrite, instead add above as additional data
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/Sensor_Platform_Simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def initiate(self, detections, timestamp, **kwargs):
groundtruth_paths = {} # Store for plotting later
detections = [] # Store for plotting later

for time, ctracks in kalman_tracker.tracks_gen():
for time, ctracks in kalman_tracker:
for track in ctracks:
loc = (track.state_vector[0], track.state_vector[2])
if track not in kalman_tracks:
Expand Down
33 changes: 20 additions & 13 deletions stonesoup/reader/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,23 @@ class YAMLTrackReader(YAMLReader, Tracker):
def data_gen(self):
yield from super().data_gen()

@BufferedGenerator.generator_method
def tracks_gen(self):
tracks = dict()
for time, document in self.data_gen():
updated_tracks = set()
for track in document.get('tracks', set()):
if track.id in tracks:
tracks[track.id].states = track.states
else:
tracks[track.id] = track
updated_tracks.add(tracks[track.id])

yield time, updated_tracks
def __iter__(self):
self.data_iter = iter(self.data_gen())
self._tracks = dict()
return super().__iter__()

@property
def tracks(self):
return self._tracks

def __next__(self):
time, document = next(self.data_iter)
updated_tracks = set()
for track in document.get('tracks', set()):
if track.id in self.tracks:
self._tracks[track.id].states = track.states
else:
self._tracks[track.id] = track
updated_tracks.add(self.tracks[track.id])

return time, updated_tracks
19 changes: 10 additions & 9 deletions stonesoup/tracker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@
from abc import abstractmethod

from ..base import Base
from ..buffered_generator import BufferedGenerator


class Tracker(Base, BufferedGenerator):
class Tracker(Base):
"""Tracker base class"""
sdhiscocks marked this conversation as resolved.
Show resolved Hide resolved

@property
@abstractmethod
def tracks(self):
return self.current[1]
raise NotImplementedError

@abstractmethod
@BufferedGenerator.generator_method
def tracks_gen(self):
"""Returns a generator of tracks for each time step.
def __iter__(self):
return self

Yields
------
@abstractmethod
def __next__(self):
"""
Returns
-------
: :class:`datetime.datetime`
Datetime of current time step
: set of :class:`~.Track`
Expand Down
46 changes: 24 additions & 22 deletions stonesoup/tracker/pointprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ..updater import Updater
from ..hypothesiser.gaussianmixture import GaussianMixtureHypothesiser
from ..mixturereducer.gaussianmixture import GaussianMixtureReducer
from ..buffered_generator import BufferedGenerator


class PointProcessMultiTargetTracker(Tracker):
Expand Down Expand Up @@ -46,6 +45,10 @@ def tracks(self):
tracks.add(track)
return tracks

def __iter__(self):
self.detector_iter = iter(self.detector)
return super().__iter__()

def update_tracks(self):
"""
Updates the tracks (:class:`Track`) associated with the filter.
Expand Down Expand Up @@ -74,27 +77,26 @@ def update_tracks(self):
self.extraction_threshold:
self.target_tracks[tag] = Track([component], id=tag)

@BufferedGenerator.generator_method
def tracks_gen(self):
for time, detections in self.detector:
# Add birth component
self.birth_component.timestamp = time
self.gaussian_mixture.append(self.birth_component)
# Perform GM Prediction and generate hypotheses
hypotheses = self.hypothesiser.hypothesise(
self.gaussian_mixture.components,
detections,
time
)
# Perform GM Update
self.gaussian_mixture = self.updater.update(hypotheses)
# Reduce mixture - Pruning and Merging
self.gaussian_mixture.components = \
self.reducer.reduce(self.gaussian_mixture.components)
# Update the tracks
self.update_tracks()
self.end_tracks()
yield time, self.tracks
def __next__(self):
time, detections = next(self.detector_iter)
# Add birth component
self.birth_component.timestamp = time
self.gaussian_mixture.append(self.birth_component)
# Perform GM Prediction and generate hypotheses
hypotheses = self.hypothesiser.hypothesise(
self.gaussian_mixture.components,
detections,
time
)
# Perform GM Update
self.gaussian_mixture = self.updater.update(hypotheses)
# Reduce mixture - Pruning and Merging
self.gaussian_mixture.components = \
self.reducer.reduce(self.gaussian_mixture.components)
# Update the tracks
self.update_tracks()
self.end_tracks()
return time, self.tracks

def end_tracks(self):
"""
Expand Down
Loading