From 0af6b044bd5976dd337a634a5233f0147e690ff6 Mon Sep 17 00:00:00 2001 From: G Webb <29946934+gawebb-dstl@users.noreply.github.com> Date: Tue, 26 Sep 2023 09:34:20 +0100 Subject: [PATCH 1/4] Added interpolate file with four options of methods. Functions inplace to test them and speed check them --- stonesoup/functions/interpolate.py | 261 +++++++++++++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 stonesoup/functions/interpolate.py diff --git a/stonesoup/functions/interpolate.py b/stonesoup/functions/interpolate.py new file mode 100644 index 000000000..21e2be0dd --- /dev/null +++ b/stonesoup/functions/interpolate.py @@ -0,0 +1,261 @@ +import copy +import datetime +import timeit +import warnings +from typing import Union, List, Iterable + +import numpy as np +from scipy.interpolate import interp1d + +from stonesoup.types.array import StateVector +from stonesoup.types.state import StateMutableSequence, State + + +def time_range(start_time: datetime.datetime, end_time: datetime.datetime, + timestep: datetime.timedelta = datetime.timedelta(seconds=1)) \ + -> Iterable[datetime.datetime]: + """ + Produces a range of datetime object between ``start_time`` (inclusive) and ``end_time`` + (inclusive) + + Parameters + ---------- + start_time: datetime.datetime time range start (inclusive) + end_time: datetime.datetime time range end (inclusive) + timestep: datetime.timedelta default value is 1 second + + Returns + ------- + Generator[datetime.datetime] + + """ + duration = end_time - start_time + n_time_steps = duration / timestep + for x in range(int(n_time_steps) + 1): + yield start_time + x * timestep + + +def interpolate_state_mutable_sequence(sms: StateMutableSequence, + times: Union[datetime.datetime, List[datetime.datetime]], + option: int) -> Union[StateMutableSequence, State]: + + if isinstance(times, datetime.datetime): + new_sms = interpolate_state_mutable_sequence(sms, [times]) + return new_sms.state + + new_sms = copy.copy(sms) + + # Removes multiple states with the same timestamp. Not needed for most sequences + time_state_dict = {state.timestamp: state + for state in sms} + + # Filter times if required + max_state_time = sms[-1].timestamp + min_state_time = sms[0].timestamp + if max(times) > max_state_time or min(times) < min_state_time: + new_times = [time + for time in times + if min_state_time <= time <= max_state_time] + + if len(new_times) == 0: + raise IndexError(f"All times are outside of the state mutable sequence's time range " + f"({min_state_time} -> {max_state_time})") + + removed_times = set(times).difference(new_times) + warnings.warn(f"Trying to interpolate states which are outside the time range " + f"({min_state_time} -> {max_state_time}) of the state mutable sequence. The " + f"following times aren't included in the output {removed_times}") + + times = new_times + + # Return short-cut method if no interpolation is required + if set(times).issubset(time_state_dict.keys()): + # All states are present. No interpolation is required + new_sms.states = [time_state_dict[time] + for time in times] + + return new_sms + + if option == 1: + return option_1(sms, times) + elif option == 2: + return option_2(sms, times) + elif option == 3: + return option_3(sms, times) + elif option == 4: + return option_4(sms, times) + + +def option_1(sms: StateMutableSequence, times: List[datetime.datetime]) -> StateMutableSequence: + new_sms = copy.copy(sms) + if hasattr(new_sms, "metadatas"): + new_sms.metadatas = list() + new_sms.states = list() + + state_iter = iter(sms) + aft = bef = next(state_iter) + + for timestamp in times: + + # want `bef` to be state just before timestamp, and `aft` to be state just after + while aft.timestamp < timestamp: + bef = aft + while bef.timestamp == aft.timestamp: + aft = next(state_iter) + + if aft.timestamp == timestamp: + # if `aft` happens to have timestamp exactly equal, just take its state vector + sv = aft.state_vector + else: + # otherwise, get the point on the line connecting `bef` and `aft` + # that lies exactly where `timestamp` would be (assuming constant velocity) + bef_sv = bef.state_vector + aft_sv = aft.state_vector + + frac = (timestamp - bef.timestamp) / (aft.timestamp - bef.timestamp) + + sv = bef_sv + ((aft_sv - bef_sv) * frac) + + # use `bef` as template for new state (since metadata might change in the future, + # at `aft`) + new_state = State.from_state(bef, state_vector=sv, timestamp=timestamp) + new_state.state_vector = sv + new_sms.append(new_state) + + return new_sms + + +def option_2(sms: StateMutableSequence, times: List[datetime.datetime]) -> StateMutableSequence: + new_sms = copy.copy(sms) + if hasattr(new_sms, "metadatas"): + new_sms.metadatas = list() + new_sms.states = list() + + time_state_dict = {state.timestamp: state + for state in sms} + + float_times = [state.timestamp.timestamp() for state in sms] + float_state_vectors = [[np.double(state.state_vector[i]) + for state in sms] + for i, _ in enumerate(sms[0].state_vector)] + + for timestamp in times: + if timestamp in time_state_dict: + new_sms.append(time_state_dict[timestamp]) + else: + + output = np.zeros(len(sms[0].state_vector)) + for i in range(len(output)): + a_states = float_state_vectors[i] + output[i] = np.interp(timestamp.timestamp(), float_times, a_states) + + new_sms.append(State(StateVector(output), timestamp=timestamp)) + + return new_sms + + +def option_3(sms: StateMutableSequence, times: List[datetime.datetime]) -> StateMutableSequence: + + new_sms = copy.copy(sms) + if hasattr(new_sms, "metadatas"): + new_sms.metadatas = list() + + # Removes multiple states with the same timestamp. Not needed for most sequences + time_state_dict = {state.timestamp: state + for state in sms} + + svs = [state.state_vector for state in time_state_dict.values()] + state_times = [time.timestamp() for time in time_state_dict.keys()] + interpolate_object = interp1d(state_times, np.stack(svs, axis=0), axis=0) + + interp_output = interpolate_object([time.timestamp() for time in times]) + + new_sms.states = [State(interp_output[idx, :, :], timestamp=time) + for idx, time in enumerate(times)] + + return new_sms + + +def option_4(sms: StateMutableSequence, times: List[datetime.datetime]) -> StateMutableSequence: + new_sms = copy.copy(sms) + if hasattr(new_sms, "metadatas"): + new_sms.metadatas = list() + + # Removes multiple states with the same timestamp. Not needed for most sequences + time_state_dict = {state.timestamp: state + for state in sms} + + times_to_interpolate = sorted(set(times).difference(time_state_dict.keys())) + + svs = [state.state_vector for state in time_state_dict.values()] + state_times = [time.timestamp() for time in time_state_dict.keys()] + interpolate_object = interp1d(state_times, np.stack(svs, axis=0), axis=0) + + interp_output = interpolate_object([time.timestamp() for time in times_to_interpolate]) + + for idx, time in enumerate(times_to_interpolate): + time_state_dict[time] = State(interp_output[idx, :, :], timestamp=time) + + new_sms.states = [time_state_dict[time] for time in times] + + return new_sms + + +def state_vector_fun(time: datetime.datetime) -> State: + today = datetime.datetime(datetime.datetime.today().year, + datetime.datetime.today().month, + datetime.datetime.today().day) + num = (time-today).seconds + return State( + timestamp=time, + state_vector=[ + num*1.3e-4, + num*4.7e-5, + num/43, + # np.sin(num*1e-6) + ] + ) + + +def gen_test_data(): + t0 = datetime.datetime.today() + max_seconds = 1e3 + t_max = t0 + datetime.timedelta(seconds=max_seconds) + + sms = StateMutableSequence([state_vector_fun(time) + for time in time_range(t0, t_max, datetime.timedelta(seconds=0.25)) + ]) + + interp_times = list(time_range(t0, t_max, datetime.timedelta(seconds=0.1))) + + return sms, interp_times + + +def individual_speed_test(i): + sms, interp_times = gen_test_data() + interpolate_state_mutable_sequence(sms, interp_times, i) + + +def run_all_speed_tests(): + for i in [1, 2, 3, 4]: + time_taken = timeit.timeit(setup="from __main__ import individual_speed_test", + stmt=f"individual_speed_test({i})", number=10) + print(f"Option {i} took {time_taken}") + + +def check_correct(): + sms, interp_times = gen_test_data() + for i in [1, 2, 3, 4]: + new_sms = interpolate_state_mutable_sequence(sms, interp_times, i) + for state in new_sms: + interp_sv = state.state_vector + true_sv = state_vector_fun(state.timestamp).state_vector + if not np.all(np.isclose(interp_sv, true_sv, rtol=1e-3)): + print(f"Option {i} failed") + break + print(f"Option {i} worked") + + +if __name__ == '__main__': + run_all_speed_tests() + check_correct() From bd860478a1c2126e29d3c7eaf1b5a54ab9d30ab4 Mon Sep 17 00:00:00 2001 From: G Webb <29946934+gawebb-dstl@users.noreply.github.com> Date: Tue, 26 Sep 2023 12:29:36 +0100 Subject: [PATCH 2/4] Moved tests into test_interpolate.py Removed speed tests and options Added more words --- docs/source/stonesoup.functions.rst | 2 + stonesoup/functions/interpolate.py | 227 +++--------------- stonesoup/functions/tests/test_interpolate.py | 109 +++++++++ 3 files changed, 150 insertions(+), 188 deletions(-) create mode 100644 stonesoup/functions/tests/test_interpolate.py diff --git a/docs/source/stonesoup.functions.rst b/docs/source/stonesoup.functions.rst index e1e9aea03..df4ac6b1d 100644 --- a/docs/source/stonesoup.functions.rst +++ b/docs/source/stonesoup.functions.rst @@ -5,3 +5,5 @@ Functions .. automodule:: stonesoup.functions.orbital +.. automodule:: stonesoup.functions.interpolate + diff --git a/stonesoup/functions/interpolate.py b/stonesoup/functions/interpolate.py index 21e2be0dd..47bed69e5 100644 --- a/stonesoup/functions/interpolate.py +++ b/stonesoup/functions/interpolate.py @@ -1,14 +1,12 @@ import copy import datetime -import timeit import warnings from typing import Union, List, Iterable import numpy as np from scipy.interpolate import interp1d -from stonesoup.types.array import StateVector -from stonesoup.types.state import StateMutableSequence, State +from ..types.state import StateMutableSequence, State def time_range(start_time: datetime.datetime, end_time: datetime.datetime, @@ -37,15 +35,42 @@ def time_range(start_time: datetime.datetime, end_time: datetime.datetime, def interpolate_state_mutable_sequence(sms: StateMutableSequence, times: Union[datetime.datetime, List[datetime.datetime]], - option: int) -> Union[StateMutableSequence, State]: + ) -> Union[StateMutableSequence, State]: + """ + This function performs linear interpolation on a :class:`~.StateMutableSequence`. The function + has two slightly different forms: + + If an individual :class:`~datetime.datetime` is inputted for the variable ``times`` then a + :class:`~.State` is returned corresponding to ``times``. + + If a list of :class:`~datetime.datetime` is inputted for the variable ``times`` then a + :class:`~.StateMutableSequence` is returned with the states in the sequence corresponding to + ``times``. + + Note + ---- + This function does **not** extrapolate. Times outside the range of the time range of ``sms`` + are discarded and warning is given. If all ``times`` values are outside the time range of + ``sms`` then an ``IndexError`` is raised. + + Unique states for each time are required for interpolation. If there are multiple states with + the same time in ``sms`` the later state in the sequence is used. + """ + # If single time is used, insert time into list and run again. + # A StateMutableSequence is produced by the inner function call. + # The single state is taken from that StateMutableSequence if isinstance(times, datetime.datetime): new_sms = interpolate_state_mutable_sequence(sms, [times]) return new_sms.state + # Track metadata removed and no interpolation can be performed on the metadata new_sms = copy.copy(sms) + if hasattr(new_sms, "metadatas"): + new_sms.metadatas = list() - # Removes multiple states with the same timestamp. Not needed for most sequences + # This step ensure unique states for each timestamp. The last state for a timestamp is used + # with earlier states not being used. time_state_dict = {state.timestamp: state for state in sms} @@ -68,194 +93,20 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence, times = new_times - # Return short-cut method if no interpolation is required - if set(times).issubset(time_state_dict.keys()): - # All states are present. No interpolation is required - new_sms.states = [time_state_dict[time] - for time in times] - - return new_sms - - if option == 1: - return option_1(sms, times) - elif option == 2: - return option_2(sms, times) - elif option == 3: - return option_3(sms, times) - elif option == 4: - return option_4(sms, times) - - -def option_1(sms: StateMutableSequence, times: List[datetime.datetime]) -> StateMutableSequence: - new_sms = copy.copy(sms) - if hasattr(new_sms, "metadatas"): - new_sms.metadatas = list() - new_sms.states = list() - - state_iter = iter(sms) - aft = bef = next(state_iter) - - for timestamp in times: - - # want `bef` to be state just before timestamp, and `aft` to be state just after - while aft.timestamp < timestamp: - bef = aft - while bef.timestamp == aft.timestamp: - aft = next(state_iter) - - if aft.timestamp == timestamp: - # if `aft` happens to have timestamp exactly equal, just take its state vector - sv = aft.state_vector - else: - # otherwise, get the point on the line connecting `bef` and `aft` - # that lies exactly where `timestamp` would be (assuming constant velocity) - bef_sv = bef.state_vector - aft_sv = aft.state_vector - - frac = (timestamp - bef.timestamp) / (aft.timestamp - bef.timestamp) - - sv = bef_sv + ((aft_sv - bef_sv) * frac) - - # use `bef` as template for new state (since metadata might change in the future, - # at `aft`) - new_state = State.from_state(bef, state_vector=sv, timestamp=timestamp) - new_state.state_vector = sv - new_sms.append(new_state) - - return new_sms - - -def option_2(sms: StateMutableSequence, times: List[datetime.datetime]) -> StateMutableSequence: - new_sms = copy.copy(sms) - if hasattr(new_sms, "metadatas"): - new_sms.metadatas = list() - new_sms.states = list() - - time_state_dict = {state.timestamp: state - for state in sms} - - float_times = [state.timestamp.timestamp() for state in sms] - float_state_vectors = [[np.double(state.state_vector[i]) - for state in sms] - for i, _ in enumerate(sms[0].state_vector)] - - for timestamp in times: - if timestamp in time_state_dict: - new_sms.append(time_state_dict[timestamp]) - else: - - output = np.zeros(len(sms[0].state_vector)) - for i in range(len(output)): - a_states = float_state_vectors[i] - output[i] = np.interp(timestamp.timestamp(), float_times, a_states) - - new_sms.append(State(StateVector(output), timestamp=timestamp)) - - return new_sms - - -def option_3(sms: StateMutableSequence, times: List[datetime.datetime]) -> StateMutableSequence: - - new_sms = copy.copy(sms) - if hasattr(new_sms, "metadatas"): - new_sms.metadatas = list() - - # Removes multiple states with the same timestamp. Not needed for most sequences - time_state_dict = {state.timestamp: state - for state in sms} - - svs = [state.state_vector for state in time_state_dict.values()] - state_times = [time.timestamp() for time in time_state_dict.keys()] - interpolate_object = interp1d(state_times, np.stack(svs, axis=0), axis=0) - - interp_output = interpolate_object([time.timestamp() for time in times]) - - new_sms.states = [State(interp_output[idx, :, :], timestamp=time) - for idx, time in enumerate(times)] - - return new_sms - - -def option_4(sms: StateMutableSequence, times: List[datetime.datetime]) -> StateMutableSequence: - new_sms = copy.copy(sms) - if hasattr(new_sms, "metadatas"): - new_sms.metadatas = list() - - # Removes multiple states with the same timestamp. Not needed for most sequences - time_state_dict = {state.timestamp: state - for state in sms} - + # Find times that require interpolation times_to_interpolate = sorted(set(times).difference(time_state_dict.keys())) - svs = [state.state_vector for state in time_state_dict.values()] - state_times = [time.timestamp() for time in time_state_dict.keys()] - interpolate_object = interp1d(state_times, np.stack(svs, axis=0), axis=0) + if len(times_to_interpolate) > 0: + # Only interpolate if required + state_vectors = [state.state_vector for state in time_state_dict.values()] + state_times = [time.timestamp() for time in time_state_dict.keys()] + interpolate_object = interp1d(state_times, np.stack(state_vectors, axis=0), axis=0) - interp_output = interpolate_object([time.timestamp() for time in times_to_interpolate]) + interp_output = interpolate_object([time.timestamp() for time in times_to_interpolate]) - for idx, time in enumerate(times_to_interpolate): - time_state_dict[time] = State(interp_output[idx, :, :], timestamp=time) + for idx, time in enumerate(times_to_interpolate): + time_state_dict[time] = State(interp_output[idx, :, :], timestamp=time) new_sms.states = [time_state_dict[time] for time in times] return new_sms - - -def state_vector_fun(time: datetime.datetime) -> State: - today = datetime.datetime(datetime.datetime.today().year, - datetime.datetime.today().month, - datetime.datetime.today().day) - num = (time-today).seconds - return State( - timestamp=time, - state_vector=[ - num*1.3e-4, - num*4.7e-5, - num/43, - # np.sin(num*1e-6) - ] - ) - - -def gen_test_data(): - t0 = datetime.datetime.today() - max_seconds = 1e3 - t_max = t0 + datetime.timedelta(seconds=max_seconds) - - sms = StateMutableSequence([state_vector_fun(time) - for time in time_range(t0, t_max, datetime.timedelta(seconds=0.25)) - ]) - - interp_times = list(time_range(t0, t_max, datetime.timedelta(seconds=0.1))) - - return sms, interp_times - - -def individual_speed_test(i): - sms, interp_times = gen_test_data() - interpolate_state_mutable_sequence(sms, interp_times, i) - - -def run_all_speed_tests(): - for i in [1, 2, 3, 4]: - time_taken = timeit.timeit(setup="from __main__ import individual_speed_test", - stmt=f"individual_speed_test({i})", number=10) - print(f"Option {i} took {time_taken}") - - -def check_correct(): - sms, interp_times = gen_test_data() - for i in [1, 2, 3, 4]: - new_sms = interpolate_state_mutable_sequence(sms, interp_times, i) - for state in new_sms: - interp_sv = state.state_vector - true_sv = state_vector_fun(state.timestamp).state_vector - if not np.all(np.isclose(interp_sv, true_sv, rtol=1e-3)): - print(f"Option {i} failed") - break - print(f"Option {i} worked") - - -if __name__ == '__main__': - run_all_speed_tests() - check_correct() diff --git a/stonesoup/functions/tests/test_interpolate.py b/stonesoup/functions/tests/test_interpolate.py new file mode 100644 index 000000000..a825f4a06 --- /dev/null +++ b/stonesoup/functions/tests/test_interpolate.py @@ -0,0 +1,109 @@ +import datetime +from typing import Tuple, List + +import numpy as np +import pytest + +from ..interpolate import time_range, interpolate_state_mutable_sequence +from ...types.state import State, StateMutableSequence + + +@pytest.mark.parametrize("input_kwargs, expected", + [(dict(start_time=datetime.datetime(2023, 1, 1, 0, 0), + end_time=datetime.datetime(2023, 1, 1, 0, 0, 35), + timestep=datetime.timedelta(seconds=7)), + [datetime.datetime(2023, 1, 1, 0, 0), + datetime.datetime(2023, 1, 1, 0, 0, 7), + datetime.datetime(2023, 1, 1, 0, 0, 14), + datetime.datetime(2023, 1, 1, 0, 0, 21), + datetime.datetime(2023, 1, 1, 0, 0, 28), + datetime.datetime(2023, 1, 1, 0, 0, 35) + ] + ), + (dict(start_time=datetime.datetime(1970, 1, 1, 0, 0), + end_time=datetime.datetime(1970, 1, 1, 0, 0, 6)), + [datetime.datetime(1970, 1, 1, 0, 0), + datetime.datetime(1970, 1, 1, 0, 0, 1), + datetime.datetime(1970, 1, 1, 0, 0, 2), + datetime.datetime(1970, 1, 1, 0, 0, 3), + datetime.datetime(1970, 1, 1, 0, 0, 4), + datetime.datetime(1970, 1, 1, 0, 0, 5), + datetime.datetime(1970, 1, 1, 0, 0, 6) + ] + ) + ]) +def test_time_range(input_kwargs, expected): + generated_times = list(time_range(**input_kwargs)) + assert generated_times == expected + + +t0 = datetime.datetime(2023, 9, 1) +t_max = t0 + datetime.timedelta(seconds=10) +out_of_range_time = t_max + datetime.timedelta(seconds=10) + + +def calculate_state(time: datetime.datetime) -> State: + """ This function maps a datetime.datetime to a State. This allows interpolated values to be + checked easily. """ + n_seconds = (time-t0).seconds + return State( + timestamp=time, + state_vector=[ + 10 + n_seconds*1.3e-4, + 94 - n_seconds*4.7e-5, + 106 + n_seconds/43, + ] + ) + + +@pytest.fixture +def gen_test_data() -> Tuple[StateMutableSequence, List[datetime.datetime]]: + + sms = StateMutableSequence([calculate_state(time) + for time in time_range(t0, t_max, datetime.timedelta(seconds=0.25)) + ]) + + interp_times = list(time_range(t0, t_max, datetime.timedelta(seconds=0.1))) + + return sms, interp_times + + +def test_interpolate_state_mutable_sequence(gen_test_data): + sms, interp_times = gen_test_data + new_sms = interpolate_state_mutable_sequence(sms, interp_times) + + assert isinstance(new_sms, StateMutableSequence) + + for state in new_sms: + interp_sv = state.state_vector + true_sv = calculate_state(state.timestamp).state_vector + np.testing.assert_allclose(interp_sv, true_sv, rtol=1e-3, atol=1e-7) + + +def test_interpolate_individual_time(gen_test_data): + sms, interp_times = gen_test_data + + for time in interp_times[0:50]: + interp_state = interpolate_state_mutable_sequence(sms, time) + + assert isinstance(interp_state, State) + + interp_sv = interp_state.state_vector + true_sv = calculate_state(time).state_vector + np.testing.assert_allclose(interp_sv, true_sv, rtol=1e-3, atol=1e-7) + + +def test_interpolate_warning(gen_test_data): + sms, _ = gen_test_data + times = [t0, t_max, out_of_range_time] + + with pytest.warns(UserWarning): + _ = interpolate_state_mutable_sequence(sms, times) + + +def test_interpolate_error(gen_test_data): + sms, _ = gen_test_data + time = out_of_range_time + + with pytest.raises(IndexError): + _ = interpolate_state_mutable_sequence(sms, time) From b5741c7dd68e33cccbcebe345233d0fc9a777ddd Mon Sep 17 00:00:00 2001 From: G Webb Date: Thu, 5 Oct 2023 18:23:08 +0100 Subject: [PATCH 3/4] Removed legacy function scipy.interpolate.interp1d and replaced with multiple calls to numpy.interp --- stonesoup/functions/interpolate.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/stonesoup/functions/interpolate.py b/stonesoup/functions/interpolate.py index 47bed69e5..9ab260b4d 100644 --- a/stonesoup/functions/interpolate.py +++ b/stonesoup/functions/interpolate.py @@ -4,8 +4,8 @@ from typing import Union, List, Iterable import numpy as np -from scipy.interpolate import interp1d +from ..types.array import StateVectors from ..types.state import StateMutableSequence, State @@ -98,14 +98,18 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence, if len(times_to_interpolate) > 0: # Only interpolate if required - state_vectors = [state.state_vector for state in time_state_dict.values()] - state_times = [time.timestamp() for time in time_state_dict.keys()] - interpolate_object = interp1d(state_times, np.stack(state_vectors, axis=0), axis=0) - - interp_output = interpolate_object([time.timestamp() for time in times_to_interpolate]) - - for idx, time in enumerate(times_to_interpolate): - time_state_dict[time] = State(interp_output[idx, :, :], timestamp=time) + state_vectors = StateVectors([state.state_vector for state in time_state_dict.values()]) + state_timestamps = [time.timestamp() for time in time_state_dict.keys()] + interp_timestamps = [time.timestamp() for time in times_to_interpolate] + + interp_output = np.empty((sms.state.ndim, len(times_to_interpolate))) + for element_index in range(sms.state.ndim): + interp_output[element_index][:] = np.interp(x=interp_timestamps, + xp=state_timestamps, + fp=state_vectors[element_index][:]) + + for state_index, time in enumerate(times_to_interpolate): + time_state_dict[time] = State(interp_output[:, state_index], timestamp=time) new_sms.states = [time_state_dict[time] for time in times] From 45f2f10a380bc9c2c058a3fa2613a03adf42b1c1 Mon Sep 17 00:00:00 2001 From: G Webb Date: Mon, 9 Oct 2023 11:24:12 +0100 Subject: [PATCH 4/4] Minor edits Co-authored-by: Steven Hiscocks --- docs/source/stonesoup.functions.rst | 7 +++++++ stonesoup/functions/interpolate.py | 6 ++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/source/stonesoup.functions.rst b/docs/source/stonesoup.functions.rst index df4ac6b1d..46d4ae450 100644 --- a/docs/source/stonesoup.functions.rst +++ b/docs/source/stonesoup.functions.rst @@ -3,7 +3,14 @@ Functions .. automodule:: stonesoup.functions +Orbital +------- + .. automodule:: stonesoup.functions.orbital + +Interpolation +------------- + .. automodule:: stonesoup.functions.interpolate diff --git a/stonesoup/functions/interpolate.py b/stonesoup/functions/interpolate.py index 9ab260b4d..18f408171 100644 --- a/stonesoup/functions/interpolate.py +++ b/stonesoup/functions/interpolate.py @@ -55,6 +55,8 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence, Unique states for each time are required for interpolation. If there are multiple states with the same time in ``sms`` the later state in the sequence is used. + + For :class:`~.Track` inputs the *metadatas* is removed as it can't be interpolated. """ # If single time is used, insert time into list and run again. @@ -104,9 +106,9 @@ def interpolate_state_mutable_sequence(sms: StateMutableSequence, interp_output = np.empty((sms.state.ndim, len(times_to_interpolate))) for element_index in range(sms.state.ndim): - interp_output[element_index][:] = np.interp(x=interp_timestamps, + interp_output[element_index, :] = np.interp(x=interp_timestamps, xp=state_timestamps, - fp=state_vectors[element_index][:]) + fp=state_vectors[element_index, :]) for state_index, time in enumerate(times_to_interpolate): time_state_dict[time] = State(interp_output[:, state_index], timestamp=time)