From 970f1da031c620c96deac60408872ad5780c0112 Mon Sep 17 00:00:00 2001 From: Steven Hiscocks Date: Wed, 5 Apr 2023 11:49:38 +0100 Subject: [PATCH] Fix np.mean for StateVectors Previously unusable due to incorrect arguments being passed --- stonesoup/types/array.py | 2 +- stonesoup/types/tests/test_array.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/stonesoup/types/array.py b/stonesoup/types/array.py index f1cf02ea8..657913758 100644 --- a/stonesoup/types/array.py +++ b/stonesoup/types/array.py @@ -167,7 +167,7 @@ def __array_function__(self, func, types, args, kwargs): def _mean(state_vectors, axis=None, dtype=None, out=None, keepdims=np._NoValue): if state_vectors.dtype != np.object_: # Can just use standard numpy mean if not using custom objects - return np.mean(axis, dtype, out, keepdims) + return np.mean(np.asarray(state_vectors), axis, dtype, out, keepdims) elif axis == 1 and out is None: state_vector = np.average(state_vectors, axis) if dtype: diff --git a/stonesoup/types/tests/test_array.py b/stonesoup/types/tests/test_array.py index 70c37fa9b..b9315325e 100644 --- a/stonesoup/types/tests/test_array.py +++ b/stonesoup/types/tests/test_array.py @@ -39,6 +39,14 @@ def test_statevectors(): assert isinstance(sv, StateVector) +def test_statevectors_mean(): + svs = StateVectors([[1., 2., 3.], [4., 5., 6.]]) + mean = StateVector([[2., 5.]]) + + assert np.allclose(np.average(svs, axis=1), mean) + assert np.allclose(np.mean(svs, axis=1, keepdims=True), mean) + + def test_standard_statevector_indexing(): state_vector_array = np.array([[1], [2], [3], [4]]) state_vector = StateVector(state_vector_array)