-
Notifications
You must be signed in to change notification settings - Fork 141
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #625 from 0sm1um/ensemble_kalman
Ensemble Kalman Filter Algorithm and Data Type
- Loading branch information
Showing
11 changed files
with
596 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from .base import Property | ||
from ..models.transition import TransitionModel | ||
from .kalman import KalmanPredictor | ||
from ..types.array import StateVectors | ||
from ..types.state import State | ||
from ..types.prediction import Prediction | ||
|
||
|
||
class EnsemblePredictor(KalmanPredictor): | ||
r"""Ensemble Kalman Filter Predictor class | ||
The EnKF predicts the state by treating each column of the ensemble matrix | ||
as a state vector. The state is propagated through time by applying the | ||
transition function to each member (vector) of the ensemble. | ||
.. math:: | ||
\hat{X}_k = [f(x_1), f(x_2), ..., f(x_M)] | ||
""" | ||
transition_model: TransitionModel = Property(doc="The transition model to be used.") | ||
|
||
def predict(self, prior, timestamp=None, **kwargs): | ||
"""Ensemble Kalman Filter prediction step | ||
Parameters | ||
---------- | ||
prior : :class:`~.EnsembleState` | ||
A prior state object | ||
control_input : :class:`~.State`, optional | ||
The control input. It will only have an effect if | ||
:attr:`control_model` is not `None` (the default is `None`) | ||
timestamp: :class:`datetime.datetime`, optional | ||
A timestamp signifying when the prediction is performed | ||
(the default is `None`) | ||
Returns | ||
------- | ||
: :class:`~.EnsembleStatePrediction` | ||
The predicted state | ||
""" | ||
|
||
# Compute time_interval | ||
time_interval = self._predict_over_interval(prior, timestamp) | ||
# This block of code propagates each column through the transition model. | ||
pred_ensemble = StateVectors( | ||
[self.transition_model.function(State(state_vector=ensemble_member), | ||
noise=True, time_interval=time_interval) | ||
for ensemble_member in prior.ensemble.T]) | ||
|
||
return Prediction.from_state(prior, pred_ensemble, timestamp=timestamp, | ||
transition_model=self.transition_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# coding: utf-8 | ||
import datetime | ||
|
||
import numpy as np | ||
|
||
from ...models.transition.linear import ConstantVelocity | ||
from ...predictor.ensemble import EnsemblePredictor | ||
from ...types.state import GaussianState, EnsembleState | ||
from ...types.array import StateVector, CovarianceMatrix | ||
|
||
|
||
def test_ensemble(): | ||
# Initialise a transition model | ||
transition_model = ConstantVelocity(noise_diff_coeff=0) | ||
|
||
# Define time related variables | ||
timestamp = datetime.datetime(2021, 2, 27, 17, 27, 48) | ||
timediff = 1 # 1 second | ||
new_timestamp = timestamp + datetime.timedelta(seconds=timediff) | ||
time_interval = new_timestamp - timestamp | ||
|
||
# Define prior state | ||
mean = StateVector([[10], [10]]) | ||
covar = CovarianceMatrix(np.eye(2)) | ||
gaussian_state = GaussianState(mean, covar, timestamp) | ||
num_vectors = 50 | ||
prior_state = EnsembleState.from_gaussian_state(gaussian_state, num_vectors) | ||
prior_ensemble = prior_state.ensemble | ||
|
||
# Create Predictor object, run prediction | ||
predictor = EnsemblePredictor(transition_model) | ||
prediction = predictor.predict(prior_state, timestamp=new_timestamp) | ||
|
||
# Evaluate mean and covariance | ||
eval_ensemble = transition_model.matrix(timestamp=new_timestamp, | ||
time_interval=time_interval) @ prior_ensemble | ||
eval_mean = StateVector((np.average(eval_ensemble, axis=1))) | ||
eval_cov = np.cov(eval_ensemble) | ||
|
||
# Compare evaluated mean and covariance with predictor results | ||
assert np.allclose(prediction.mean, eval_mean) | ||
assert np.allclose(prediction.ensemble, eval_ensemble) | ||
assert np.allclose(prediction.covar, eval_cov) | ||
assert prediction.timestamp == new_timestamp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.