Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensemble Kalman Filter Algorithm and Data Type #625

Merged
merged 16 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ MIT License
© Crown Copyright 2017-2022 Defence Science and Technology Laboratory UK
© Crown Copyright 2018-2022 Defence Research and Development Canada / Recherche et développement pour la défense Canada
© Copyright 2018-2022 University of Liverpool UK
© Copyright 2020-2022 John Hiles

Portions of this work were created while John Hiles was employed by Wright
State University under Contract No. FA8650-18-2-1645 to the Air Force Research
Laboratory, and the University has released its rights to this Student work.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
6 changes: 6 additions & 0 deletions docs/source/stonesoup.predictor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ Particle
.. automodule:: stonesoup.predictor.particle
:show-inheritance:

Ensemble
--------

.. automodule:: stonesoup.predictor.ensemble
:show-inheritance:

Categorical
-----------

Expand Down
6 changes: 6 additions & 0 deletions docs/source/stonesoup.updater.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ Particle
.. automodule:: stonesoup.updater.particle
:show-inheritance:

Ensemble
--------

.. automodule:: stonesoup.updater.ensemble
:show-inheritance:

Information
-----------

Expand Down
52 changes: 52 additions & 0 deletions stonesoup/predictor/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from .base import Property
from ..models.transition import TransitionModel
from .kalman import KalmanPredictor
from ..types.array import StateVectors
from ..types.state import State
from ..types.prediction import Prediction


class EnsemblePredictor(KalmanPredictor):
r"""Ensemble Kalman Filter Predictor class

The EnKF predicts the state by treating each column of the ensemble matrix
as a state vector. The state is propagated through time by applying the
transition function to each member (vector) of the ensemble.

.. math::

\hat{X}_k = [f(x_1), f(x_2), ..., f(x_M)]

"""
transition_model: TransitionModel = Property(doc="The transition model to be used.")

def predict(self, prior, timestamp=None, **kwargs):
"""Ensemble Kalman Filter prediction step

Parameters
----------
prior : :class:`~.EnsembleState`
A prior state object
control_input : :class:`~.State`, optional
The control input. It will only have an effect if
:attr:`control_model` is not `None` (the default is `None`)
timestamp: :class:`datetime.datetime`, optional
A timestamp signifying when the prediction is performed
(the default is `None`)

Returns
-------
: :class:`~.EnsembleStatePrediction`
The predicted state
"""

# Compute time_interval
time_interval = self._predict_over_interval(prior, timestamp)
# This block of code propagates each column through the transition model.
pred_ensemble = StateVectors(
[self.transition_model.function(State(state_vector=ensemble_member),
noise=True, time_interval=time_interval)
for ensemble_member in prior.ensemble.T])

return Prediction.from_state(prior, pred_ensemble, timestamp=timestamp,
transition_model=self.transition_model)
44 changes: 44 additions & 0 deletions stonesoup/predictor/tests/test_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# coding: utf-8
import datetime

import numpy as np

from ...models.transition.linear import ConstantVelocity
from ...predictor.ensemble import EnsemblePredictor
from ...types.state import GaussianState, EnsembleState
from ...types.array import StateVector, CovarianceMatrix


def test_ensemble():
# Initialise a transition model
transition_model = ConstantVelocity(noise_diff_coeff=0)

# Define time related variables
timestamp = datetime.datetime(2021, 2, 27, 17, 27, 48)
timediff = 1 # 1 second
new_timestamp = timestamp + datetime.timedelta(seconds=timediff)
time_interval = new_timestamp - timestamp

# Define prior state
mean = StateVector([[10], [10]])
covar = CovarianceMatrix(np.eye(2))
gaussian_state = GaussianState(mean, covar, timestamp)
num_vectors = 50
prior_state = EnsembleState.from_gaussian_state(gaussian_state, num_vectors)
prior_ensemble = prior_state.ensemble

# Create Predictor object, run prediction
predictor = EnsemblePredictor(transition_model)
prediction = predictor.predict(prior_state, timestamp=new_timestamp)

# Evaluate mean and covariance
eval_ensemble = transition_model.matrix(timestamp=new_timestamp,
time_interval=time_interval) @ prior_ensemble
eval_mean = StateVector((np.average(eval_ensemble, axis=1)))
eval_cov = np.cov(eval_ensemble)

# Compare evaluated mean and covariance with predictor results
assert np.allclose(prediction.mean, eval_mean)
assert np.allclose(prediction.ensemble, eval_ensemble)
assert np.allclose(prediction.covar, eval_cov)
assert prediction.timestamp == new_timestamp
19 changes: 17 additions & 2 deletions stonesoup/types/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from .array import CovarianceMatrix
from .base import Type
from .state import (State, GaussianState, ParticleState, SqrtGaussianState, InformationState,
TaggedWeightedGaussianState, WeightedGaussianState, CategoricalState)
from .state import (State, GaussianState, ParticleState, EnsembleState,
SqrtGaussianState, InformationState, TaggedWeightedGaussianState,
WeightedGaussianState, CategoricalState)
from ..base import Property
from ..models.transition.base import TransitionModel
from ..types.state import CreatableFromState, CompositeState
Expand Down Expand Up @@ -118,6 +119,20 @@ class ParticleMeasurementPrediction(MeasurementPrediction, ParticleState):
"""


class EnsembleStatePrediction(Prediction, EnsembleState):
"""EnsembleStatePrediction type

This is a simple Ensemble measurement prediction object.
"""


class EnsembleMeasurementPrediction(MeasurementPrediction, EnsembleState):
"""EnsembleMeasurementPrediction type

This is a simple Ensemble measurement prediction object.
"""


class CategoricalStatePrediction(Prediction, CategoricalState):
"""Categorical state prediction type"""

Expand Down
122 changes: 121 additions & 1 deletion stonesoup/types/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from .array import StateVector, CovarianceMatrix, PrecisionMatrix
from .array import StateVector, StateVectors, CovarianceMatrix, PrecisionMatrix
from .base import Type
from .numeric import Probability
from .particle import Particles
Expand Down Expand Up @@ -474,6 +474,126 @@ def covar(self):
State.register(ParticleState) # noqa: E305


class EnsembleState(Type):
r"""Ensemble State type

This is an Ensemble state object which describes the system state as a
ensemble of state vectors for use in Ensemble based filters.

This approach is functionally identical to the Particle state type except
it doesn't use any weighting for any of the "particles" or ensemble members.
All "particles" or state vectors in the ensemble are equally weighted.

.. math::

\mathbf{X} = [x_1, x_2, ..., x_M]

"""

ensemble: StateVectors = Property(doc='''An ensemble of state vectors which represent
the state''')

timestamp: datetime.datetime = Property(
default=None, doc="Timestamp of the state. Default None.")

@classmethod
def from_gaussian_state(self, gaussian_state, num_vectors):
"""
Returns an EnsembleState instance, from a given
GaussianState object.

Parameters
----------
gaussian_state : :class:`~.GaussianState`
The GaussianState used to create the new EnsembleState.
num_vectors : int
The number of desired column vectors present in the ensemble.
Returns
-------
:class:`~.EnsembleState`
Instance of EnsembleState.
"""
mean = gaussian_state.state_vector.reshape((gaussian_state.ndim,))
covar = gaussian_state.covar
timestamp = gaussian_state.timestamp

return EnsembleState(ensemble=self.generate_ensemble(mean, covar, num_vectors),
timestamp=timestamp)

@classmethod
def generate_ensemble(self, mean, covar, num_vectors):
"""
Returns a StateVectors wrapped ensemble of state vectors, from a given
mean and covariance matrix.

Parameters
----------
mean : :class:`~.numpy.ndarray`
The mean value of the distribution being sampled to generate
ensemble.
covar : :class:`~.numpy.ndarray`
The covariance matrix of the distribution being sampled to
generate ensemble.
num_vectors : int
The number of desired column vectors present in the ensemble,
or the number of "samples".
Returns
-------
:class:`~.EnsembleState`
Instance of EnsembleState.
"""
# This check is necessary, because the StateVector wrapper does
# funny things with dimension.
rng = np.random.default_rng()
if mean.ndim != 1:
mean = mean.reshape(len(mean))
try:
ensemble = StateVectors(
[StateVector((rng.multivariate_normal(mean, covar)))
for n in range(num_vectors)])
# If covar is univariate, then use the univariate noise generation function.
except ValueError:
ensemble = StateVectors(
[StateVector((rng.normal(mean, covar))) for n in range(num_vectors)])

return ensemble

@property
def ndim(self):
"""Number of dimensions in state vectors"""
return np.shape(self.ensemble)[0]

@property
def num_vectors(self):
"""Number of columns in state ensemble"""
return np.shape(self.ensemble)[1]

@property
def mean(self):
"""The state mean, numerically equivalent to state vector"""
return np.average(self.ensemble, axis=1)

@property
def state_vector(self):
"""State mean in StateVector wrapper."""
return StateVector(self.mean)

@property
def covar(self):
"""Sample covariance matrix for ensemble"""
return np.cov(self.ensemble)

@property
def sqrt_covar(self):
"""sqrt of sample covariance matrix for ensemble, useful for
some EnKF algorithms"""
return ((self.ensemble - np.tile(self.mean, self.num_vectors)) /
np.sqrt(self.num_vectors - 1))


State.register(EnsembleState) # noqa: E305


class CategoricalState(State):
r"""CategoricalState type.

Expand Down
76 changes: 73 additions & 3 deletions stonesoup/types/tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import scipy.linalg

from ..angle import Bearing
from ..array import StateVector, CovarianceMatrix
from ..array import StateVector, StateVectors, CovarianceMatrix
from ..groundtruth import GroundTruthState
from ..numeric import Probability
from ..particle import Particle
from ..state import CreatableFromState
from ..state import State, GaussianState, ParticleState, StateMutableSequence, \
WeightedGaussianState, SqrtGaussianState, CategoricalState, CompositeState
from ..state import State, GaussianState, ParticleState, EnsembleState, \
StateMutableSequence, WeightedGaussianState, SqrtGaussianState, CategoricalState, \
CompositeState

from ...base import Property


Expand Down Expand Up @@ -218,6 +220,74 @@ def test_particlestate_angle():
assert np.allclose(state.covar, CovarianceMatrix([[0.01, -1.5], [-1.5, 225]]))


def test_ensemblestate():

# 1D
state_vector1 = StateVector(np.array([1.5]))
state_vector2 = StateVector(np.array([0.5]))
list_of_state_vectors = [state_vector1, state_vector2]
ensemble = StateVectors(list_of_state_vectors)

# Test state without timestamp
state = EnsembleState(ensemble)
assert np.allclose(state.state_vector, StateVector([[1]]))
assert np.allclose(state.covar, CovarianceMatrix([[0.5]]))

# Test state with timestamp
timestamp = datetime.datetime(2021, 2, 25, 22, 29, 2)
state = EnsembleState(ensemble, timestamp=timestamp)
assert np.allclose(state.state_vector, StateVector([[1]]))
assert np.allclose(state.covar, CovarianceMatrix([[0.5]]))
assert state.timestamp == timestamp

# 2D
state_vector1 = StateVector(np.array([1.5, 0.75]))
state_vector2 = StateVector(np.array([0.5, 1.25]))
ensemble = StateVectors([state_vector1, state_vector2])

state = EnsembleState(ensemble)
assert np.allclose(state.state_vector, StateVector([[1], [1]]))
assert np.allclose(state.covar, CovarianceMatrix([[0.5, -0.25], [-0.25, 0.125]]))
assert np.allclose(state.sqrt_covar @ state.sqrt_covar.T, state.covar)

# Test generate_ensemble class method.
# 1 Dimensional
test_mean_1d = np.array([0])
test_covar_1d = np.array([1])
ensemble1d = state.generate_ensemble(mean=test_mean_1d,
covar=test_covar_1d, num_vectors=5)
assert np.shape(ensemble1d) == (1, 5)
assert isinstance(ensemble1d, StateVectors)

# 2 Dimensional
# Lets pass in a state vector mean(as opposed to an array) while we're at it
test_mean_2d = StateVector([1, 1])
test_covar_2d = CovarianceMatrix(np.eye(2))
ensemble2d = state.generate_ensemble(mean=test_mean_2d,
covar=test_covar_2d, num_vectors=5)
assert np.shape(ensemble2d) == (2, 5)
assert isinstance(ensemble2d, StateVectors)


def test_ensemblestate_gaussian_init():
"""Test initialising with an existing gaussian state object"""

# Initialize GaussianState
mean = StateVector([[25], [25], [25], [25]])
covar = CovarianceMatrix(np.eye(4))
timestamp = datetime.datetime(2021, 2, 26, 16, 35, 42)
gaussian_state = GaussianState(mean, covar, timestamp)
# Generate EnsembleState
num_vectors = 50
ensemble_state = EnsembleState.from_gaussian_state(gaussian_state, num_vectors)

assert isinstance(ensemble_state.state_vector, StateVector)
assert isinstance(ensemble_state.ensemble, StateVectors)
assert isinstance(ensemble_state.covar, CovarianceMatrix)
assert isinstance(ensemble_state.timestamp, datetime.datetime)
assert ensemble_state.timestamp == timestamp


def test_state_mutable_sequence_state():
state_vector = StateVector([[0]])
timestamp = datetime.datetime(2018, 1, 1, 14)
Expand Down
Loading