From d2b31c80359e047bd0a1e0d82f08eba930b37b96 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 4 Jun 2024 15:27:36 +0100 Subject: [PATCH 1/5] Split validators into modules --- movement/io/load_poses.py | 4 +- movement/io/save_poses.py | 2 +- movement/io/validators/__init__.py | 0 movement/io/validators/datasets.py | 178 +++++++++++++++++ .../io/{validators.py => validators/files.py} | 179 +----------------- movement/move_accessor.py | 2 +- tests/test_unit/test_validators.py | 4 +- 7 files changed, 188 insertions(+), 181 deletions(-) create mode 100644 movement/io/validators/__init__.py create mode 100644 movement/io/validators/datasets.py rename movement/io/{validators.py => validators/files.py} (50%) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index e2b30910..5debf89c 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -12,11 +12,11 @@ from sleap_io.model.labels import Labels from movement import MovementDataset -from movement.io.validators import ( +from movement.io.validators.datasets import ValidPosesDataset +from movement.io.validators.files import ( ValidDeepLabCutCSV, ValidFile, ValidHDF5, - ValidPosesDataset, ) from movement.logging import log_error, log_warning diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index e0ee5778..3d6b29b8 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -9,7 +9,7 @@ import pandas as pd import xarray as xr -from movement.io.validators import ValidFile +from movement.io.validators.files import ValidFile from movement.logging import log_error logger = logging.getLogger(__name__) diff --git a/movement/io/validators/__init__.py b/movement/io/validators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/movement/io/validators/datasets.py b/movement/io/validators/datasets.py new file mode 100644 index 00000000..159c7348 --- /dev/null +++ b/movement/io/validators/datasets.py @@ -0,0 +1,178 @@ +"""`attrs` classes for validating data structures.""" + +from collections.abc import Iterable +from typing import Any, Optional, Union + +import numpy as np +from attrs import converters, define, field, validators + +from movement.logging import log_error, log_warning + + +def _list_of_str(value: Union[str, Iterable[Any]]) -> list[str]: + """Try to coerce the value into a list of strings.""" + if isinstance(value, str): + log_warning( + f"Invalid value ({value}). Expected a list of strings. " + "Converting to a list of length 1." + ) + return [value] + elif isinstance(value, Iterable): + return [str(item) for item in value] + else: + raise log_error( + ValueError, f"Invalid value ({value}). Expected a list of strings." + ) + + +def _ensure_type_ndarray(value: Any) -> None: + """Raise ValueError the value is a not numpy array.""" + if not isinstance(value, np.ndarray): + raise log_error( + ValueError, f"Expected a numpy array, but got {type(value)}." + ) + + +def _set_fps_to_none_if_invalid(fps: Optional[float]) -> Optional[float]: + """Set fps to None if a non-positive float is passed.""" + if fps is not None and fps <= 0: + log_warning( + f"Invalid fps value ({fps}). Expected a positive number. " + "Setting fps to None." + ) + return None + return fps + + +def _validate_list_length( + attribute: str, value: Optional[list], expected_length: int +): + """Raise a ValueError if the list does not have the expected length.""" + if (value is not None) and (len(value) != expected_length): + raise log_error( + ValueError, + f"Expected `{attribute}` to have length {expected_length}, " + f"but got {len(value)}.", + ) + + +@define(kw_only=True) +class ValidPosesDataset: + """Class for validating data intended for a ``movement`` dataset. + + Attributes + ---------- + position_array : np.ndarray + Array of shape (n_frames, n_individuals, n_keypoints, n_space) + containing the poses. + confidence_array : np.ndarray, optional + Array of shape (n_frames, n_individuals, n_keypoints) containing + the point-wise confidence scores. + If None (default), the scores will be set to an array of NaNs. + individual_names : list of str, optional + List of unique names for the individuals in the video. If None + (default), the individuals will be named "individual_0", + "individual_1", etc. + keypoint_names : list of str, optional + List of unique names for the keypoints in the skeleton. If None + (default), the keypoints will be named "keypoint_0", "keypoint_1", + etc. + fps : float, optional + Frames per second of the video. Defaults to None. + source_software : str, optional + Name of the software from which the poses were loaded. + Defaults to None. + + """ + + # Define class attributes + position_array: np.ndarray = field() + confidence_array: Optional[np.ndarray] = field(default=None) + individual_names: Optional[list[str]] = field( + default=None, + converter=converters.optional(_list_of_str), + ) + keypoint_names: Optional[list[str]] = field( + default=None, + converter=converters.optional(_list_of_str), + ) + fps: Optional[float] = field( + default=None, + converter=converters.pipe( # type: ignore + converters.optional(float), _set_fps_to_none_if_invalid + ), + ) + source_software: Optional[str] = field( + default=None, + validator=validators.optional(validators.instance_of(str)), + ) + + # Add validators + @position_array.validator + def _validate_position_array(self, attribute, value): + _ensure_type_ndarray(value) + if value.ndim != 4: + raise log_error( + ValueError, + f"Expected `{attribute}` to have 4 dimensions, " + f"but got {value.ndim}.", + ) + if value.shape[-1] not in [2, 3]: + raise log_error( + ValueError, + f"Expected `{attribute}` to have 2 or 3 spatial dimensions, " + f"but got {value.shape[-1]}.", + ) + + @confidence_array.validator + def _validate_confidence_array(self, attribute, value): + if value is not None: + _ensure_type_ndarray(value) + expected_shape = self.position_array.shape[:-1] + if value.shape != expected_shape: + raise log_error( + ValueError, + f"Expected `{attribute}` to have shape " + f"{expected_shape}, but got {value.shape}.", + ) + + @individual_names.validator + def _validate_individual_names(self, attribute, value): + if self.source_software == "LightningPose": + # LightningPose only supports a single individual + _validate_list_length(attribute, value, 1) + else: + _validate_list_length( + attribute, value, self.position_array.shape[1] + ) + + @keypoint_names.validator + def _validate_keypoint_names(self, attribute, value): + _validate_list_length(attribute, value, self.position_array.shape[2]) + + def __attrs_post_init__(self): + """Assign default values to optional attributes (if None).""" + if self.confidence_array is None: + self.confidence_array = np.full( + (self.position_array.shape[:-1]), np.nan, dtype="float32" + ) + log_warning( + "Confidence array was not provided." + "Setting to an array of NaNs." + ) + if self.individual_names is None: + self.individual_names = [ + f"individual_{i}" for i in range(self.position_array.shape[1]) + ] + log_warning( + "Individual names were not provided. " + f"Setting to {self.individual_names}." + ) + if self.keypoint_names is None: + self.keypoint_names = [ + f"keypoint_{i}" for i in range(self.position_array.shape[2]) + ] + log_warning( + "Keypoint names were not provided. " + f"Setting to {self.keypoint_names}." + ) diff --git a/movement/io/validators.py b/movement/io/validators/files.py similarity index 50% rename from movement/io/validators.py rename to movement/io/validators/files.py index 6174f945..b7e7ca9e 100644 --- a/movement/io/validators.py +++ b/movement/io/validators/files.py @@ -1,15 +1,13 @@ -"""`attrs` classes for validating file paths and data structures.""" +"""`attrs` classes for validating file paths.""" import os -from collections.abc import Iterable from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Literal import h5py -import numpy as np -from attrs import converters, define, field, validators +from attrs import define, field, validators -from movement.logging import log_error, log_warning +from movement.logging import log_error @define @@ -200,172 +198,3 @@ def csv_file_contains_expected_levels(self, attribute, value): ".csv header rows do not match the known format for " "DeepLabCut pose estimation output files.", ) - - -def _list_of_str(value: Union[str, Iterable[Any]]) -> list[str]: - """Try to coerce the value into a list of strings.""" - if isinstance(value, str): - log_warning( - f"Invalid value ({value}). Expected a list of strings. " - "Converting to a list of length 1." - ) - return [value] - elif isinstance(value, Iterable): - return [str(item) for item in value] - else: - raise log_error( - ValueError, f"Invalid value ({value}). Expected a list of strings." - ) - - -def _ensure_type_ndarray(value: Any) -> None: - """Raise ValueError the value is a not numpy array.""" - if not isinstance(value, np.ndarray): - raise log_error( - ValueError, f"Expected a numpy array, but got {type(value)}." - ) - - -def _set_fps_to_none_if_invalid(fps: Optional[float]) -> Optional[float]: - """Set fps to None if a non-positive float is passed.""" - if fps is not None and fps <= 0: - log_warning( - f"Invalid fps value ({fps}). Expected a positive number. " - "Setting fps to None." - ) - return None - return fps - - -def _validate_list_length( - attribute: str, value: Optional[list], expected_length: int -): - """Raise a ValueError if the list does not have the expected length.""" - if (value is not None) and (len(value) != expected_length): - raise log_error( - ValueError, - f"Expected `{attribute}` to have length {expected_length}, " - f"but got {len(value)}.", - ) - - -@define(kw_only=True) -class ValidPosesDataset: - """Class for validating data intended for a ``movement`` dataset. - - Attributes - ---------- - position_array : np.ndarray - Array of shape (n_frames, n_individuals, n_keypoints, n_space) - containing the poses. - confidence_array : np.ndarray, optional - Array of shape (n_frames, n_individuals, n_keypoints) containing - the point-wise confidence scores. - If None (default), the scores will be set to an array of NaNs. - individual_names : list of str, optional - List of unique names for the individuals in the video. If None - (default), the individuals will be named "individual_0", - "individual_1", etc. - keypoint_names : list of str, optional - List of unique names for the keypoints in the skeleton. If None - (default), the keypoints will be named "keypoint_0", "keypoint_1", - etc. - fps : float, optional - Frames per second of the video. Defaults to None. - source_software : str, optional - Name of the software from which the poses were loaded. - Defaults to None. - - """ - - # Define class attributes - position_array: np.ndarray = field() - confidence_array: Optional[np.ndarray] = field(default=None) - individual_names: Optional[list[str]] = field( - default=None, - converter=converters.optional(_list_of_str), - ) - keypoint_names: Optional[list[str]] = field( - default=None, - converter=converters.optional(_list_of_str), - ) - fps: Optional[float] = field( - default=None, - converter=converters.pipe( # type: ignore - converters.optional(float), _set_fps_to_none_if_invalid - ), - ) - source_software: Optional[str] = field( - default=None, - validator=validators.optional(validators.instance_of(str)), - ) - - # Add validators - @position_array.validator - def _validate_position_array(self, attribute, value): - _ensure_type_ndarray(value) - if value.ndim != 4: - raise log_error( - ValueError, - f"Expected `{attribute}` to have 4 dimensions, " - f"but got {value.ndim}.", - ) - if value.shape[-1] not in [2, 3]: - raise log_error( - ValueError, - f"Expected `{attribute}` to have 2 or 3 spatial dimensions, " - f"but got {value.shape[-1]}.", - ) - - @confidence_array.validator - def _validate_confidence_array(self, attribute, value): - if value is not None: - _ensure_type_ndarray(value) - expected_shape = self.position_array.shape[:-1] - if value.shape != expected_shape: - raise log_error( - ValueError, - f"Expected `{attribute}` to have shape " - f"{expected_shape}, but got {value.shape}.", - ) - - @individual_names.validator - def _validate_individual_names(self, attribute, value): - if self.source_software == "LightningPose": - # LightningPose only supports a single individual - _validate_list_length(attribute, value, 1) - else: - _validate_list_length( - attribute, value, self.position_array.shape[1] - ) - - @keypoint_names.validator - def _validate_keypoint_names(self, attribute, value): - _validate_list_length(attribute, value, self.position_array.shape[2]) - - def __attrs_post_init__(self): - """Assign default values to optional attributes (if None).""" - if self.confidence_array is None: - self.confidence_array = np.full( - (self.position_array.shape[:-1]), np.nan, dtype="float32" - ) - log_warning( - "Confidence array was not provided." - "Setting to an array of NaNs." - ) - if self.individual_names is None: - self.individual_names = [ - f"individual_{i}" for i in range(self.position_array.shape[1]) - ] - log_warning( - "Individual names were not provided. " - f"Setting to {self.individual_names}." - ) - if self.keypoint_names is None: - self.keypoint_names = [ - f"keypoint_{i}" for i in range(self.position_array.shape[2]) - ] - log_warning( - "Keypoint names were not provided. " - f"Setting to {self.keypoint_names}." - ) diff --git a/movement/move_accessor.py b/movement/move_accessor.py index 17c268dc..0031cdf4 100644 --- a/movement/move_accessor.py +++ b/movement/move_accessor.py @@ -6,7 +6,7 @@ import xarray as xr from movement.analysis import kinematics -from movement.io.validators import ValidPosesDataset +from movement.io.validators.datasets import ValidPosesDataset logger = logging.getLogger(__name__) diff --git a/tests/test_unit/test_validators.py b/tests/test_unit/test_validators.py index fc44f94f..af9c5c1b 100644 --- a/tests/test_unit/test_validators.py +++ b/tests/test_unit/test_validators.py @@ -3,11 +3,11 @@ import numpy as np import pytest -from movement.io.validators import ( +from movement.io.validators.datasets import ValidPosesDataset +from movement.io.validators.files import ( ValidDeepLabCutCSV, ValidFile, ValidHDF5, - ValidPosesDataset, ) From 1c9c686aa112613c9ac3248ef3d0f820bf970fc8 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 4 Jun 2024 15:29:32 +0100 Subject: [PATCH 2/5] Fix API index --- docs/source/api_index.rst | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst index a76a19f1..50905275 100644 --- a/docs/source/api_index.rst +++ b/docs/source/api_index.rst @@ -29,15 +29,20 @@ Save poses to_sleap_analysis_file to_dlc_style_df -Validators ----------- -.. currentmodule:: movement.io.validators +Validators - Files +------------------ +.. currentmodule:: movement.io.validators.files +.. autosummary:: + :toctree: api + + ValidPosesDataset + +Validators - Datasets +---------------------- +.. currentmodule:: movement.io.validators.datasets .. autosummary:: :toctree: api - ValidFile - ValidHDF5 - ValidDeepLabCutCSV ValidPosesDataset Sample Data From b26842a7929b5add78cc98e50590c3dae98fec7c Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:37:01 +0100 Subject: [PATCH 3/5] Fix API reference --- docs/source/api_index.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst index 50905275..6d894a53 100644 --- a/docs/source/api_index.rst +++ b/docs/source/api_index.rst @@ -35,7 +35,9 @@ Validators - Files .. autosummary:: :toctree: api - ValidPosesDataset + ValidFile + ValidHDF5 + ValidDeepLabCutCSV Validators - Datasets ---------------------- From 17af1f77339da1f2c1562e39a05ad5b1667133c4 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 12 Jun 2024 11:29:11 +0100 Subject: [PATCH 4/5] Fix docstring --- movement/io/validators/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/movement/io/validators/datasets.py b/movement/io/validators/datasets.py index 159c7348..3e00895d 100644 --- a/movement/io/validators/datasets.py +++ b/movement/io/validators/datasets.py @@ -26,7 +26,7 @@ def _list_of_str(value: Union[str, Iterable[Any]]) -> list[str]: def _ensure_type_ndarray(value: Any) -> None: - """Raise ValueError the value is a not numpy array.""" + """Raise ValueError if the value is not a numpy array.""" if not isinstance(value, np.ndarray): raise log_error( ValueError, f"Expected a numpy array, but got {type(value)}." From e087b9d22553e4ea9d1e4cd98d948bb85f8154de Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 12 Jun 2024 11:31:32 +0100 Subject: [PATCH 5/5] Move validator module one level up --- docs/source/api_index.rst | 4 ++-- movement/io/load_poses.py | 8 ++------ movement/io/save_poses.py | 2 +- movement/move_accessor.py | 2 +- movement/{io => }/validators/__init__.py | 0 movement/{io => }/validators/datasets.py | 0 movement/{io => }/validators/files.py | 0 tests/test_unit/test_validators.py | 8 ++------ 8 files changed, 8 insertions(+), 16 deletions(-) rename movement/{io => }/validators/__init__.py (100%) rename movement/{io => }/validators/datasets.py (100%) rename movement/{io => }/validators/files.py (100%) diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst index 6d894a53..09454002 100644 --- a/docs/source/api_index.rst +++ b/docs/source/api_index.rst @@ -31,7 +31,7 @@ Save poses Validators - Files ------------------ -.. currentmodule:: movement.io.validators.files +.. currentmodule:: movement.validators.files .. autosummary:: :toctree: api @@ -41,7 +41,7 @@ Validators - Files Validators - Datasets ---------------------- -.. currentmodule:: movement.io.validators.datasets +.. currentmodule:: movement.validators.datasets .. autosummary:: :toctree: api diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 5debf89c..5f461f2e 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -12,13 +12,9 @@ from sleap_io.model.labels import Labels from movement import MovementDataset -from movement.io.validators.datasets import ValidPosesDataset -from movement.io.validators.files import ( - ValidDeepLabCutCSV, - ValidFile, - ValidHDF5, -) from movement.logging import log_error, log_warning +from movement.validators.datasets import ValidPosesDataset +from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5 logger = logging.getLogger(__name__) diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index 3d6b29b8..d151ada7 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -9,8 +9,8 @@ import pandas as pd import xarray as xr -from movement.io.validators.files import ValidFile from movement.logging import log_error +from movement.validators.files import ValidFile logger = logging.getLogger(__name__) diff --git a/movement/move_accessor.py b/movement/move_accessor.py index 0031cdf4..03f41c18 100644 --- a/movement/move_accessor.py +++ b/movement/move_accessor.py @@ -6,7 +6,7 @@ import xarray as xr from movement.analysis import kinematics -from movement.io.validators.datasets import ValidPosesDataset +from movement.validators.datasets import ValidPosesDataset logger = logging.getLogger(__name__) diff --git a/movement/io/validators/__init__.py b/movement/validators/__init__.py similarity index 100% rename from movement/io/validators/__init__.py rename to movement/validators/__init__.py diff --git a/movement/io/validators/datasets.py b/movement/validators/datasets.py similarity index 100% rename from movement/io/validators/datasets.py rename to movement/validators/datasets.py diff --git a/movement/io/validators/files.py b/movement/validators/files.py similarity index 100% rename from movement/io/validators/files.py rename to movement/validators/files.py diff --git a/tests/test_unit/test_validators.py b/tests/test_unit/test_validators.py index af9c5c1b..60002e7c 100644 --- a/tests/test_unit/test_validators.py +++ b/tests/test_unit/test_validators.py @@ -3,12 +3,8 @@ import numpy as np import pytest -from movement.io.validators.datasets import ValidPosesDataset -from movement.io.validators.files import ( - ValidDeepLabCutCSV, - ValidFile, - ValidHDF5, -) +from movement.validators.datasets import ValidPosesDataset +from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5 class TestValidators: