Skip to content

Commit

Permalink
Merge pull request #8 from adriensizaret/main-new-metrics-dev
Browse files Browse the repository at this point in the history
MoS Adjustements
  • Loading branch information
anboen authored Dec 13, 2024
2 parents 331fa57 + afc5b57 commit 7f40ef9
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 59 deletions.
132 changes: 76 additions & 56 deletions gaitalytics/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _get_progression_vector(self, trial: model.Trial) -> xr.DataArray:
def _get_sagittal_vector(self, trial: model.Trial) -> xr.DataArray:
"""Calculate the sagittal vector for a trial.
The sagittal vector is the vector normal to the sagittal plane.
The sagittal vector is the vector normal to the sagittal plane. Note that this vector will always be pointing towards the right side of the body.
Args:
trial: The trial for which to calculate the sagittal vector.
Expand Down Expand Up @@ -507,19 +507,23 @@ class SpatialFeatures(_PointDependentFeature):
- step_width
- minimal_toe_clearance
- AP_margin_of_stability
- AP_base_of_support
- AP_xcom
- ML_margin_of_stability
- ML_base_of_support
- ML_xcom
"""

def _calculate(self, trial: model.Trial) -> xr.DataArray:
"""Calculate the spatial features for a trial.
Definitions of the spatial features:
Step length & Step width: Hollmann et al. 2011 (doi: 10.1016/j.gaitpost.2011.03.024)
Margin of stability: Jinfeng et al. 2021 (doi: 10.1152/jn.00091.2021)
Margin of stability: Jinfeng et al. 2021 (doi: 10.1152/jn.00091.2021), Curtze et al. 2024 (doi: 10.1016/j.jbiomech.2024.112045)
Minimal toe clearance: Schulz 2017 (doi: 10.1016/j.jbiomech.2017.02.024)
Args:
trial: The trial for which to calculate the features.
trial: The trial for which to calculate the features.
Returns:
An xarray DataArray containing the calculated features.
Expand Down Expand Up @@ -567,9 +571,9 @@ def _calculate(self, trial: model.Trial) -> xr.DataArray:
results_dict.update(
self._calculate_ap_margin_of_stability(
trial,
marker_dict["ipsi_toe_2"], # type: ignore
marker_dict["ipsi_heel"], # type: ignore
marker_dict["contra_toe_2"], # type: ignore
marker_dict["xcom"],
marker_dict["xcom"]
)
)

Expand All @@ -578,7 +582,7 @@ def _calculate(self, trial: model.Trial) -> xr.DataArray:
trial,
marker_dict["ipsi_ankle"],
marker_dict["contra_ankle"],
marker_dict["xcom"],
marker_dict["xcom"]
)
)
except KeyError:
Expand Down Expand Up @@ -704,7 +708,7 @@ def _calculate_stride_length(
ipsi_marker: mapping.MappedMarkers,
contra_marker: mapping.MappedMarkers,
) -> dict[str, np.ndarray]:
"""Calculate the stride length for a trial.
"""Calculate the stride length for a trial. It is computed as the two consecutive step lengths constituting the gait cycle.
Args:
trial: The trial for which to calculate the stride length.
Expand Down Expand Up @@ -838,26 +842,28 @@ def _find_mtc_index(

return None if not mtc_i else min(mtc_i, key=lambda i: toe_z[i]) # type: ignore

def _calculate_ap_margin_of_stability(
self,
trial: model.Trial,
ipsi_toe_marker: mapping.MappedMarkers,
contra_toe_marker: mapping.MappedMarkers,
xcom_marker: mapping.MappedMarkers,
) -> dict[str, np.ndarray]:
"""Calculate the anterior-posterior margin of stability at heel strike
def _calculate_ap_margin_of_stability(self,
trial: model.Trial,
ipsi_heel_marker: mapping.MappedMarkers,
contra_toe_marker: mapping.MappedMarkers,
xcom_marker: mapping.MappedMarkers,
) -> dict[str, np.ndarray]:
"""Calculate the anterio-posterior margin of stability at heel strike. Result should be interpreted according to Curtze et al. (2024)
Args:
trial: The trial for which to calculate the AP margin of stability
ipsi_toe_marker: The ipsi-lateral toe marker
contra_toe_marker: The contra-lateral toe marker
ipsi_heel_marker: The ipsi-lateral heel marker
contra_marker: The contra-lateral toe marker
xcom_marker: The extrapolated center of mass marker
Returns:
The calculated anterior-posterior margin of stability in a dict
dict: A dictionary containing:
- "AP_margin_of_stability": The calculated anterio-posterior margin of stability.
- "AP_base_of_support": The calculated anterio-posterior base of support.
- "AP_XCOM": The calculated anterio-posterior position of the extrapolated center of mass relative to the back foot.
"""
event_times = self.get_event_times(trial.events)

ipsi_toe = self._get_marker_data(trial, ipsi_toe_marker).sel(
ipsi_heel = self._get_marker_data(trial, ipsi_heel_marker).sel(
time=event_times[0], method="nearest"
)
contra_toe = self._get_marker_data(trial, contra_toe_marker).sel(
Expand All @@ -866,37 +872,42 @@ def _calculate_ap_margin_of_stability(
xcom = self._get_marker_data(trial, xcom_marker).sel(
time=event_times[0], method="nearest"
)

progress_axis = self._get_progression_vector(trial)
progress_axis = linalg.normalize_vector(progress_axis)

projected_ipsi = linalg.project_point_on_vector(ipsi_toe, progress_axis)
projected_contra = linalg.project_point_on_vector(contra_toe, progress_axis)
projected_xcom = linalg.project_point_on_vector(xcom, progress_axis)

bos_len = linalg.calculate_distance(projected_ipsi, projected_contra).values
xcom_len = linalg.calculate_distance(projected_contra, projected_xcom).values

mos = bos_len - xcom_len

return {"AP_margin_of_stability": mos}

def _calculate_ml_margin_of_stability(
self,
trial: model.Trial,
ipsi_ankle_marker: mapping.MappedMarkers,
contra_ankle_marker: mapping.MappedMarkers,
xcom_marker: mapping.MappedMarkers,
) -> dict[str, np.ndarray]:
"""Calculate the medio-lateral margin of stability at heel strike

front_marker = linalg.get_point_in_front(ipsi_heel, contra_toe, progress_axis)
back_marker = linalg.get_point_behind(ipsi_heel, contra_toe, progress_axis)

bos_vect = front_marker - back_marker
xcom_vect = xcom - back_marker

bos_proj = abs(linalg.signed_projection_norm(bos_vect, progress_axis))
xcom_proj = linalg.signed_projection_norm(xcom_vect, progress_axis)
mos = bos_proj - xcom_proj

return {"AP_margin_of_stability": mos,
"AP_base_of_support": bos_proj,
"AP_xcom": xcom_proj}

def _calculate_ml_margin_of_stability(self,
trial: model.Trial,
ipsi_ankle_marker: mapping.MappedMarkers,
contra_ankle_marker: mapping.MappedMarkers,
xcom_marker: mapping.MappedMarkers
) -> dict[str, np.ndarray]:
"""Calculate the medio-lateral margin of stability at heel strike. Result should be interpreted according to Curtze et al. (2024)
Args:
trial: The trial for which to calculate the AP margin of stability
ipsi_ankle_marker: The ipsi-lateral lateral ankle marker
contra_ankle_marker: The contra-lateral lateral ankle marker
trial: The trial for which to calculate the ml margin of stability
ipsi_toe_marker: The ipsi-lateral lateral ankle marker
contra_marker: The contra-lateral lateral ankle marker
xcom_marker: The extrapolated center of mass marker
Returns:
The calculated anterio-posterior margin of stability in a dict
dict: A dictionary containing:
- "ML_margin_of_stability": The calculated medio-lateral margin of stability.
- "ML_base_of_support": The calculated medio-lateral base of support.
- "ML_xcom": The calculated medio-lateral position of the extrapolated center of mass relative to the back foot.
"""
event_times = self.get_event_times(trial.events)

Expand All @@ -908,18 +919,27 @@ def _calculate_ml_margin_of_stability(
)
xcom = self._get_marker_data(trial, xcom_marker).sel(
time=event_times[0], method="nearest"
)
)

sagittal_axis = self._get_sagittal_vector(trial)
sagittal_axis = linalg.normalize_vector(sagittal_axis)

projected_ipsi = linalg.project_point_on_vector(ipsi_ankle, sagittal_axis)
projected_contra = linalg.project_point_on_vector(contra_ankle, sagittal_axis)
projected_xcom = linalg.project_point_on_vector(xcom, sagittal_axis)

bos_len = linalg.calculate_distance(projected_contra, projected_ipsi).values
xcom_len = linalg.calculate_distance(projected_contra, projected_xcom).values

mos = bos_len - xcom_len

return {"ML_margin_of_stability": mos}

if trial.events.attrs["context"] == "Left":
#Rotate sagittal axis so it points towards the left side of the body
sagittal_axis = -sagittal_axis

# Lateral is the furthest point in the direction of the sagittal axis
lateral_point = linalg.get_point_in_front(ipsi_ankle, contra_ankle, sagittal_axis)
# Medial is the closest point in the direction of the sagittal axis
medial_point = linalg.get_point_behind(ipsi_ankle, contra_ankle, sagittal_axis)

bos_vect = lateral_point - medial_point
xcom_vect = xcom - medial_point

bos_proj = abs(linalg.signed_projection_norm(bos_vect, sagittal_axis))
xcom_proj = linalg.signed_projection_norm(xcom_vect, sagittal_axis)
mos = bos_proj - xcom_proj

return {"ML_margin_of_stability": mos,
"ML_base_of_support": bos_proj,
"ML_xcom": xcom_proj}
1 change: 0 additions & 1 deletion gaitalytics/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""This module provides classes for reading biomechanical file-types."""

import math
from abc import abstractmethod, ABC
from pathlib import Path
Expand Down
56 changes: 54 additions & 2 deletions gaitalytics/utils/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,26 @@ def project_point_on_vector(point: xr.DataArray, vector: xr.DataArray) -> xr.Dat
"""
return vector * point.dot(vector, dim="axis")

def signed_projection_norm(vector: xr.DataArray, onto: xr.DataArray) -> xr.DataArray:
"""Compute the signed norm of the projection of a vector onto another vector. <br>
If the projection is in the same direction as the onto vector, the norm is positive. <br>
If the projection is in the opposite direction, the norm is negative. <br>
def get_normal_vector(vector1: xr.DataArray, vector2: xr.DataArray):
Args:
vector: The vector to be projected.
onto: The vector to project onto.
Returns:
An xarray DataArray containing the signed norm of the projected vector.
"""
projection = onto * vector.dot(onto, dim="axis") / onto.dot(onto, dim="axis")
projection_norm = projection.meca.norm(dim="axis")
sign = xr.where(vector.dot(onto, dim="axis") > 0, 1, -1)
sign = xr.where(vector.dot(onto, dim="axis") == 0, 0, sign)
return projection_norm * sign


def get_normal_vector(vector1: xr.DataArray, vector2: xr.DataArray) -> xr.DataArray:
"""Create a vector with norm = 1 normal to two other vectors.
Args:
Expand Down Expand Up @@ -80,4 +98,38 @@ def calculate_speed_norm(position: xr.DataArray, dt: float = 0.01) -> np.ndarray
speed_values = np.sqrt(velocity_squared_sum)
speed_values = np.append(speed_values, speed_values[-1])

return speed_values
return xr.DataArray(speed_values, dims=["time"], coords={"time": position.coords["time"]})

def get_point_in_front(point_a: xr.DataArray, point_b: xr.DataArray, direction_vector: xr.DataArray) -> xr.DataArray:
"""Determine which point is in front of the other according to the direction vector.
Args:
point_a: The first point.
point_b: The second point.
direction_vector: The direction vector.
Returns:
The point that is in front according to the direction vector.
"""
direction_vector = direction_vector / direction_vector.meca.norm(dim="axis")
vector_b_to_a = point_a - point_b
signed_distance = vector_b_to_a.dot(direction_vector, dim="axis")

return point_a if signed_distance > 0 else point_b

def get_point_behind(point_a: xr.DataArray, point_b: xr.DataArray, direction_vector: xr.DataArray) -> xr.DataArray:
"""Determine which point is behind the other according to the direction vector.
Args:
point_a: The first point.
point_b: The second point.
direction_vector: The direction vector.
Returns:
The point that is behind according to the direction vector.
"""
direction_vector = direction_vector / direction_vector.meca.norm(dim="axis")
vector_b_to_a = point_a - point_b
signed_distance = vector_b_to_a.dot(direction_vector, dim="axis")

return point_b if signed_distance > 0 else point_a
74 changes: 74 additions & 0 deletions tests/full/test_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
import xarray as xr
import numpy as np

from gaitalytics.utils.linalg import (
calculate_distance,
project_point_on_vector,
signed_projection_norm,
get_normal_vector,
normalize_vector,
calculate_speed_norm,
get_point_in_front,
get_point_behind
)

@pytest.fixture
def sample_data():
point_a = xr.DataArray([1, 2, 3], dims=["axis"], coords={"axis": ["x", "y", "z"]})
point_b = xr.DataArray([4, 5, 6], dims=["axis"], coords={"axis": ["x", "y", "z"]})
vector_a = xr.DataArray([1, 0, 0], dims=["axis"], coords={"axis": ["x", "y", "z"]})
vector_b = xr.DataArray([0, 1, 0], dims=["axis"], coords={"axis": ["x", "y", "z"]})
return point_a, point_b, vector_a, vector_b

def test_calculate_distance(sample_data):
point_a, point_b, _, _ = sample_data
distance = calculate_distance(point_a, point_b)
expected_distance = np.sqrt(27)
assert distance == pytest.approx(expected_distance)

def test_project_point_on_vector(sample_data):
point_a, _, vector_a, _ = sample_data
projected_point = project_point_on_vector(point_a, vector_a)
expected_projection = xr.DataArray([1, 0, 0], dims=["axis"], coords={"axis": ["x", "y", "z"]})
xr.testing.assert_allclose(projected_point, expected_projection)

def test_signed_projection_norm(sample_data):
_, _, vector_a, vector_b = sample_data
signed_norm = signed_projection_norm(vector_a, vector_b)
expected_signed_norm = 0.0
assert signed_norm == pytest.approx(expected_signed_norm)

def test_get_normal_vector(sample_data):
_, _, vector_a, vector_b = sample_data
normal_vector = get_normal_vector(vector_a, vector_b)
expected_normal_vector = xr.DataArray([0, 0, 1], dims=["axis"], coords={"axis": ["x", "y", "z"]})
xr.testing.assert_allclose(normal_vector, expected_normal_vector)

def test_normalize_vector(sample_data):
_, _, vector_a, _ = sample_data
normalized_vector = normalize_vector(vector_a)
expected_normalized_vector = xr.DataArray([1, 0, 0], dims=["axis"], coords={"axis": ["x", "y", "z"]})
xr.testing.assert_allclose(normalized_vector, expected_normalized_vector)

def test_calculate_speed_norm():
position = xr.DataArray(
np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]),
dims=["axis", "time"],
coords={"axis": ["x", "y", "z"], "time": [0, 1, 2]}
)
speed = calculate_speed_norm(position, dt=1.0)
expected_speed = xr.DataArray([np.sqrt(3), np.sqrt(3), np.sqrt(3)], dims=["time"], coords={"time": [0, 1, 2]})
xr.testing.assert_allclose(speed, expected_speed)

def test_get_point_in_front(sample_data):
point_a, point_b, vector_a, _ = sample_data
point_in_front = get_point_in_front(point_a, point_b, vector_a)
expected_point_in_front = point_b
xr.testing.assert_allclose(point_in_front, expected_point_in_front)

def test_get_point_behind(sample_data):
point_a, point_b, vector_a, _ = sample_data
point_behind = get_point_behind(point_a, point_b, vector_a)
expected_point_behind = point_a
xr.testing.assert_allclose(point_behind, expected_point_behind)

0 comments on commit 7f40ef9

Please sign in to comment.