Skip to content

Commit

Permalink
Merge pull request #625 from 0sm1um/ensemble_kalman
Browse files Browse the repository at this point in the history
Ensemble Kalman Filter Algorithm and Data Type
  • Loading branch information
sdhiscocks authored May 17, 2022
2 parents f6fae9f + 85be6fa commit c9c5e8a
Show file tree
Hide file tree
Showing 11 changed files with 596 additions and 8 deletions.
5 changes: 5 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ MIT License
© Crown Copyright 2017-2022 Defence Science and Technology Laboratory UK
© Crown Copyright 2018-2022 Defence Research and Development Canada / Recherche et développement pour la défense Canada
© Copyright 2018-2022 University of Liverpool UK
© Copyright 2020-2022 John Hiles

Portions of this work were created while John Hiles was employed by Wright
State University under Contract No. FA8650-18-2-1645 to the Air Force Research
Laboratory, and the University has released its rights to this Student work.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
6 changes: 6 additions & 0 deletions docs/source/stonesoup.predictor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ Particle
.. automodule:: stonesoup.predictor.particle
:show-inheritance:

Ensemble
--------

.. automodule:: stonesoup.predictor.ensemble
:show-inheritance:

Categorical
-----------

Expand Down
6 changes: 6 additions & 0 deletions docs/source/stonesoup.updater.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ Particle
.. automodule:: stonesoup.updater.particle
:show-inheritance:

Ensemble
--------

.. automodule:: stonesoup.updater.ensemble
:show-inheritance:

Information
-----------

Expand Down
52 changes: 52 additions & 0 deletions stonesoup/predictor/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from .base import Property
from ..models.transition import TransitionModel
from .kalman import KalmanPredictor
from ..types.array import StateVectors
from ..types.state import State
from ..types.prediction import Prediction


class EnsemblePredictor(KalmanPredictor):
r"""Ensemble Kalman Filter Predictor class
The EnKF predicts the state by treating each column of the ensemble matrix
as a state vector. The state is propagated through time by applying the
transition function to each member (vector) of the ensemble.
.. math::
\hat{X}_k = [f(x_1), f(x_2), ..., f(x_M)]
"""
transition_model: TransitionModel = Property(doc="The transition model to be used.")

def predict(self, prior, timestamp=None, **kwargs):
"""Ensemble Kalman Filter prediction step
Parameters
----------
prior : :class:`~.EnsembleState`
A prior state object
control_input : :class:`~.State`, optional
The control input. It will only have an effect if
:attr:`control_model` is not `None` (the default is `None`)
timestamp: :class:`datetime.datetime`, optional
A timestamp signifying when the prediction is performed
(the default is `None`)
Returns
-------
: :class:`~.EnsembleStatePrediction`
The predicted state
"""

# Compute time_interval
time_interval = self._predict_over_interval(prior, timestamp)
# This block of code propagates each column through the transition model.
pred_ensemble = StateVectors(
[self.transition_model.function(State(state_vector=ensemble_member),
noise=True, time_interval=time_interval)
for ensemble_member in prior.ensemble.T])

return Prediction.from_state(prior, pred_ensemble, timestamp=timestamp,
transition_model=self.transition_model)
44 changes: 44 additions & 0 deletions stonesoup/predictor/tests/test_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# coding: utf-8
import datetime

import numpy as np

from ...models.transition.linear import ConstantVelocity
from ...predictor.ensemble import EnsemblePredictor
from ...types.state import GaussianState, EnsembleState
from ...types.array import StateVector, CovarianceMatrix


def test_ensemble():
# Initialise a transition model
transition_model = ConstantVelocity(noise_diff_coeff=0)

# Define time related variables
timestamp = datetime.datetime(2021, 2, 27, 17, 27, 48)
timediff = 1 # 1 second
new_timestamp = timestamp + datetime.timedelta(seconds=timediff)
time_interval = new_timestamp - timestamp

# Define prior state
mean = StateVector([[10], [10]])
covar = CovarianceMatrix(np.eye(2))
gaussian_state = GaussianState(mean, covar, timestamp)
num_vectors = 50
prior_state = EnsembleState.from_gaussian_state(gaussian_state, num_vectors)
prior_ensemble = prior_state.ensemble

# Create Predictor object, run prediction
predictor = EnsemblePredictor(transition_model)
prediction = predictor.predict(prior_state, timestamp=new_timestamp)

# Evaluate mean and covariance
eval_ensemble = transition_model.matrix(timestamp=new_timestamp,
time_interval=time_interval) @ prior_ensemble
eval_mean = StateVector((np.average(eval_ensemble, axis=1)))
eval_cov = np.cov(eval_ensemble)

# Compare evaluated mean and covariance with predictor results
assert np.allclose(prediction.mean, eval_mean)
assert np.allclose(prediction.ensemble, eval_ensemble)
assert np.allclose(prediction.covar, eval_cov)
assert prediction.timestamp == new_timestamp
19 changes: 17 additions & 2 deletions stonesoup/types/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from .array import CovarianceMatrix
from .base import Type
from .state import (State, GaussianState, ParticleState, SqrtGaussianState, InformationState,
TaggedWeightedGaussianState, WeightedGaussianState, CategoricalState)
from .state import (State, GaussianState, ParticleState, EnsembleState,
SqrtGaussianState, InformationState, TaggedWeightedGaussianState,
WeightedGaussianState, CategoricalState)
from ..base import Property
from ..models.transition.base import TransitionModel
from ..types.state import CreatableFromState, CompositeState
Expand Down Expand Up @@ -118,6 +119,20 @@ class ParticleMeasurementPrediction(MeasurementPrediction, ParticleState):
"""


class EnsembleStatePrediction(Prediction, EnsembleState):
"""EnsembleStatePrediction type
This is a simple Ensemble measurement prediction object.
"""


class EnsembleMeasurementPrediction(MeasurementPrediction, EnsembleState):
"""EnsembleMeasurementPrediction type
This is a simple Ensemble measurement prediction object.
"""


class CategoricalStatePrediction(Prediction, CategoricalState):
"""Categorical state prediction type"""

Expand Down
122 changes: 121 additions & 1 deletion stonesoup/types/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from .array import StateVector, CovarianceMatrix, PrecisionMatrix
from .array import StateVector, StateVectors, CovarianceMatrix, PrecisionMatrix
from .base import Type
from .numeric import Probability
from .particle import Particles
Expand Down Expand Up @@ -477,6 +477,126 @@ def covar(self):
State.register(ParticleState) # noqa: E305


class EnsembleState(Type):
r"""Ensemble State type
This is an Ensemble state object which describes the system state as a
ensemble of state vectors for use in Ensemble based filters.
This approach is functionally identical to the Particle state type except
it doesn't use any weighting for any of the "particles" or ensemble members.
All "particles" or state vectors in the ensemble are equally weighted.
.. math::
\mathbf{X} = [x_1, x_2, ..., x_M]
"""

ensemble: StateVectors = Property(doc='''An ensemble of state vectors which represent
the state''')

timestamp: datetime.datetime = Property(
default=None, doc="Timestamp of the state. Default None.")

@classmethod
def from_gaussian_state(self, gaussian_state, num_vectors):
"""
Returns an EnsembleState instance, from a given
GaussianState object.
Parameters
----------
gaussian_state : :class:`~.GaussianState`
The GaussianState used to create the new EnsembleState.
num_vectors : int
The number of desired column vectors present in the ensemble.
Returns
-------
:class:`~.EnsembleState`
Instance of EnsembleState.
"""
mean = gaussian_state.state_vector.reshape((gaussian_state.ndim,))
covar = gaussian_state.covar
timestamp = gaussian_state.timestamp

return EnsembleState(ensemble=self.generate_ensemble(mean, covar, num_vectors),
timestamp=timestamp)

@classmethod
def generate_ensemble(self, mean, covar, num_vectors):
"""
Returns a StateVectors wrapped ensemble of state vectors, from a given
mean and covariance matrix.
Parameters
----------
mean : :class:`~.numpy.ndarray`
The mean value of the distribution being sampled to generate
ensemble.
covar : :class:`~.numpy.ndarray`
The covariance matrix of the distribution being sampled to
generate ensemble.
num_vectors : int
The number of desired column vectors present in the ensemble,
or the number of "samples".
Returns
-------
:class:`~.EnsembleState`
Instance of EnsembleState.
"""
# This check is necessary, because the StateVector wrapper does
# funny things with dimension.
rng = np.random.default_rng()
if mean.ndim != 1:
mean = mean.reshape(len(mean))
try:
ensemble = StateVectors(
[StateVector((rng.multivariate_normal(mean, covar)))
for n in range(num_vectors)])
# If covar is univariate, then use the univariate noise generation function.
except ValueError:
ensemble = StateVectors(
[StateVector((rng.normal(mean, covar))) for n in range(num_vectors)])

return ensemble

@property
def ndim(self):
"""Number of dimensions in state vectors"""
return np.shape(self.ensemble)[0]

@property
def num_vectors(self):
"""Number of columns in state ensemble"""
return np.shape(self.ensemble)[1]

@property
def mean(self):
"""The state mean, numerically equivalent to state vector"""
return np.average(self.ensemble, axis=1)

@property
def state_vector(self):
"""State mean in StateVector wrapper."""
return StateVector(self.mean)

@property
def covar(self):
"""Sample covariance matrix for ensemble"""
return np.cov(self.ensemble)

@property
def sqrt_covar(self):
"""sqrt of sample covariance matrix for ensemble, useful for
some EnKF algorithms"""
return ((self.ensemble - np.tile(self.mean, self.num_vectors)) /
np.sqrt(self.num_vectors - 1))


State.register(EnsembleState) # noqa: E305


class CategoricalState(State):
r"""CategoricalState type.
Expand Down
76 changes: 73 additions & 3 deletions stonesoup/types/tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import scipy.linalg

from ..angle import Bearing
from ..array import StateVector, CovarianceMatrix
from ..array import StateVector, StateVectors, CovarianceMatrix
from ..groundtruth import GroundTruthState
from ..numeric import Probability
from ..particle import Particle
from ..state import CreatableFromState
from ..state import State, GaussianState, ParticleState, StateMutableSequence, \
WeightedGaussianState, SqrtGaussianState, CategoricalState, CompositeState
from ..state import State, GaussianState, ParticleState, EnsembleState, \
StateMutableSequence, WeightedGaussianState, SqrtGaussianState, CategoricalState, \
CompositeState

from ...base import Property


Expand Down Expand Up @@ -218,6 +220,74 @@ def test_particlestate_angle():
assert np.allclose(state.covar, CovarianceMatrix([[0.01, -1.5], [-1.5, 225]]))


def test_ensemblestate():

# 1D
state_vector1 = StateVector(np.array([1.5]))
state_vector2 = StateVector(np.array([0.5]))
list_of_state_vectors = [state_vector1, state_vector2]
ensemble = StateVectors(list_of_state_vectors)

# Test state without timestamp
state = EnsembleState(ensemble)
assert np.allclose(state.state_vector, StateVector([[1]]))
assert np.allclose(state.covar, CovarianceMatrix([[0.5]]))

# Test state with timestamp
timestamp = datetime.datetime(2021, 2, 25, 22, 29, 2)
state = EnsembleState(ensemble, timestamp=timestamp)
assert np.allclose(state.state_vector, StateVector([[1]]))
assert np.allclose(state.covar, CovarianceMatrix([[0.5]]))
assert state.timestamp == timestamp

# 2D
state_vector1 = StateVector(np.array([1.5, 0.75]))
state_vector2 = StateVector(np.array([0.5, 1.25]))
ensemble = StateVectors([state_vector1, state_vector2])

state = EnsembleState(ensemble)
assert np.allclose(state.state_vector, StateVector([[1], [1]]))
assert np.allclose(state.covar, CovarianceMatrix([[0.5, -0.25], [-0.25, 0.125]]))
assert np.allclose(state.sqrt_covar @ state.sqrt_covar.T, state.covar)

# Test generate_ensemble class method.
# 1 Dimensional
test_mean_1d = np.array([0])
test_covar_1d = np.array([1])
ensemble1d = state.generate_ensemble(mean=test_mean_1d,
covar=test_covar_1d, num_vectors=5)
assert np.shape(ensemble1d) == (1, 5)
assert isinstance(ensemble1d, StateVectors)

# 2 Dimensional
# Lets pass in a state vector mean(as opposed to an array) while we're at it
test_mean_2d = StateVector([1, 1])
test_covar_2d = CovarianceMatrix(np.eye(2))
ensemble2d = state.generate_ensemble(mean=test_mean_2d,
covar=test_covar_2d, num_vectors=5)
assert np.shape(ensemble2d) == (2, 5)
assert isinstance(ensemble2d, StateVectors)


def test_ensemblestate_gaussian_init():
"""Test initialising with an existing gaussian state object"""

# Initialize GaussianState
mean = StateVector([[25], [25], [25], [25]])
covar = CovarianceMatrix(np.eye(4))
timestamp = datetime.datetime(2021, 2, 26, 16, 35, 42)
gaussian_state = GaussianState(mean, covar, timestamp)
# Generate EnsembleState
num_vectors = 50
ensemble_state = EnsembleState.from_gaussian_state(gaussian_state, num_vectors)

assert isinstance(ensemble_state.state_vector, StateVector)
assert isinstance(ensemble_state.ensemble, StateVectors)
assert isinstance(ensemble_state.covar, CovarianceMatrix)
assert isinstance(ensemble_state.timestamp, datetime.datetime)
assert ensemble_state.timestamp == timestamp


def test_state_mutable_sequence_state():
state_vector = StateVector([[0]])
timestamp = datetime.datetime(2018, 1, 1, 14)
Expand Down
Loading

0 comments on commit c9c5e8a

Please sign in to comment.