Skip to content

Commit

Permalink
Merge pull request #872 from dstl/interpolate_func_2
Browse files Browse the repository at this point in the history
Function to linearly interpolate state mutable sequences
  • Loading branch information
sdhiscocks authored Oct 31, 2023
2 parents 4a79905 + 45f2f10 commit 0cc3c61
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/source/stonesoup.functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,14 @@ Functions

.. automodule:: stonesoup.functions

Orbital
-------

.. automodule:: stonesoup.functions.orbital


Interpolation
-------------

.. automodule:: stonesoup.functions.interpolate

118 changes: 118 additions & 0 deletions stonesoup/functions/interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import copy
import datetime
import warnings
from typing import Union, List, Iterable

import numpy as np

from ..types.array import StateVectors
from ..types.state import StateMutableSequence, State


def time_range(start_time: datetime.datetime, end_time: datetime.datetime,
timestep: datetime.timedelta = datetime.timedelta(seconds=1)) \
-> Iterable[datetime.datetime]:
"""
Produces a range of datetime object between ``start_time`` (inclusive) and ``end_time``
(inclusive)
Parameters
----------
start_time: datetime.datetime time range start (inclusive)
end_time: datetime.datetime time range end (inclusive)
timestep: datetime.timedelta default value is 1 second
Returns
-------
Generator[datetime.datetime]
"""
duration = end_time - start_time
n_time_steps = duration / timestep
for x in range(int(n_time_steps) + 1):
yield start_time + x * timestep


def interpolate_state_mutable_sequence(sms: StateMutableSequence,
times: Union[datetime.datetime, List[datetime.datetime]],
) -> Union[StateMutableSequence, State]:
"""
This function performs linear interpolation on a :class:`~.StateMutableSequence`. The function
has two slightly different forms:
If an individual :class:`~datetime.datetime` is inputted for the variable ``times`` then a
:class:`~.State` is returned corresponding to ``times``.
If a list of :class:`~datetime.datetime` is inputted for the variable ``times`` then a
:class:`~.StateMutableSequence` is returned with the states in the sequence corresponding to
``times``.
Note
----
This function does **not** extrapolate. Times outside the range of the time range of ``sms``
are discarded and warning is given. If all ``times`` values are outside the time range of
``sms`` then an ``IndexError`` is raised.
Unique states for each time are required for interpolation. If there are multiple states with
the same time in ``sms`` the later state in the sequence is used.
For :class:`~.Track` inputs the *metadatas* is removed as it can't be interpolated.
"""

# If single time is used, insert time into list and run again.
# A StateMutableSequence is produced by the inner function call.
# The single state is taken from that StateMutableSequence
if isinstance(times, datetime.datetime):
new_sms = interpolate_state_mutable_sequence(sms, [times])
return new_sms.state

# Track metadata removed and no interpolation can be performed on the metadata
new_sms = copy.copy(sms)
if hasattr(new_sms, "metadatas"):
new_sms.metadatas = list()

# This step ensure unique states for each timestamp. The last state for a timestamp is used
# with earlier states not being used.
time_state_dict = {state.timestamp: state
for state in sms}

# Filter times if required
max_state_time = sms[-1].timestamp
min_state_time = sms[0].timestamp
if max(times) > max_state_time or min(times) < min_state_time:
new_times = [time
for time in times
if min_state_time <= time <= max_state_time]

if len(new_times) == 0:
raise IndexError(f"All times are outside of the state mutable sequence's time range "
f"({min_state_time} -> {max_state_time})")

removed_times = set(times).difference(new_times)
warnings.warn(f"Trying to interpolate states which are outside the time range "
f"({min_state_time} -> {max_state_time}) of the state mutable sequence. The "
f"following times aren't included in the output {removed_times}")

times = new_times

# Find times that require interpolation
times_to_interpolate = sorted(set(times).difference(time_state_dict.keys()))

if len(times_to_interpolate) > 0:
# Only interpolate if required
state_vectors = StateVectors([state.state_vector for state in time_state_dict.values()])
state_timestamps = [time.timestamp() for time in time_state_dict.keys()]
interp_timestamps = [time.timestamp() for time in times_to_interpolate]

interp_output = np.empty((sms.state.ndim, len(times_to_interpolate)))
for element_index in range(sms.state.ndim):
interp_output[element_index, :] = np.interp(x=interp_timestamps,
xp=state_timestamps,
fp=state_vectors[element_index, :])

for state_index, time in enumerate(times_to_interpolate):
time_state_dict[time] = State(interp_output[:, state_index], timestamp=time)

new_sms.states = [time_state_dict[time] for time in times]

return new_sms
109 changes: 109 additions & 0 deletions stonesoup/functions/tests/test_interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import datetime
from typing import Tuple, List

import numpy as np
import pytest

from ..interpolate import time_range, interpolate_state_mutable_sequence
from ...types.state import State, StateMutableSequence


@pytest.mark.parametrize("input_kwargs, expected",
[(dict(start_time=datetime.datetime(2023, 1, 1, 0, 0),
end_time=datetime.datetime(2023, 1, 1, 0, 0, 35),
timestep=datetime.timedelta(seconds=7)),
[datetime.datetime(2023, 1, 1, 0, 0),
datetime.datetime(2023, 1, 1, 0, 0, 7),
datetime.datetime(2023, 1, 1, 0, 0, 14),
datetime.datetime(2023, 1, 1, 0, 0, 21),
datetime.datetime(2023, 1, 1, 0, 0, 28),
datetime.datetime(2023, 1, 1, 0, 0, 35)
]
),
(dict(start_time=datetime.datetime(1970, 1, 1, 0, 0),
end_time=datetime.datetime(1970, 1, 1, 0, 0, 6)),
[datetime.datetime(1970, 1, 1, 0, 0),
datetime.datetime(1970, 1, 1, 0, 0, 1),
datetime.datetime(1970, 1, 1, 0, 0, 2),
datetime.datetime(1970, 1, 1, 0, 0, 3),
datetime.datetime(1970, 1, 1, 0, 0, 4),
datetime.datetime(1970, 1, 1, 0, 0, 5),
datetime.datetime(1970, 1, 1, 0, 0, 6)
]
)
])
def test_time_range(input_kwargs, expected):
generated_times = list(time_range(**input_kwargs))
assert generated_times == expected


t0 = datetime.datetime(2023, 9, 1)
t_max = t0 + datetime.timedelta(seconds=10)
out_of_range_time = t_max + datetime.timedelta(seconds=10)


def calculate_state(time: datetime.datetime) -> State:
""" This function maps a datetime.datetime to a State. This allows interpolated values to be
checked easily. """
n_seconds = (time-t0).seconds
return State(
timestamp=time,
state_vector=[
10 + n_seconds*1.3e-4,
94 - n_seconds*4.7e-5,
106 + n_seconds/43,
]
)


@pytest.fixture
def gen_test_data() -> Tuple[StateMutableSequence, List[datetime.datetime]]:

sms = StateMutableSequence([calculate_state(time)
for time in time_range(t0, t_max, datetime.timedelta(seconds=0.25))
])

interp_times = list(time_range(t0, t_max, datetime.timedelta(seconds=0.1)))

return sms, interp_times


def test_interpolate_state_mutable_sequence(gen_test_data):
sms, interp_times = gen_test_data
new_sms = interpolate_state_mutable_sequence(sms, interp_times)

assert isinstance(new_sms, StateMutableSequence)

for state in new_sms:
interp_sv = state.state_vector
true_sv = calculate_state(state.timestamp).state_vector
np.testing.assert_allclose(interp_sv, true_sv, rtol=1e-3, atol=1e-7)


def test_interpolate_individual_time(gen_test_data):
sms, interp_times = gen_test_data

for time in interp_times[0:50]:
interp_state = interpolate_state_mutable_sequence(sms, time)

assert isinstance(interp_state, State)

interp_sv = interp_state.state_vector
true_sv = calculate_state(time).state_vector
np.testing.assert_allclose(interp_sv, true_sv, rtol=1e-3, atol=1e-7)


def test_interpolate_warning(gen_test_data):
sms, _ = gen_test_data
times = [t0, t_max, out_of_range_time]

with pytest.warns(UserWarning):
_ = interpolate_state_mutable_sequence(sms, times)


def test_interpolate_error(gen_test_data):
sms, _ = gen_test_data
time = out_of_range_time

with pytest.raises(IndexError):
_ = interpolate_state_mutable_sequence(sms, time)

0 comments on commit 0cc3c61

Please sign in to comment.