From e0c00a142303fec3b53ab8610d5714cc63f86ab3 Mon Sep 17 00:00:00 2001 From: Steven Hiscocks Date: Tue, 21 Jun 2022 10:54:16 +0100 Subject: [PATCH 1/2] Fix tree data associators to work with particle states --- stonesoup/dataassociator/tests/test_tree.py | 58 ++++++++++++++++++--- stonesoup/dataassociator/tree.py | 22 ++++++-- 2 files changed, 69 insertions(+), 11 deletions(-) diff --git a/stonesoup/dataassociator/tests/test_tree.py b/stonesoup/dataassociator/tests/test_tree.py index 85fc9a990..4cf81dda4 100644 --- a/stonesoup/dataassociator/tests/test_tree.py +++ b/stonesoup/dataassociator/tests/test_tree.py @@ -2,6 +2,8 @@ import pytest import numpy as np +from scipy.stats import multivariate_normal + try: import rtree except (ImportError, AttributeError, OSError): @@ -13,10 +15,12 @@ from ..probability import PDA, JPDA from ..tree import DetectionKDTreeMixIn, TPRTreeMixIn from ...models.measurement.nonlinear import CartesianToBearingRange -from ...types.array import CovarianceMatrix +from ...predictor.particle import ParticlePredictor +from ...types.array import CovarianceMatrix, StateVectors from ...types.detection import Detection, MissedDetection -from ...types.state import GaussianState +from ...types.state import GaussianState, ParticleState from ...types.track import Track +from ...updater.particle import ParticleUpdater class DetectionKDTreeNN(NearestNeighbour, DetectionKDTreeMixIn): @@ -150,8 +154,6 @@ def test_nearest_neighbour(nn_associator): def test_tpr_tree_management(nn_associator, updater): '''Test method for TPR insert, delete and update''' - if not isinstance(nn_associator, TPRTreeNN): - return timestamp = datetime.datetime.now() t1 = Track([GaussianState(np.array([[0, 0, 0, 0]]), np.diag([1, 0.1, 1, 0.1]), timestamp)]) @@ -201,8 +203,6 @@ def test_tpr_tree_management(nn_associator, updater): def test_tpr_tree_measurement_models(nn_associator, measurement_model): '''Test method for TPR insert, delete and update using non linear measurement model''' - if not isinstance(nn_associator, TPRTreeNN): - return timestamp = datetime.datetime.now() measurement_model_nl = CartesianToBearingRange( ndim_state=4, mapping=[0, 2], @@ -359,3 +359,49 @@ def test_no_tracks_probability(pda_associator): # Since no Tracks went in, there should be no associations assert not associations + + +def test_particle_tree(nn_associator): + timestamp = datetime.datetime.now() + p1 = multivariate_normal.rvs(np.array([0, 0, 0, 0]), + np.diag([1, 0.1, 1, 0.1]), + size=200) + p2 = multivariate_normal.rvs(np.array([3, 0, 3, 0]), + np.diag([1, 0.1, 1, 0.1]), + size=200) + t1 = Track([ParticleState(StateVectors(p1.T), timestamp, np.full(200, 1 / 200))]) + t2 = Track([ParticleState(StateVectors(p2.T), timestamp, np.full(200, 1 / 200))]) + d1 = Detection(np.array([[2, 2]]), timestamp) + d2 = Detection(np.array([[5, 5]]), timestamp) + + tracks = {t1, t2} + detections = {d1, d2} + + # Switch predictor/updater to Particle ones. + nn_associator.hypothesiser.predictor = ParticlePredictor( + nn_associator.hypothesiser.predictor.transition_model) + nn_associator.hypothesiser.updater = ParticleUpdater( + nn_associator.hypothesiser.updater.measurement_model) + if isinstance(nn_associator, DetectionKDTreeMixIn): + nn_associator.predictor = nn_associator.hypothesiser.predictor + nn_associator.updater = nn_associator.hypothesiser.updater + associations = nn_associator.associate(tracks, detections, timestamp) + + # There should be 2 associations + assert len(associations) == 2 + + # Each track should associate with a unique detection + associated_measurements = [hypothesis.measurement + for hypothesis in associations.values() + if hypothesis.measurement] + assert len(associated_measurements) == len(set(associated_measurements)) + + tracks = {} + associations = nn_associator.associate(tracks, detections, timestamp) + assert len(associations) == 0 + + tracks = {t1, t2} + detections = {} + associations = nn_associator.associate(tracks, detections, timestamp) + + assert len([hypothesis for hypothesis in associations.values() if not hypothesis]) == 2 diff --git a/stonesoup/dataassociator/tree.py b/stonesoup/dataassociator/tree.py index b1ee97f4b..a3bc392f0 100644 --- a/stonesoup/dataassociator/tree.py +++ b/stonesoup/dataassociator/tree.py @@ -68,13 +68,18 @@ def generate_hypotheses(self, tracks, detections, timestamp, **kwargs): prediction = self.predictor.predict(track.state, timestamp) meas_pred = self.updater.predict_measurement(prediction) + try: + meas_pred_state_vector = meas_pred.mean + except AttributeError: + meas_pred_state_vector = meas_pred.state_vector + if self.number_of_neighbours is None: indexes = tree.query_ball_point( - meas_pred.state_vector.ravel(), + meas_pred_state_vector.ravel(), r=self.max_distance) else: _, indexes = tree.query( - meas_pred.state_vector.ravel(), + meas_pred_state_vector.ravel(), k=self.number_of_neighbours, distance_upper_bound=self.max_distance) @@ -91,7 +96,14 @@ def generate_hypotheses(self, tracks, detections, timestamp, **kwargs): class TPRTreeMixIn(DataAssociator): """Detection TPR tree based mixin - Construct a TPR-tree. + Construct a TPR-tree to filter detections for generating hypotheses. This assumes + tracks move in constant velocity like model, using the mean and covariance to define + region to look for detections. + + Notes + ----- + This requires that track state has a mean (position and velocity) and covariance, which + is then approximated to a TPR node (position, velocity and time bounding box). """ measurement_model: MeasurementModel = Property( doc="Measurement model used within the TPR tree") @@ -125,10 +137,10 @@ def __init__(self, *args, **kwargs): self._coords = dict() def _track_tree_coordinates(self, track): - state_vector = track.state_vector[self.pos_mapping, :] + state_vector = track.mean[self.pos_mapping, :] state_delta = 3 * np.sqrt( np.diag(track.covar)[self.pos_mapping].reshape(-1, 1)) - vel_vector = track.state_vector[self.vel_mapping, :] + vel_vector = track.mean[self.vel_mapping, :] vel_delta = 3 * np.sqrt( np.diag(track.covar)[self.vel_mapping].reshape(-1, 1)) From bf4333d62c33d365352e5882115ffefe324c30b2 Mon Sep 17 00:00:00 2001 From: Steven Hiscocks Date: Fri, 22 Jul 2022 09:30:57 +0100 Subject: [PATCH 2/2] Ignore import errors of rtree in code coverage --- stonesoup/dataassociator/tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stonesoup/dataassociator/tree.py b/stonesoup/dataassociator/tree.py index a3bc392f0..3eb0dde71 100644 --- a/stonesoup/dataassociator/tree.py +++ b/stonesoup/dataassociator/tree.py @@ -9,7 +9,7 @@ from scipy.spatial import KDTree try: import rtree -except (ImportError, AttributeError, OSError) as err: +except (ImportError, AttributeError, OSError) as err: # pragma: no cover # AttributeError or OSError raised when libspatialindex missing or unable to load. import warnings warnings.warn(f"Failed to import 'rtree': {err!r}")