Skip to content

Commit

Permalink
Refactor in progess for new OcpFes class
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin CO committed Feb 24, 2025
1 parent c80be5d commit 82146ad
Show file tree
Hide file tree
Showing 14 changed files with 254 additions and 500 deletions.
4 changes: 2 additions & 2 deletions cocofest/integration/ivp_fes.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,12 @@ def build_initial_guess_from_ocp(self, ocp, stim_idx_at_node_list=None):
)
for i in range(self.n_shooting)
]
for j in range(self.model._sum_stim_truncation)
for j in range(self.model.sum_stim_truncation)
]

else:
pi = self.pulse_intensity[0] if isinstance(self.pulse_intensity, list) else self.pulse_width
initial_guess_list = [[pi] * self.model._sum_stim_truncation] * self.n_shooting
initial_guess_list = [[pi] * self.model.sum_stim_truncation] * self.n_shooting

u.add(key, initial_guess=initial_guess_list, phase=0, interpolation=InterpolationType.EACH_FRAME)

Expand Down
36 changes: 33 additions & 3 deletions cocofest/models/ding2003.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Callable
from math import gcd
from fractions import Fraction

import numpy as np
from casadi import MX, exp, vertcat
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(
super().__init__()
self._model_name = model_name
self._muscle_name = muscle_name
self._sum_stim_truncation = sum_stim_truncation
self.sum_stim_truncation = sum_stim_truncation
self._with_fatigue = False
self.pulse_apparition_time = None
self.stim_time = stim_time if stim_time else []
Expand Down Expand Up @@ -392,12 +394,12 @@ def declare_ding_variables(
ConfigureProblem.configure_dynamics_function(ocp, nlp, dyn_func=self.dynamics)

def _get_additional_previous_stim_time(self):
while len(self.previous_stim["time"]) < self._sum_stim_truncation:
while len(self.previous_stim["time"]) < self.sum_stim_truncation:
self.previous_stim["time"].insert(0, -10000000)
return self.previous_stim

def get_numerical_data_time_series(self, n_shooting, final_time, all_stim_time=None):
truncation = self._sum_stim_truncation
truncation = self.sum_stim_truncation
# --- Set the previous stim time for the numerical data time series (mandatory to avoid nan values) --- #
self.previous_stim = self._get_additional_previous_stim_time()
stim_time = (
Expand Down Expand Up @@ -427,3 +429,31 @@ def get_numerical_data_time_series(self, n_shooting, final_time, all_stim_time=N
stim_time_array = np.transpose(temp_result, (2, 1, 0))

return {"stim_time": stim_time_array}, stim_idx_at_node_list

def get_n_shooting(self, final_time: float) -> int:
"""
Prepare the n_shooting for the ocp in order to have a time step that is a multiple of the stimulation time.
Returns
-------
int
The number of shooting points
"""
# Represent the final time as a Fraction for exact arithmetic.
T_final = Fraction(final_time).limit_denominator()
n_shooting = 1

for t in self.stim_time:

t_frac = Fraction(t).limit_denominator() # Convert the stimulation time to an exact fraction.
norm = t_frac / T_final # Compute the normalized time: t / final_time.
d = norm.denominator # The denominator in the reduced fraction gives the requirement.
n_shooting = n_shooting * d // gcd(n_shooting, d)

if n_shooting >= 1000:
print(
f"Warning: The number of shooting nodes is very high n = {n_shooting}.\n"
"The optimization might be long, consider using stimulation time with even spacing (common frequency)."
)

return n_shooting
2 changes: 1 addition & 1 deletion cocofest/models/dynamical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def declare_model_variables(
StateConfigure().configure_last_pulse_width(ocp, nlp, muscle_name=str(muscle_model.muscle_name))
if isinstance(muscle_model, DingModelPulseIntensityFrequency):
StateConfigure().configure_pulse_intensity(
ocp, nlp, muscle_name=str(muscle_model.muscle_name), truncation=muscle_model._sum_stim_truncation
ocp, nlp, muscle_name=str(muscle_model.muscle_name), truncation=muscle_model.sum_stim_truncation
)
if self.activate_residual_torque:
ConfigureProblem.configure_tau(ocp, nlp, as_states=False, as_controls=True)
Expand Down
4 changes: 2 additions & 2 deletions cocofest/models/hmed2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def declare_ding_variables(
A list of values to pass to the dynamics at each node. Experimental external forces should be included here.
"""
StateConfigure().configure_all_fes_model_states(ocp, nlp, fes_model=self)
StateConfigure().configure_pulse_intensity(ocp, nlp, truncation=self._sum_stim_truncation)
StateConfigure().configure_pulse_intensity(ocp, nlp, truncation=self.sum_stim_truncation)
ConfigureProblem.configure_dynamics_function(ocp, nlp, dyn_func=self.dynamics)

def min_pulse_intensity(self):
Expand All @@ -309,7 +309,7 @@ def min_pulse_intensity(self):
return (np.arctanh(-self.cr) / self.bs) + self.Is

def _get_additional_previous_stim_time(self):
while len(self.previous_stim["time"]) < self._sum_stim_truncation:
while len(self.previous_stim["time"]) < self.sum_stim_truncation:
self.previous_stim["time"].insert(0, -10000000)
self.previous_stim["pulse_intensity"].insert(0, 50)
return self.previous_stim
2 changes: 1 addition & 1 deletion cocofest/models/hmed2018_with_fatigue.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,5 +312,5 @@ def declare_ding_variables(
A list of values to pass to the dynamics at each node. Experimental external forces should be included here.
"""
StateConfigure().configure_all_fes_model_states(ocp, nlp, fes_model=self)
StateConfigure().configure_pulse_intensity(ocp, nlp, truncation=self._sum_stim_truncation)
StateConfigure().configure_pulse_intensity(ocp, nlp, truncation=self.sum_stim_truncation)
ConfigureProblem.configure_dynamics_function(ocp, nlp, dyn_func=self.dynamics)
12 changes: 6 additions & 6 deletions cocofest/optimization/fes_identification_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def prepare_ocp(

numerical_data_time_series, stim_idx_at_node_list = model.get_numerical_data_time_series(n_shooting, final_time)

dynamics = OcpFesId._declare_dynamics(model=model, numerical_data_timeseries=numerical_data_time_series)
x_bounds, x_init = OcpFesId._set_bounds(
dynamics = OcpFesId.declare_dynamics(model=model, numerical_data_timeseries=numerical_data_time_series)
x_bounds, x_init = OcpFesId.set_x_bounds(
model=model,
force_tracking=objective["force_tracking"],
discontinuity_in_ocp=discontinuity_in_ocp,
Expand All @@ -118,7 +118,7 @@ def prepare_ocp(
if isinstance(model, DingModelPulseWidthFrequency)
else pulse_intensity["fixed"] if isinstance(model, DingModelPulseIntensityFrequency) else None
)
u_bounds, u_init = OcpFesId._set_u_bounds(
u_bounds, u_init = OcpFesId.set_u_bounds(
model=model, control_value=control_value, stim_idx_at_node_list=stim_idx_at_node_list, n_shooting=n_shooting
)

Expand Down Expand Up @@ -179,7 +179,7 @@ def _sanity_check_id(
)

@staticmethod
def _set_bounds(
def set_x_bounds(
model: FesModel = None,
force_tracking=None,
discontinuity_in_ocp=None,
Expand Down Expand Up @@ -310,7 +310,7 @@ def _set_phase_transition(discontinuity_in_ocp):
return phase_transitions

@staticmethod
def _set_u_bounds(model, control_value: list, stim_idx_at_node_list: list, n_shooting: int):
def set_u_bounds(model, control_value: list, stim_idx_at_node_list: list, n_shooting: int):
# Controls bounds
u_bounds = BoundsList()
# Controls initial guess
Expand Down Expand Up @@ -339,7 +339,7 @@ def _set_u_bounds(model, control_value: list, stim_idx_at_node_list: list, n_sho
)
for i in range(n_shooting)
]
for j in range(model._sum_stim_truncation)
for j in range(model.sum_stim_truncation)
]

u_init.add(key="pulse_intensity", initial_guess=np.array(control_list)[:, 0], phase=0)
Expand Down
4 changes: 2 additions & 2 deletions cocofest/optimization/fes_nmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def advance_window_initial_guess_parameters(self, sol, **advance_options):
self.parameter_init[key].init[:, :] = reshaped_parameter[self.initial_guess_param_index, :]

def update_stim(self):
truncation_term = self.nlp[0].model._sum_stim_truncation
truncation_term = self.nlp[0].model.sum_stim_truncation
solution_stimulation_time = self.nlp[0].model.stim_time[-truncation_term:]
previous_stim_time = [x - self.phase_time[0] for x in solution_stimulation_time]
previous_stim = {"time": previous_stim_time}
Expand Down Expand Up @@ -143,7 +143,7 @@ def create_model_from_list(self, models: list):
stim_time = [val for sublist in stim_time for val in sublist]

combined_model = DingModelPulseWidthFrequencyWithFatigue(
stim_time=stim_time, sum_stim_truncation=self.nlp[0].model._sum_stim_truncation
stim_time=stim_time, sum_stim_truncation=self.nlp[0].model.sum_stim_truncation
)
return combined_model

Expand Down
Loading

0 comments on commit 82146ad

Please sign in to comment.