-
Notifications
You must be signed in to change notification settings - Fork 142
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #681 from ekhunter123/add_chernoff_updater
Add ChernoffUpdater class and track feeder
- Loading branch information
Showing
6 changed files
with
320 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# -*- coding: utf-8 -*- | ||
import pytest | ||
import numpy as np | ||
|
||
from ...types.state import GaussianState | ||
from ...types.track import Track | ||
from ..track import Tracks2GaussianDetectionFeeder | ||
|
||
t1 = Track(GaussianState([1, 1, 1, 1], np.diag([2, 2, 2, 2]), timestamp=2)) | ||
t2 = Track([GaussianState([1, 1, 1, 1], np.diag([2, 2, 2, 2]), timestamp=1), | ||
GaussianState([2, 1, 2, 1], np.diag([2, 2, 2, 2]), timestamp=2)]) | ||
t3 = Track([GaussianState([1, 1], np.diag([2, 2]), timestamp=0), | ||
GaussianState([2, 1], np.diag([2, 2]), timestamp=1), | ||
GaussianState([3, 1], np.diag([2, 2]), timestamp=2)]) | ||
t4 = Track(GaussianState([1, 0, 1, 0, 1, 0], np.diag([2, 2, 2, 2, 2, 2]), timestamp=2)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"tracks", | ||
[ | ||
([t1]), | ||
([t1, t2]), | ||
([t2, t3]), | ||
([t1, t2, t3, t4]) | ||
] | ||
) | ||
def test_Track2GaussianDetectionFeeder(tracks): | ||
|
||
# Make feeder and get detections | ||
reader = [(tracks[0][-1].timestamp, tracks)] | ||
feeder = Tracks2GaussianDetectionFeeder(reader=reader) | ||
time, detections = next(feeder.data_gen()) | ||
|
||
# Check that there are the right number of detections | ||
assert len(detections) == len(tracks) | ||
|
||
# Check that the correct state was taken from each track | ||
assert np.all([detections[i].timestamp == tracks[i][-1].timestamp for i in range(len(tracks))]) | ||
assert np.all([detections[i].timestamp == time for i in range(len(tracks))]) | ||
|
||
# Check that the dimension of each detection is correct | ||
assert np.all([len(detections[i].state_vector) == len(tracks[i][-1].state_vector) | ||
for i in range(len(tracks))]) | ||
assert np.all([len(detection.state_vector) == detection.measurement_model.ndim | ||
for detection in detections]) | ||
|
||
# Check that the detection has the correct mean and covariance | ||
for i in range(len(tracks)): | ||
assert np.all(detections[i].state_vector == tracks[i][-1].state_vector) | ||
assert np.all(detections[i].covar == tracks[i][-1].covar) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# -*- coding: utf-8 -*- | ||
import numpy as np | ||
|
||
from stonesoup.types.detection import GaussianDetection | ||
from stonesoup.feeder.base import DetectionFeeder | ||
from stonesoup.models.measurement.linear import LinearGaussian | ||
from ..buffered_generator import BufferedGenerator | ||
|
||
|
||
class Tracks2GaussianDetectionFeeder(DetectionFeeder): | ||
''' | ||
Feeder consumes Track objects and outputs GaussianDetection objects. | ||
At each time step, the :attr:`Reader` feeds in a set of live tracks. The feeder takes the most | ||
recent state from each of those tracks, and turn them into a set of | ||
:class:`~.GaussianDetection` objects. Each detection is given a :class:`~.LinearGaussian` | ||
measurement model whose covariance is equal to the state covariance. The feeder assumes that | ||
the tracks are all live, that is each track has a state at the most recent time step. | ||
''' | ||
@BufferedGenerator.generator_method | ||
def data_gen(self): | ||
for time, tracks in self.reader: | ||
detections = [] | ||
for track in tracks: | ||
dim = len(track.state.state_vector) | ||
detections.append( | ||
GaussianDetection.from_state( | ||
track.state, | ||
measurement_model=LinearGaussian( | ||
dim, list(range(dim)), np.asarray(track.covar)), | ||
target_type=GaussianDetection) | ||
) | ||
|
||
yield time, detections |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# -*- coding: utf-8 -*- | ||
import numpy as np | ||
|
||
from ..base import Property | ||
from .base import Updater | ||
from ..types.prediction import MeasurementPrediction | ||
from ..types.update import Update | ||
|
||
|
||
class ChernoffUpdater(Updater): | ||
r"""A class which performs state updates using the Chernoff fusion rule. In this context, | ||
measurements come in the form of states with a mean and covariance instead of just the | ||
traditional mean. The measurements are expected to come as :class:`~.GaussianDetection` | ||
objects. | ||
The Chernoff fusion rule is written as | ||
.. math:: | ||
p_{\omega}(x_{k}) = \frac{p_{1}(x_{k})^{\omega}p_{2}(x_{k})^{1-\omega}} | ||
{\int p_{1}(x)^{\omega}p_{2}(x)^{1-\omega} \mathrm{d} x} | ||
where :math:`\omega` is a weighting parameter in the range :math:`(0,1]`, which can be found | ||
using an optimization algorithm. | ||
In situations where :math:`p_1(x)` and :math:`p_2(x)` are multivariate Gaussian distributions, | ||
the above formula is equal to the Covariance Intersection Algorithm from Julier and Uhlmann. | ||
Let :math:`(a,A)` and :math:`(b,B)` be the means and covariances of the measurement and | ||
prediction respectively. The Covariance Intersection Algorithm was reformulated for use in | ||
Bayesian state estimation by Clark et al, yielding the update formulas for the covariance, | ||
mean, and innovation: | ||
.. math:: | ||
D &= \left ( \omega A^{-1} + (1-\omega)B^{-1} \right )\\ | ||
d &= D \left ( \omega A^{-1}a + (1-\omega)B^{-1}b \right )\\ | ||
V &= \frac{A}{1-\omega} + \frac{B}{\omega} | ||
In filters where gating is required, the gating region can be written using the innovation | ||
covariance matrix as: | ||
.. math:: | ||
\mathcal{V}(\gamma) = \left\{ (a,A) : (a-b)^T V^{-1} (a-b) \leq \gamma \right\} | ||
Note: If you have tracks that you would like to use as measurements for this updater, the | ||
:class:`~.Tracks2GaussianDetectionFeeder` class can be used to convert the tracks to the | ||
appropriate format. | ||
References | ||
---------- | ||
[1] Hurley, M. B., “An information theoretic justification for covariance intersection and its | ||
generalization,” in [Proceedings of the Fifth International Conference on Information Fusion. | ||
FUSION 2002.(IEEE Cat. No. 02EX5997) ], 1, 505–511, IEEE (2002). | ||
https://ieeexplore.ieee.org/document/1021196. | ||
[2] Julier, S., Uhlmann, J., and Durrant-Whyte, H. F., “A new method for the nonlinear | ||
transformation of means and covariances in filters and estimators,” IEEE Transactions on | ||
automatic control 45(3), 477–482 (2000). | ||
https://ieeexplore.ieee.org/abstract/document/847726/similar#similar. | ||
[3] Clark, D. E. and Campbell, M. A., “Integrating covariance intersection into Bayesian | ||
multi-target tracking filters,” preprint on TechRxiv. submitted to IEEE Transactions on | ||
Aerospace and Electronic Systems . | ||
""" | ||
|
||
omega: float = Property( | ||
default=0.5, | ||
doc="A weighting parameter in the range :math:`(0,1]`") | ||
|
||
def predict_measurement(self, predicted_state, measurement_model=None, **kwargs): | ||
''' | ||
This function predicts the measurement in situations where the predicted state consists | ||
of a covariance and state vector. | ||
''' | ||
|
||
measurement_model = self._check_measurement_model(measurement_model) | ||
|
||
# The innovation covariance uses the noise covariance from the measurement model | ||
state_covar_m = measurement_model.noise_covar | ||
innov_covar = 1/(1-self.omega)*state_covar_m + 1/self.omega*predicted_state.covar | ||
|
||
# The predicted measurement and measurement cross covariance can be taken from | ||
# the predicted state | ||
predicted_meas = predicted_state.state_vector | ||
meas_cross_cov = predicted_state.covar | ||
|
||
# Combine everything into a GaussianMeasurementPrediction object | ||
return MeasurementPrediction.from_state(predicted_state, predicted_meas, innov_covar, | ||
predicted_state.timestamp, | ||
cross_covar=meas_cross_cov) | ||
|
||
def update(self, hypothesis, force_symmetric_covariance=False, **kwargs): | ||
''' | ||
Given a hypothesis, calculate the posterior mean and covariance | ||
''' | ||
|
||
# Get the predicted state out of the hypothesis. These are 'B' and 'b', the | ||
# covariance and mean of the predicted Gaussian | ||
predicted_covar = hypothesis.prediction.covar | ||
predicted_mean = hypothesis.prediction.state_vector | ||
|
||
# Extract the vector and covariance from the measurement. These are 'A' and 'a', the | ||
# covariance and mean of the Gaussian measurement. | ||
measurement_covar = hypothesis.measurement.covar | ||
measurement_mean = hypothesis.measurement.state_vector | ||
|
||
# Predict the measurement if it is not already done | ||
if hypothesis.measurement_prediction is None: | ||
hypothesis.measurement_prediction = self.predict_measurement( | ||
hypothesis.prediction, | ||
measurement_model=hypothesis.measurement.measurement_model, | ||
**kwargs | ||
) | ||
|
||
# Calculate the updated mean and covariance from covariance intersection | ||
posterior_covariance = np.linalg.inv(self.omega*np.linalg.inv(measurement_covar) + | ||
(1-self.omega)*np.linalg.inv(predicted_covar)) | ||
posterior_mean = posterior_covariance @ (self.omega*np.linalg.inv(measurement_covar) | ||
@ measurement_mean + | ||
(1-self.omega)*np.linalg.inv(predicted_covar) | ||
@ predicted_mean) | ||
|
||
# Optionally force the posterior covariance to be a symmetric matrix | ||
if force_symmetric_covariance: | ||
posterior_covariance = \ | ||
(posterior_covariance + posterior_covariance.T)/2 | ||
|
||
# Return the updated state | ||
return Update.from_state(hypothesis.prediction, posterior_mean, posterior_covariance, | ||
hypothesis, hypothesis.measurement.timestamp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Test for updater.chernoff module""" | ||
import pytest | ||
import numpy as np | ||
|
||
from stonesoup.models.measurement.linear import LinearGaussian | ||
from stonesoup.types.detection import GaussianDetection | ||
from stonesoup.types.hypothesis import SingleHypothesis | ||
from stonesoup.types.prediction import ( | ||
GaussianStatePrediction, GaussianMeasurementPrediction) | ||
from stonesoup.types.state import GaussianState | ||
from stonesoup.updater.chernoff import ChernoffUpdater | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"UpdaterClass, measurement_model, prediction, measurement, omega", | ||
[ | ||
( # Chernoff Updater | ||
ChernoffUpdater, | ||
LinearGaussian(ndim_state=2, mapping=[0, 1], | ||
noise_covar=np.array([[0.04, 0.04]])), | ||
GaussianStatePrediction(np.array([[-6.45], [0.7]]), | ||
np.array([[4.1123, 0.0013], | ||
[0.0013, 0.0365]])), | ||
GaussianDetection(state_vector=np.array([[-6.23, 0.83]]), | ||
covar=np.diag([0.75, 1.2])), | ||
0.5 | ||
) | ||
], | ||
ids=["standard"] | ||
) | ||
def test_chernoff(UpdaterClass, measurement_model, prediction, measurement, omega): | ||
|
||
# Calculate evaluation variables | ||
innov_cov = 1/(1-omega)*measurement_model.noise_covar + 1/omega*prediction.covar | ||
eval_measurement_prediction = GaussianMeasurementPrediction( | ||
measurement_model.matrix() @ prediction.mean, | ||
innov_cov, | ||
cross_covar=prediction.covar @ measurement_model.matrix().T) | ||
|
||
posterior_cov = np.linalg.inv(omega*np.linalg.inv(measurement.covar) + | ||
(1-omega)*np.linalg.inv(prediction.covar)) | ||
posterior_mean = posterior_cov@(omega*np.linalg.inv(measurement.covar) @ | ||
measurement.state_vector + (1-omega) * | ||
np.linalg.inv(prediction.covar)@prediction.state_vector) | ||
eval_posterior = GaussianState( | ||
posterior_mean, | ||
posterior_cov) | ||
|
||
# Initialise a Chernoff updater | ||
updater = UpdaterClass(measurement_model=measurement_model, omega=omega) | ||
|
||
# Get and assert measurement prediction | ||
measurement_prediction = updater.predict_measurement(prediction) | ||
assert(np.allclose(measurement_prediction.mean, | ||
eval_measurement_prediction.mean, | ||
0, atol=1.e-14)) | ||
assert(np.allclose(measurement_prediction.covar, | ||
eval_measurement_prediction.covar, | ||
0, atol=1.e-14)) | ||
assert(np.allclose(measurement_prediction.cross_covar, | ||
eval_measurement_prediction.cross_covar, | ||
0, atol=1.e-14)) | ||
|
||
# Perform and assert state update (without measurement prediction) | ||
posterior = updater.update(SingleHypothesis( | ||
prediction=prediction, | ||
measurement=measurement)) | ||
assert(np.allclose(posterior.mean, eval_posterior.mean, 0, atol=1.e-14)) | ||
assert(np.allclose(posterior.covar, eval_posterior.covar, 0, atol=1.e-14)) | ||
assert(np.array_equal(posterior.hypothesis.prediction, prediction)) | ||
assert (np.allclose( | ||
posterior.hypothesis.measurement_prediction.state_vector, | ||
measurement_prediction.state_vector, 0, atol=1.e-14)) | ||
assert (np.allclose(posterior.hypothesis.measurement_prediction.covar, | ||
measurement_prediction.covar, 0, atol=1.e-14)) | ||
assert(np.array_equal(posterior.hypothesis.measurement, measurement)) | ||
assert(posterior.timestamp == prediction.timestamp) | ||
|
||
# Perform and assert state update | ||
posterior = updater.update(SingleHypothesis( | ||
prediction=prediction, | ||
measurement=measurement, | ||
measurement_prediction=measurement_prediction)) | ||
assert(np.allclose(posterior.mean, eval_posterior.mean, 0, atol=1.e-14)) | ||
assert(np.allclose(posterior.covar, eval_posterior.covar, 0, atol=1.e-14)) | ||
assert(np.array_equal(posterior.hypothesis.prediction, prediction)) | ||
assert (np.allclose( | ||
posterior.hypothesis.measurement_prediction.state_vector, | ||
measurement_prediction.state_vector, 0, atol=1.e-14)) | ||
assert (np.allclose(posterior.hypothesis.measurement_prediction.covar, | ||
measurement_prediction.covar, 0, atol=1.e-14)) | ||
assert(np.array_equal(posterior.hypothesis.measurement, measurement)) | ||
assert(posterior.timestamp == prediction.timestamp) |