Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function to linearly interpolate state mutable sequences #872

Merged
merged 4 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/stonesoup.functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ Functions

.. automodule:: stonesoup.functions.orbital

.. automodule:: stonesoup.functions.interpolate

gawebb-dstl marked this conversation as resolved.
Show resolved Hide resolved
112 changes: 112 additions & 0 deletions stonesoup/functions/interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import copy
import datetime
import warnings
from typing import Union, List, Iterable

import numpy as np
from scipy.interpolate import interp1d

from ..types.state import StateMutableSequence, State


def time_range(start_time: datetime.datetime, end_time: datetime.datetime,
timestep: datetime.timedelta = datetime.timedelta(seconds=1)) \
-> Iterable[datetime.datetime]:
"""
Produces a range of datetime object between ``start_time`` (inclusive) and ``end_time``
(inclusive)

Parameters
----------
start_time: datetime.datetime time range start (inclusive)
end_time: datetime.datetime time range end (inclusive)
timestep: datetime.timedelta default value is 1 second

Returns
-------
Generator[datetime.datetime]

"""
duration = end_time - start_time
n_time_steps = duration / timestep
for x in range(int(n_time_steps) + 1):
yield start_time + x * timestep


def interpolate_state_mutable_sequence(sms: StateMutableSequence,
times: Union[datetime.datetime, List[datetime.datetime]],
) -> Union[StateMutableSequence, State]:
"""
This function performs linear interpolation on a :class:`~.StateMutableSequence`. The function
has two slightly different forms:

If an individual :class:`~datetime.datetime` is inputted for the variable ``times`` then a
:class:`~.State` is returned corresponding to ``times``.

If a list of :class:`~datetime.datetime` is inputted for the variable ``times`` then a
:class:`~.StateMutableSequence` is returned with the states in the sequence corresponding to
``times``.

Note
----
This function does **not** extrapolate. Times outside the range of the time range of ``sms``
are discarded and warning is given. If all ``times`` values are outside the time range of
``sms`` then an ``IndexError`` is raised.

Unique states for each time are required for interpolation. If there are multiple states with
the same time in ``sms`` the later state in the sequence is used.
"""

# If single time is used, insert time into list and run again.
# A StateMutableSequence is produced by the inner function call.
# The single state is taken from that StateMutableSequence
if isinstance(times, datetime.datetime):
new_sms = interpolate_state_mutable_sequence(sms, [times])
return new_sms.state

# Track metadata removed and no interpolation can be performed on the metadata
new_sms = copy.copy(sms)
if hasattr(new_sms, "metadatas"):
new_sms.metadatas = list()
gawebb-dstl marked this conversation as resolved.
Show resolved Hide resolved

# This step ensure unique states for each timestamp. The last state for a timestamp is used
# with earlier states not being used.
time_state_dict = {state.timestamp: state
for state in sms}

# Filter times if required
max_state_time = sms[-1].timestamp
min_state_time = sms[0].timestamp
if max(times) > max_state_time or min(times) < min_state_time:
new_times = [time
for time in times
if min_state_time <= time <= max_state_time]

if len(new_times) == 0:
raise IndexError(f"All times are outside of the state mutable sequence's time range "
f"({min_state_time} -> {max_state_time})")

removed_times = set(times).difference(new_times)
warnings.warn(f"Trying to interpolate states which are outside the time range "
f"({min_state_time} -> {max_state_time}) of the state mutable sequence. The "
f"following times aren't included in the output {removed_times}")

times = new_times

# Find times that require interpolation
times_to_interpolate = sorted(set(times).difference(time_state_dict.keys()))

if len(times_to_interpolate) > 0:
# Only interpolate if required
state_vectors = [state.state_vector for state in time_state_dict.values()]
state_times = [time.timestamp() for time in time_state_dict.keys()]
interpolate_object = interp1d(state_times, np.stack(state_vectors, axis=0), axis=0)
gawebb-dstl marked this conversation as resolved.
Show resolved Hide resolved

interp_output = interpolate_object([time.timestamp() for time in times_to_interpolate])

for idx, time in enumerate(times_to_interpolate):
time_state_dict[time] = State(interp_output[idx, :, :], timestamp=time)

new_sms.states = [time_state_dict[time] for time in times]

return new_sms
109 changes: 109 additions & 0 deletions stonesoup/functions/tests/test_interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import datetime
from typing import Tuple, List

import numpy as np
import pytest

from ..interpolate import time_range, interpolate_state_mutable_sequence
from ...types.state import State, StateMutableSequence


@pytest.mark.parametrize("input_kwargs, expected",
[(dict(start_time=datetime.datetime(2023, 1, 1, 0, 0),
end_time=datetime.datetime(2023, 1, 1, 0, 0, 35),
timestep=datetime.timedelta(seconds=7)),
[datetime.datetime(2023, 1, 1, 0, 0),
datetime.datetime(2023, 1, 1, 0, 0, 7),
datetime.datetime(2023, 1, 1, 0, 0, 14),
datetime.datetime(2023, 1, 1, 0, 0, 21),
datetime.datetime(2023, 1, 1, 0, 0, 28),
datetime.datetime(2023, 1, 1, 0, 0, 35)
]
),
(dict(start_time=datetime.datetime(1970, 1, 1, 0, 0),
end_time=datetime.datetime(1970, 1, 1, 0, 0, 6)),
[datetime.datetime(1970, 1, 1, 0, 0),
datetime.datetime(1970, 1, 1, 0, 0, 1),
datetime.datetime(1970, 1, 1, 0, 0, 2),
datetime.datetime(1970, 1, 1, 0, 0, 3),
datetime.datetime(1970, 1, 1, 0, 0, 4),
datetime.datetime(1970, 1, 1, 0, 0, 5),
datetime.datetime(1970, 1, 1, 0, 0, 6)
]
)
])
def test_time_range(input_kwargs, expected):
generated_times = list(time_range(**input_kwargs))
assert generated_times == expected


t0 = datetime.datetime(2023, 9, 1)
t_max = t0 + datetime.timedelta(seconds=10)
out_of_range_time = t_max + datetime.timedelta(seconds=10)


def calculate_state(time: datetime.datetime) -> State:
""" This function maps a datetime.datetime to a State. This allows interpolated values to be
checked easily. """
n_seconds = (time-t0).seconds
return State(
timestamp=time,
state_vector=[
10 + n_seconds*1.3e-4,
94 - n_seconds*4.7e-5,
106 + n_seconds/43,
]
)


@pytest.fixture
def gen_test_data() -> Tuple[StateMutableSequence, List[datetime.datetime]]:

sms = StateMutableSequence([calculate_state(time)
for time in time_range(t0, t_max, datetime.timedelta(seconds=0.25))
])

interp_times = list(time_range(t0, t_max, datetime.timedelta(seconds=0.1)))

return sms, interp_times


def test_interpolate_state_mutable_sequence(gen_test_data):
sms, interp_times = gen_test_data
new_sms = interpolate_state_mutable_sequence(sms, interp_times)

assert isinstance(new_sms, StateMutableSequence)

for state in new_sms:
interp_sv = state.state_vector
true_sv = calculate_state(state.timestamp).state_vector
np.testing.assert_allclose(interp_sv, true_sv, rtol=1e-3, atol=1e-7)


def test_interpolate_individual_time(gen_test_data):
sms, interp_times = gen_test_data

for time in interp_times[0:50]:
interp_state = interpolate_state_mutable_sequence(sms, time)

assert isinstance(interp_state, State)

interp_sv = interp_state.state_vector
true_sv = calculate_state(time).state_vector
np.testing.assert_allclose(interp_sv, true_sv, rtol=1e-3, atol=1e-7)


def test_interpolate_warning(gen_test_data):
sms, _ = gen_test_data
times = [t0, t_max, out_of_range_time]

with pytest.warns(UserWarning):
_ = interpolate_state_mutable_sequence(sms, times)


def test_interpolate_error(gen_test_data):
sms, _ = gen_test_data
time = out_of_range_time

with pytest.raises(IndexError):
_ = interpolate_state_mutable_sequence(sms, time)