-
Notifications
You must be signed in to change notification settings - Fork 141
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1080 from A-acuto/new_pf_proposal
Particle filter proposal implementation
- Loading branch information
Showing
8 changed files
with
330 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
Proposal | ||
=========== | ||
|
||
.. automodule:: stonesoup.proposal | ||
:no-members: | ||
|
||
.. automodule:: stonesoup.proposal.base | ||
:show-inheritance: | ||
|
||
Simple | ||
------ | ||
|
||
.. automodule:: stonesoup.proposal.simple | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from abc import abstractmethod | ||
from stonesoup.base import Base | ||
|
||
|
||
class Proposal(Base): | ||
|
||
@abstractmethod | ||
def rvs(self, *args, **kwargs): | ||
r"""Proposal noise/sample generation function | ||
Generates samples from the proposal. | ||
Parameters | ||
---------- | ||
state: :class:`~.State` | ||
The state to generate samples from. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from typing import Union | ||
|
||
import numpy as np | ||
from scipy.stats import multivariate_normal as mvn | ||
|
||
from stonesoup.base import Property | ||
from stonesoup.models.transition import TransitionModel | ||
from stonesoup.proposal.base import Proposal | ||
from stonesoup.types.array import StateVector, StateVectors | ||
from stonesoup.types.detection import Detection | ||
from stonesoup.types.state import State, GaussianState, SqrtGaussianState | ||
from stonesoup.types.prediction import Prediction | ||
from stonesoup.updater.base import Updater | ||
from stonesoup.predictor.base import Predictor | ||
from stonesoup.predictor.kalman import SqrtKalmanPredictor | ||
from stonesoup.types.hypothesis import SingleHypothesis | ||
|
||
|
||
class PriorAsProposal(Proposal): | ||
"""Proposal that uses the dynamics model as the importance density. | ||
This proposal uses the dynamics model to predict the next state, and then | ||
uses the predicted state as the prior for the measurement model. | ||
""" | ||
transition_model: TransitionModel = Property( | ||
doc="The transition model used to make the prediction") | ||
|
||
def rvs(self, prior: State, measurement=None, time_interval=None, | ||
**kwargs) -> Union[StateVector, StateVectors]: | ||
"""Generate samples from the proposal. | ||
Parameters | ||
---------- | ||
state: :class:`~.State` | ||
The state to generate samples from. | ||
Returns | ||
------- | ||
: :class:`~.ParticlePrediction` with samples drawn from the updated proposal | ||
""" | ||
|
||
if measurement is not None: | ||
timestamp = measurement.timestamp | ||
time_interval = measurement.timestamp - prior.timestamp | ||
else: | ||
timestamp = prior.timestamp + time_interval | ||
|
||
new_state_vector = self.transition_model.function(prior, | ||
time_interval=time_interval, | ||
**kwargs) | ||
return Prediction.from_state(prior, | ||
parent=prior, | ||
state_vector=new_state_vector, | ||
timestamp=timestamp, | ||
transition_model=self.transition_model, | ||
prior=prior) | ||
|
||
|
||
class KFasProposal(Proposal): | ||
"""This proposal uses the kalman filter prediction and update steps to | ||
generate new set of particles and weights | ||
""" | ||
predictor: Predictor = Property( | ||
doc="predictor to use the various values") | ||
updater: Updater = Property( | ||
doc="Updater used for update the values") | ||
|
||
def rvs(self, prior: State, measurement: Detection = None, time_interval=None, | ||
**kwargs): | ||
"""Generate samples from the proposal. | ||
Use the kalman filter predictor and updater to create a new distribution | ||
Parameters | ||
---------- | ||
state: :class:`~.State` | ||
The state to generate samples from. | ||
measurement: :class:`~.Detection` | ||
the measurement that is used in the update step of the kalman prediction, | ||
(the default is `None`) | ||
time_interval: :class:`datetime.time_delta` | ||
time interval of the prediction is needed to propagate the states | ||
Returns | ||
------- | ||
: :class:`~.ParticlePrediction` | ||
""" | ||
|
||
if measurement is not None: | ||
timestamp = measurement.timestamp | ||
time_interval = measurement.timestamp - prior.timestamp | ||
else: | ||
timestamp = prior.timestamp + time_interval | ||
|
||
if time_interval.total_seconds() == 0: | ||
return Prediction.from_state(prior, | ||
parent=prior, | ||
state_vector=prior.state_vector, | ||
timestamp=prior.timestamp, | ||
transition_model=self.predictor.transition_model, | ||
prior=prior) | ||
|
||
prior_cls = GaussianState # Default | ||
if isinstance(self.predictor, SqrtKalmanPredictor): | ||
prior_cls = SqrtGaussianState | ||
|
||
# Null covariance for the particles | ||
null_covar = np.zeros_like(prior.covar) | ||
|
||
predictions = [ | ||
self.predictor.predict( | ||
prior_cls(particle_sv, null_covar, prior.timestamp), | ||
timestamp=timestamp) | ||
for particle_sv in prior.state_vector] | ||
|
||
if measurement is not None: | ||
updates = [self.updater.update(SingleHypothesis(prediction, measurement)) | ||
for prediction in predictions] | ||
else: | ||
updates = predictions # keep the prediction | ||
|
||
# Draw the samples | ||
samples = np.array([state.state_vector.reshape(-1) + | ||
mvn.rvs(cov=state.covar).T | ||
for state in updates]) | ||
|
||
# Compute the log of q(x_k|x_{k-1}, y_k) | ||
post_log_weights = np.array([mvn.logpdf(sample - update.state_vector.reshape(-1), | ||
cov=update.covar) | ||
for sample, update in zip(samples, updates)]) | ||
|
||
pred_state = Prediction.from_state(prior, | ||
parent=prior, | ||
state_vector=StateVectors(samples.T), | ||
timestamp=timestamp, | ||
transition_model=self.predictor.transition_model, | ||
prior=prior) | ||
|
||
prior_log_weights = self.predictor.transition_model.logpdf(pred_state, prior, | ||
time_interval=time_interval) | ||
|
||
pred_state.log_weight = (pred_state.log_weight + prior_log_weights - post_log_weights) | ||
|
||
return pred_state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import itertools | ||
|
||
import datetime | ||
import numpy as np | ||
|
||
# Import the proposals | ||
from stonesoup.proposal.simple import PriorAsProposal, KFasProposal | ||
from stonesoup.models.transition.linear import ConstantVelocity | ||
from stonesoup.types.particle import Particle | ||
from stonesoup.types.prediction import ParticleStatePrediction | ||
from stonesoup.predictor.kalman import KalmanPredictor | ||
from stonesoup.updater.kalman import KalmanUpdater | ||
from stonesoup.types.state import ParticleState, GaussianState | ||
from stonesoup.predictor.particle import ParticlePredictor | ||
from stonesoup.types.detection import Detection | ||
from stonesoup.models.measurement.linear import LinearGaussian | ||
from stonesoup.types.hypothesis import SingleHypothesis | ||
|
||
|
||
def test_prior_proposal(): | ||
# test that the proposal as prior and basic PF implementation | ||
# yield same results, since they are driven by the transition model | ||
|
||
# Initialise a transition model | ||
cv = ConstantVelocity(noise_diff_coeff=0) | ||
|
||
# Define time related variables | ||
timestamp = datetime.datetime.now() | ||
timediff = 2 # 2 sec | ||
new_timestamp = timestamp + datetime.timedelta(seconds=timediff) | ||
time_interval = new_timestamp - timestamp | ||
|
||
num_particles = 9 # Number of particles | ||
|
||
# Define prior state | ||
prior_particles = [Particle(np.array([[i], [j]]), 1/num_particles) | ||
for i, j in itertools.product([10, 20, 30], [10, 20, 30])] | ||
prior = ParticleState(None, particle_list=prior_particles, timestamp=timestamp) | ||
|
||
# predictors prior and standard stone soup | ||
predictor_prior = ParticlePredictor(cv, | ||
proposal=PriorAsProposal(cv)) | ||
|
||
# Check that the predictor without prior specified works with the prior as | ||
# proposal | ||
predictor_base = ParticlePredictor(cv) | ||
|
||
# basic transition model evaluations | ||
eval_particles = [Particle(cv.matrix(timestamp=new_timestamp, | ||
time_interval=time_interval) | ||
@ particle.state_vector, | ||
1 / 9) | ||
for particle in prior_particles] | ||
eval_mean = np.mean(np.hstack([i.state_vector for i in eval_particles]), | ||
axis=1).reshape(2, 1) | ||
|
||
# construct the evaluation prediction | ||
eval_prediction = ParticleStatePrediction(None, new_timestamp, particle_list=eval_particles) | ||
|
||
prediction_base = predictor_base.predict(prior, timestamp=new_timestamp) | ||
prediction_prior = predictor_prior.predict(prior, timestamp=new_timestamp) | ||
|
||
assert np.all([eval_prediction.state_vector[:, i] == | ||
prediction_base.state_vector[:, i] for i in range(9)]) | ||
assert np.all([prediction_base.weight[i] == 1 / 9 for i in range(9)]) | ||
|
||
assert np.allclose(prediction_prior.mean, eval_mean) | ||
assert prediction_prior.timestamp == new_timestamp | ||
assert np.all([eval_prediction.state_vector[:, i] == | ||
prediction_prior.state_vector[:, i] for i in range(9)]) | ||
assert np.all([prediction_prior.weight[i] == 1 / 9 for i in range(9)]) | ||
|
||
|
||
def test_kf_proposal(): | ||
|
||
# Initialise a transition model | ||
cv = ConstantVelocity(noise_diff_coeff=1) | ||
|
||
# initialise the measurement model | ||
lg = LinearGaussian(ndim_state=2, | ||
mapping=[0], | ||
noise_covar=np.diag([1])) | ||
|
||
# Define time related variables | ||
timestamp = datetime.datetime.now() | ||
timediff = 2 # 2 sec | ||
new_timestamp = timestamp + datetime.timedelta(seconds=timediff) | ||
time_interval = new_timestamp - timestamp | ||
|
||
num_particles = 9 # Number of particles | ||
|
||
# Define prior state | ||
prior_particles = [Particle(np.array([[i], [j]]), 1/num_particles) | ||
for i, j in itertools.product([1, 2, 3], [1, 2, 3])] | ||
|
||
prior = ParticleState(None, particle_list=prior_particles, timestamp=timestamp) | ||
|
||
# null covariance for the predictions | ||
null_covar = np.zeros_like(prior.covar) | ||
prior_kf = GaussianState(prior.mean, null_covar, prior.timestamp) | ||
|
||
# Kalman filter components | ||
kf_predictor = KalmanPredictor(cv) | ||
kf_updater = KalmanUpdater(lg) | ||
|
||
# perform the kalman filter update | ||
prediction = kf_predictor.predict(prior_kf, timestamp=new_timestamp) | ||
|
||
# state prediction | ||
new_state = GaussianState(state_vector=cv.function(prior_kf, noise=True, | ||
time_interval=time_interval), | ||
covar=np.diag([1, 1]), | ||
timestamp=new_timestamp) | ||
|
||
detection = Detection(lg.function(new_state, | ||
noise=True), | ||
timestamp=new_timestamp, | ||
measurement_model=lg) | ||
|
||
eval_state = kf_updater.update(SingleHypothesis(prediction, detection)) | ||
|
||
proposal = KFasProposal(KalmanPredictor(cv), | ||
KalmanUpdater(lg)) | ||
# particle proposal | ||
particle_proposal = proposal.rvs(prior, measurement=detection, time_interval=time_interval) | ||
|
||
assert particle_proposal.state_vector.shape == prior.state_vector.shape | ||
assert np.allclose(particle_proposal.mean, eval_state.state_vector, rtol=1) | ||
assert particle_proposal.timestamp == new_timestamp |
Oops, something went wrong.