Skip to content

Commit

Permalink
Merge pull request #882 from dstl/track2track_example_mods
Browse files Browse the repository at this point in the history
Modify track fusion example to use tee and plot longest track
  • Loading branch information
sdhiscocks authored Nov 7, 2023
2 parents 93d3b5e + e11e53c commit e6194b2
Showing 1 changed file with 56 additions and 89 deletions.
145 changes: 56 additions & 89 deletions docs/examples/Track2Track_Fusion_Example.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@
# * Create a GM-PHD tracker that will perform track fusion via covariance intersection using
# the :class:`ChernoffUpdater` class.
# * Create a metric manager to generate metrics for each of the four trackers
# * Set up the detection feeders. Each tracker will receive measurements using a custom
# :class:`DummyDetector` class. The track fusion tracker will also use the
# * Set up the detection feeders. The track fusion tracker will also use the
# :class:`Tracks2GaussianDetectionFeeder` class.
# * Run the simulation
# * Plot the resulting tracks and the metrics over time
Expand Down Expand Up @@ -205,13 +204,30 @@
# Now we can pass the platforms into a detection simulator. At each timestep, the simulator will
# return the detections from the `sensor1_platform`, then the detections from the
# `sensor2_platform`.
#
# As we'll be using the same simulation and detectors in multiple detectors, trackers and for
# plotting :func:`itertools.tee()` is used to create independent iterators to use in different
# components.

# %%
from itertools import tee
from stonesoup.feeder.multi import MultiDataFeeder
from stonesoup.simulator.platform import PlatformDetectionSimulator
radar_simulator = PlatformDetectionSimulator(
groundtruth=gt_simulator,
platforms=[sensor1_platform, sensor2_platform]

gt_sims = tee(gt_simulator, 2)
radar1_simulator = PlatformDetectionSimulator(
groundtruth=gt_sims[0],
platforms=[sensor1_platform]
)
radar1_plotting, radar1_simulator, s1_detector = tee(radar1_simulator, 3)

radar2_simulator = PlatformDetectionSimulator(
groundtruth=gt_sims[1],
platforms=[sensor2_platform]
)
radar2_plotting, radar2_simulator, s2_detector = tee(radar2_simulator, 3)

s1s2_detector = MultiDataFeeder([radar1_simulator, radar2_simulator])

# %%
# Let's briefly visualize the truths and measurements before we move on. Note that the final
Expand All @@ -224,20 +240,17 @@
from stonesoup.plotter import Plotter, Dimension

# Lists to hold the detections from each sensor and the path of the airborne radar
s1_detections = []
s2_detections = []
s1_detections = set()
s2_detections = set()
radar1_path = []

# Extract the generator function from a copy of the simulator
sim = deepcopy(radar_simulator)
g = sim.detections_gen()
truths = set()

# Iterate over the time steps, extracting the detections, truths, and airborne sensor path
for _ in range(num_steps):
s1_detections.append(next(g)[1])
s2_detections.append(next(g)[1])
radar1_path.append(sim.platforms[0].position)
truths = set(sim.groundtruth.groundtruth_paths)
for (time, s1ds), (_, s2ds) in zip(radar1_plotting, radar2_plotting):
s1_detections.update(s1ds)
s2_detections.update(s2ds)
radar1_path.append(sensor1_platform.position)
truths.update(gt_simulator.groundtruth_paths)

# Plot the truths and detections
plotter = Plotter(dimension=Dimension.THREE)
Expand Down Expand Up @@ -318,10 +331,11 @@
jpda_tracker = MultiTargetMixtureTracker(
initiator=initiator,
deleter=deleter,
detector=None,
detector=s1_detector,
data_associator=data_associator,
updater=jpda_updater
)
jpda_tracker, s1_tracker = tee(jpda_tracker, 2)

# %%
# GM-LCC Tracker
Expand Down Expand Up @@ -382,13 +396,14 @@

# Tracker
gmlcc_tracker = PointProcessMultiTargetTracker(
detector=None,
detector=s2_detector,
hypothesiser=deepcopy(hypothesiser),
updater=deepcopy(updater),
reducer=deepcopy(reducer),
birth_component=deepcopy(birth_component),
extraction_threshold=0.90,
)
gmlcc_tracker, s2_tracker = tee(gmlcc_tracker, 2)

# %%
# 4: Make GM-PHD Tracker For Measurement Fusion
Expand All @@ -406,7 +421,7 @@
)

meas_fusion_tracker = PointProcessMultiTargetTracker(
detector=None,
detector=s1s2_detector,
hypothesiser=deepcopy(hypothesiser),
updater=deepcopy(updater),
reducer=deepcopy(reducer),
Expand Down Expand Up @@ -478,8 +493,10 @@
)

# Make tracker
from stonesoup.feeder.track import Tracks2GaussianDetectionFeeder

track_fusion_tracker = PointProcessMultiTargetTracker(
detector=None,
detector=Tracks2GaussianDetectionFeeder(MultiDataFeeder([s1_tracker, s2_tracker])),
hypothesiser=hypothesiser,
updater=updater,
reducer=deepcopy(ch_reducer),
Expand Down Expand Up @@ -545,97 +562,52 @@
#
# The track fusion tracker will also use the :class:`~.Tracks2GaussianDetectionFeeder` class to
# feed the tracks as measurements. At each time step, the resultant live tracks from the JPDA and
# GM-LCC trackers will be put into a :class:`~.Tracks2GaussianDetectionFeeder` (using the
# :class:`~.DummyDetector` we write below). The feeder will take the most recent state from each
# GM-LCC trackers will be put into a :class:`~.Tracks2GaussianDetectionFeeder`. The feeder will
# take the most recent state from each
# track and turn it into a :class:`~.GaussianDetection` object. The set of detection objects will
# be returned and passed into the tracker.

# %%
from stonesoup.feeder.track import Tracks2GaussianDetectionFeeder
from stonesoup.buffered_generator import BufferedGenerator
from stonesoup.reader.base import DetectionReader


class DummyDetector(DetectionReader):
def __init__(self, *args, **kwargs):
self.current = kwargs['current']

@BufferedGenerator.generator_method
def detections_gen(self):
yield self.current


# %%
# 8: Run Simulation
# -----------------

# %%

sensor1_detections, sensor2_detections = [], []
jpda_tracks, gmlcc_tracks = set(), set()
meas_fusion_tracks, track_fusion_tracks = set(), set()

metric_manager.add_data({'detections': set()})

sim_generator = radar_simulator.detections_gen()
meas_fusion_tracker_iter = iter(meas_fusion_tracker)
track_fusion_tracker_iter = iter(track_fusion_tracker)

for t in range(num_steps):

# Run JPDA tracker from sensor 1
s1d = next(sim_generator)
sensor1_detections.extend(s1d[1]) # hold in list for plotting
# Pass the detections into a DummyDetector and set it up as an iterable
jpda_tracker.detector = DummyDetector(current=s1d)
jpda_tracker.__iter__()
# Run the tracker and store the resulting tracks
_, sensor1_tracks = next(jpda_tracker)
jpda_tracks.update(sensor1_tracks)

# Run GM-LCC tracker from sensor 2
s2d = next(sim_generator)
sensor2_detections.extend(s2d[1]) # hold in list for plotting
# Pass the detections into a DummyDetector and set it up as an iterable
gmlcc_tracker.detector = DummyDetector(current=s2d)
gmlcc_tracker.__iter__()
# Run the tracker and store results
time, sensor2_tracks = next(gmlcc_tracker)
_, sensor2_tracks = next(gmlcc_tracker)
gmlcc_tracks.update(sensor2_tracks)

# Run the GM-PHD for measurement fusion. This one gets called twice, once for each set of
# detections. This ensures there is only one detection per target.
for detections in [s1d, s2d]:
meas_fusion_tracker.detector = DummyDetector(current=detections)
meas_fusion_tracker.__iter__()
_, tracks = next(meas_fusion_tracker)
meas_fusion_tracks.update(tracks)

# Run the GM-PHD for track fusion. Similar to the measurement fusion, this tracker gets run
# twice, once for each set of tracks.
for tracks_as_meas in [sensor1_tracks, sensor2_tracks]:
dummy_detector = DummyDetector(current=[time, tracks_as_meas])
track_fusion_tracker.detector = Tracks2GaussianDetectionFeeder(dummy_detector)
track_fusion_tracker.__iter__()
_, tracks = next(track_fusion_tracker)
for _ in (0, 1):
_, tracks = next(meas_fusion_tracker_iter)
meas_fusion_tracks.update(tracks)
_, tracks = next(track_fusion_tracker_iter)
track_fusion_tracks.update(tracks)

# ----------------------------------------------------------------------

# Add ground truth data and measurements to metric manager
truths = radar_simulator.groundtruth.current
detections = s1d[1] | s2d[1]
metric_manager.add_data({'truths': truths[1], 'detections': detections}, overwrite=False)

# Ensure that all tracks have been extracted from the trackers
jpda_tracks.update(jpda_tracker.tracks)
gmlcc_tracks.update(gmlcc_tracker.tracks)
meas_fusion_tracks.update(meas_fusion_tracker.tracks)
track_fusion_tracks.update(track_fusion_tracker.tracks)
detections = s1_detections | s2_detections
metric_manager.add_data({'truths': truths, 'detections': detections}, overwrite=False)

# Remove tracks that have just one state in them as they were probably from clutter
jpda_tracks = set([track for track in jpda_tracks if len(track) > 1])
gmlcc_tracks = set([track for track in gmlcc_tracks if len(track) > 1])
meas_fusion_tracks = set([track for track in meas_fusion_tracks if len(track) > 1])
track_fusion_tracks = set([track for track in track_fusion_tracks if len(track) > 1])
jpda_tracks = {track for track in jpda_tracks if len(track) > 1}
gmlcc_tracks = {track for track in gmlcc_tracks if len(track) > 1}
meas_fusion_tracks = {track for track in meas_fusion_tracks if len(track) > 1}
track_fusion_tracks = {track for track in track_fusion_tracks if len(track) > 1}

# Add track data to the metric manager
metric_manager.add_data({'jpda_tracks': jpda_tracks,
Expand All @@ -650,20 +622,15 @@ def detections_gen(self):
# Next, we will plot all of the resulting tracks and measurements. This will be done in two plots.
# The first plot will show all of the data, and the second plot will show a closer view of one
# resultant track.
#
# These plots are done in 2D to make them more readable. We invite the reader to explore the plot
# interactively using the following line in an active Jupyter session.
#
# %matplotlib widget

# %%
plotter1, plotter2 = Plotter(), Plotter()
for plotter in [plotter1, plotter2]:
plotter.plot_ground_truths(set(radar_simulator.groundtruth.groundtruth_paths), [0, 2],
plotter.plot_ground_truths(truths, [0, 2],
color='black')
plotter.plot_measurements(sensor1_detections, [0, 2], color='orange', marker='*',
plotter.plot_measurements(s1_detections, [0, 2], color='orange', marker='*',
measurements_label='Measurements - Airborne Radar')
plotter.plot_measurements(sensor2_detections, [0, 2], color='blue', marker='*',
plotter.plot_measurements(s2_detections, [0, 2], color='blue', marker='*',
measurements_label='Measurements - Ground Radar')
plotter.plot_tracks(jpda_tracks, [0, 2], color='red',
track_label='Tracks - Airborne Radar (JPDAF)')
Expand All @@ -686,7 +653,7 @@ def detections_gen(self):

plotter1.fig.show()

track = track_fusion_tracks.pop()
track = sorted(track_fusion_tracks, key=len)[-1] # Longest track
x_min = min([state.state_vector[0] for state in track])
x_max = max([state.state_vector[0] for state in track])
y_min = min([state.state_vector[2] for state in track])
Expand Down

0 comments on commit e6194b2

Please sign in to comment.