Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tree data associators to work with particle states #662

Merged
merged 2 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 52 additions & 6 deletions stonesoup/dataassociator/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest
import numpy as np
from scipy.stats import multivariate_normal

try:
import rtree
except (ImportError, AttributeError, OSError):
Expand All @@ -13,10 +15,12 @@
from ..probability import PDA, JPDA
from ..tree import DetectionKDTreeMixIn, TPRTreeMixIn
from ...models.measurement.nonlinear import CartesianToBearingRange
from ...types.array import CovarianceMatrix
from ...predictor.particle import ParticlePredictor
from ...types.array import CovarianceMatrix, StateVectors
from ...types.detection import Detection, MissedDetection
from ...types.state import GaussianState
from ...types.state import GaussianState, ParticleState
from ...types.track import Track
from ...updater.particle import ParticleUpdater


class DetectionKDTreeNN(NearestNeighbour, DetectionKDTreeMixIn):
Expand Down Expand Up @@ -150,8 +154,6 @@ def test_nearest_neighbour(nn_associator):

def test_tpr_tree_management(nn_associator, updater):
'''Test method for TPR insert, delete and update'''
if not isinstance(nn_associator, TPRTreeNN):
return
timestamp = datetime.datetime.now()

t1 = Track([GaussianState(np.array([[0, 0, 0, 0]]), np.diag([1, 0.1, 1, 0.1]), timestamp)])
Expand Down Expand Up @@ -201,8 +203,6 @@ def test_tpr_tree_management(nn_associator, updater):

def test_tpr_tree_measurement_models(nn_associator, measurement_model):
'''Test method for TPR insert, delete and update using non linear measurement model'''
if not isinstance(nn_associator, TPRTreeNN):
return
timestamp = datetime.datetime.now()
measurement_model_nl = CartesianToBearingRange(
ndim_state=4, mapping=[0, 2],
Expand Down Expand Up @@ -359,3 +359,49 @@ def test_no_tracks_probability(pda_associator):

# Since no Tracks went in, there should be no associations
assert not associations


def test_particle_tree(nn_associator):
timestamp = datetime.datetime.now()
p1 = multivariate_normal.rvs(np.array([0, 0, 0, 0]),
np.diag([1, 0.1, 1, 0.1]),
size=200)
p2 = multivariate_normal.rvs(np.array([3, 0, 3, 0]),
np.diag([1, 0.1, 1, 0.1]),
size=200)
t1 = Track([ParticleState(StateVectors(p1.T), timestamp, np.full(200, 1 / 200))])
t2 = Track([ParticleState(StateVectors(p2.T), timestamp, np.full(200, 1 / 200))])
d1 = Detection(np.array([[2, 2]]), timestamp)
d2 = Detection(np.array([[5, 5]]), timestamp)

tracks = {t1, t2}
detections = {d1, d2}

# Switch predictor/updater to Particle ones.
nn_associator.hypothesiser.predictor = ParticlePredictor(
nn_associator.hypothesiser.predictor.transition_model)
nn_associator.hypothesiser.updater = ParticleUpdater(
nn_associator.hypothesiser.updater.measurement_model)
if isinstance(nn_associator, DetectionKDTreeMixIn):
nn_associator.predictor = nn_associator.hypothesiser.predictor
nn_associator.updater = nn_associator.hypothesiser.updater
associations = nn_associator.associate(tracks, detections, timestamp)

# There should be 2 associations
assert len(associations) == 2

# Each track should associate with a unique detection
associated_measurements = [hypothesis.measurement
for hypothesis in associations.values()
if hypothesis.measurement]
assert len(associated_measurements) == len(set(associated_measurements))

tracks = {}
associations = nn_associator.associate(tracks, detections, timestamp)
assert len(associations) == 0

tracks = {t1, t2}
detections = {}
associations = nn_associator.associate(tracks, detections, timestamp)

assert len([hypothesis for hypothesis in associations.values() if not hypothesis]) == 2
24 changes: 18 additions & 6 deletions stonesoup/dataassociator/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scipy.spatial import KDTree
try:
import rtree
except (ImportError, AttributeError, OSError) as err:
except (ImportError, AttributeError, OSError) as err: # pragma: no cover
# AttributeError or OSError raised when libspatialindex missing or unable to load.
import warnings
warnings.warn(f"Failed to import 'rtree': {err!r}")
Expand Down Expand Up @@ -68,13 +68,18 @@ def generate_hypotheses(self, tracks, detections, timestamp, **kwargs):
prediction = self.predictor.predict(track.state, timestamp)
meas_pred = self.updater.predict_measurement(prediction)

try:
meas_pred_state_vector = meas_pred.mean
except AttributeError:
meas_pred_state_vector = meas_pred.state_vector

if self.number_of_neighbours is None:
indexes = tree.query_ball_point(
meas_pred.state_vector.ravel(),
meas_pred_state_vector.ravel(),
r=self.max_distance)
else:
_, indexes = tree.query(
meas_pred.state_vector.ravel(),
meas_pred_state_vector.ravel(),
k=self.number_of_neighbours,
distance_upper_bound=self.max_distance)

Expand All @@ -91,7 +96,14 @@ def generate_hypotheses(self, tracks, detections, timestamp, **kwargs):
class TPRTreeMixIn(DataAssociator):
"""Detection TPR tree based mixin

Construct a TPR-tree.
Construct a TPR-tree to filter detections for generating hypotheses. This assumes
tracks move in constant velocity like model, using the mean and covariance to define
region to look for detections.

Notes
-----
This requires that track state has a mean (position and velocity) and covariance, which
is then approximated to a TPR node (position, velocity and time bounding box).
"""
measurement_model: MeasurementModel = Property(
doc="Measurement model used within the TPR tree")
Expand Down Expand Up @@ -125,10 +137,10 @@ def __init__(self, *args, **kwargs):
self._coords = dict()

def _track_tree_coordinates(self, track):
state_vector = track.state_vector[self.pos_mapping, :]
state_vector = track.mean[self.pos_mapping, :]
state_delta = 3 * np.sqrt(
np.diag(track.covar)[self.pos_mapping].reshape(-1, 1))
vel_vector = track.state_vector[self.vel_mapping, :]
vel_vector = track.mean[self.vel_mapping, :]
vel_delta = 3 * np.sqrt(
np.diag(track.covar)[self.vel_mapping].reshape(-1, 1))

Expand Down