diff --git a/stonesoup/types/_util.py b/stonesoup/types/_util.py new file mode 100644 index 000000000..13791e16d --- /dev/null +++ b/stonesoup/types/_util.py @@ -0,0 +1,60 @@ +""" +Back port from Python 3.8: https://github.com/python/cpython/blob/4f100fe9f1c691145e3fa959ef324646e303cdf3/Lib/functools.py#L924-L976 + +LICENSE: https://github.com/python/cpython/blob/4f100fe9f1c691145e3fa959ef324646e303cdf3/LICENSE +Copyright (c) 2001-2022 Python Software Foundation. All rights reserved. +""" +# flake8: noqa +# TODO: Remove once support for Python 3.7 dropped; replace with functools + +from threading import RLock + +_NOT_FOUND = object() + + +class cached_property: # pragma: no cover + def __init__(self, func): + self.func = func + self.attrname = None + self.__doc__ = func.__doc__ + self.lock = RLock() + + def __set_name__(self, owner, name): + if self.attrname is None: + self.attrname = name + elif name != self.attrname: + raise TypeError( + "Cannot assign the same cached_property to two different names " + f"({self.attrname!r} and {name!r})." + ) + + def __get__(self, instance, owner=None): + if instance is None: + return self + if self.attrname is None: + raise TypeError( + "Cannot use cached_property instance without calling __set_name__ on it.") + try: + cache = instance.__dict__ + except AttributeError: # not all objects have __dict__ (e.g. class defines slots) + msg = ( + f"No '__dict__' attribute on {type(instance).__name__!r} " + f"instance to cache {self.attrname!r} property." + ) + raise TypeError(msg) from None + val = cache.get(self.attrname, _NOT_FOUND) + if val is _NOT_FOUND: + with self.lock: + # check if another thread filled cache while we awaited lock + val = cache.get(self.attrname, _NOT_FOUND) + if val is _NOT_FOUND: + val = self.func(instance) + try: + cache[self.attrname] = val + except TypeError: + msg = ( + f"The '__dict__' attribute on {type(instance).__name__!r} instance " + f"does not support item assignment for caching {self.attrname!r} property." + ) + raise TypeError(msg) from None + return val diff --git a/stonesoup/types/state.py b/stonesoup/types/state.py index 7f5dbf0d2..a753108cf 100644 --- a/stonesoup/types/state.py +++ b/stonesoup/types/state.py @@ -12,6 +12,7 @@ from .base import Type from .particle import Particle from .numeric import Probability +from ._util import cached_property # TODO: Change to functools once support for Python 3.7 dropped class State(Type): @@ -433,7 +434,15 @@ class ParticleState(State): """Particle State type This is a particle state object which describes the state as a - distribution of particles""" + distribution of particles + + Note + ---- + Once either :attr:`mean` or :attr:`covar` are called, both :attr:`state_vector` + and :attr:`weight` NumPy arrays will no longer be writable due to caching. If + replacing :attr:`state_vector` or :attr:`covar` on the state, the cache will + be cleared. + """ state_vector: StateVectors = Property(doc='State vectors.') weight: MutableSequence[Probability] = Property(default=None, doc='Weights of particles') @@ -484,6 +493,33 @@ def __getitem__(self, item): parent=p) return particle + def _clear_cache(self): + if 'mean' in self.__dict__: + del self.__dict__["mean"] + if 'covar' in self.__dict__: + del self.__dict__["covar"] + + @state_vector.setter + def state_vector(self, value): + self._clear_cache() + if value is not None: + value = np.asanyarray(value) + setattr(self, type(self).state_vector._property_name, value) + + @weight.setter + def weight(self, value): + self._clear_cache() + if value is not None: + value = np.asanyarray(value) + setattr(self, type(self).weight._property_name, value) + + @fixed_covar.setter + def fixed_covar(self, value): + # Don't need to worry about mean + if 'covar' in self.__dict__: + del self.__dict__["covar"] + setattr(self, type(self).fixed_covar._property_name, value) + @property def particles(self): return [particle for particle in self] @@ -495,22 +531,20 @@ def __len__(self): def ndim(self): return self.state_vector.shape[0] - @property + @cached_property def mean(self): """The state mean, equivalent to state vector""" - result = np.average(self.state_vector, - axis=1, - weights=self.weight) - # Convert type as may have type of weights - return result + self.state_vector.flags.writeable = False + self.weight.flags.writeable = False + return np.average(self.state_vector, axis=1, weights=self.weight) - @property + @cached_property def covar(self): if self.fixed_covar is not None: return self.fixed_covar - cov = np.cov(self.state_vector, ddof=0, aweights=np.array(self.weight)) - # Fix one dimensional covariances being returned with zero dimension - return cov + self.state_vector.flags.writeable = False + self.weight.flags.writeable = False + return np.cov(self.state_vector, ddof=0, aweights=self.weight) State.register(ParticleState) # noqa: E305 diff --git a/stonesoup/types/tests/test_state.py b/stonesoup/types/tests/test_state.py index 178b6f81e..7d1c5ab21 100644 --- a/stonesoup/types/tests/test_state.py +++ b/stonesoup/types/tests/test_state.py @@ -246,6 +246,39 @@ def test_particlestate_angle(): assert np.allclose(state.covar, CovarianceMatrix([[0.01, -1.5], [-1.5, 225]])) +def test_particlestate_cache(): + num_particles = 10 + weight = Probability(1/num_particles) + particles = StateVectors(np.concatenate( + (np.tile([[0]], num_particles//2), np.tile([[100]], num_particles//2)), axis=1)) + weights = np.tile(weight, num_particles) + + state = ParticleState(particles, weight=weights) + assert np.allclose(state.mean, StateVector([[50]])) + assert np.allclose(state.covar, CovarianceMatrix([[2500]])) + + with pytest.raises(ValueError, match="read-only"): + state.state_vector += 10 + + with pytest.raises(ValueError, match="read-only"): + state.weight *= 0.5 + + state.state_vector = particles + 50 # Cache cleared + with pytest.raises(ValueError, match="read-only"): + state.weight *= 0.5 # But still not writable + state.weight = state.weight * 0.5 + assert np.allclose(state.mean, StateVector([[100]])) + assert np.allclose(state.covar, CovarianceMatrix([[2500]])) + + state = ParticleState(particles, weight=weights, fixed_covar=np.array([[1]])) + assert np.allclose(state.mean, StateVector([[50]])) + assert np.allclose(state.covar, CovarianceMatrix([[1]])) + + state.fixed_covar = np.array([[2]]) + assert np.allclose(state.mean, StateVector([[50]])) + assert np.allclose(state.covar, CovarianceMatrix([[2]])) + + def test_ensemblestate(): # 1D