Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor validators #204

Merged
merged 6 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions docs/source/api_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,22 @@ Save poses
to_sleap_analysis_file
to_dlc_style_df

Validators
----------
.. currentmodule:: movement.io.validators
Validators - Files
------------------
.. currentmodule:: movement.validators.files
.. autosummary::
:toctree: api

ValidFile
ValidHDF5
ValidDeepLabCutCSV

Validators - Datasets
----------------------
.. currentmodule:: movement.validators.datasets
.. autosummary::
:toctree: api

ValidPosesDataset

Sample Data
Expand Down
8 changes: 2 additions & 6 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
from sleap_io.model.labels import Labels

from movement import MovementDataset
from movement.io.validators import (
ValidDeepLabCutCSV,
ValidFile,
ValidHDF5,
ValidPosesDataset,
)
from movement.logging import log_error, log_warning
from movement.validators.datasets import ValidPosesDataset
from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion movement/io/save_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import pandas as pd
import xarray as xr

from movement.io.validators import ValidFile
from movement.logging import log_error
from movement.validators.files import ValidFile

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion movement/move_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import xarray as xr

from movement.analysis import kinematics
from movement.io.validators import ValidPosesDataset
from movement.logging import log_error
from movement.validators.datasets import ValidPosesDataset

logger = logging.getLogger(__name__)

Expand Down
Empty file added movement/validators/__init__.py
Empty file.
178 changes: 178 additions & 0 deletions movement/validators/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""`attrs` classes for validating data structures."""

from collections.abc import Iterable
from typing import Any

import numpy as np
from attrs import converters, define, field, validators

from movement.logging import log_error, log_warning


def _list_of_str(value: str | Iterable[Any]) -> list[str]:
"""Try to coerce the value into a list of strings."""
if isinstance(value, str):
log_warning(
f"Invalid value ({value}). Expected a list of strings. "
"Converting to a list of length 1."
)
return [value]
elif isinstance(value, Iterable):
return [str(item) for item in value]
else:
raise log_error(
ValueError, f"Invalid value ({value}). Expected a list of strings."
)


def _ensure_type_ndarray(value: Any) -> None:
"""Raise ValueError the value is a not numpy array."""
if not isinstance(value, np.ndarray):
raise log_error(
ValueError, f"Expected a numpy array, but got {type(value)}."
)


def _set_fps_to_none_if_invalid(fps: float | None) -> float | None:
"""Set fps to None if a non-positive float is passed."""
if fps is not None and fps <= 0:
log_warning(
f"Invalid fps value ({fps}). Expected a positive number. "
"Setting fps to None."
)
return None
return fps


def _validate_list_length(
attribute: str, value: list | None, expected_length: int
):
"""Raise a ValueError if the list does not have the expected length."""
if (value is not None) and (len(value) != expected_length):
raise log_error(
ValueError,
f"Expected `{attribute}` to have length {expected_length}, "
f"but got {len(value)}.",
)


@define(kw_only=True)
class ValidPosesDataset:
"""Class for validating data intended for a ``movement`` dataset.

Attributes
----------
position_array : np.ndarray
Array of shape (n_frames, n_individuals, n_keypoints, n_space)
containing the poses.
confidence_array : np.ndarray, optional
Array of shape (n_frames, n_individuals, n_keypoints) containing
the point-wise confidence scores.
If None (default), the scores will be set to an array of NaNs.
individual_names : list of str, optional
List of unique names for the individuals in the video. If None
(default), the individuals will be named "individual_0",
"individual_1", etc.
keypoint_names : list of str, optional
List of unique names for the keypoints in the skeleton. If None
(default), the keypoints will be named "keypoint_0", "keypoint_1",
etc.
fps : float, optional
Frames per second of the video. Defaults to None.
source_software : str, optional
Name of the software from which the poses were loaded.
Defaults to None.

"""

# Define class attributes
position_array: np.ndarray = field()
confidence_array: np.ndarray | None = field(default=None)
individual_names: list[str] | None = field(
default=None,
converter=converters.optional(_list_of_str),
)
keypoint_names: list[str] | None = field(
default=None,
converter=converters.optional(_list_of_str),
)
fps: float | None = field(
default=None,
converter=converters.pipe( # type: ignore
converters.optional(float), _set_fps_to_none_if_invalid
),
)
source_software: str | None = field(
default=None,
validator=validators.optional(validators.instance_of(str)),
)

# Add validators
@position_array.validator
def _validate_position_array(self, attribute, value):
_ensure_type_ndarray(value)
if value.ndim != 4:
raise log_error(
ValueError,
f"Expected `{attribute}` to have 4 dimensions, "
f"but got {value.ndim}.",
)
if value.shape[-1] not in [2, 3]:
raise log_error(
ValueError,
f"Expected `{attribute}` to have 2 or 3 spatial dimensions, "
f"but got {value.shape[-1]}.",
)

@confidence_array.validator
def _validate_confidence_array(self, attribute, value):
if value is not None:
_ensure_type_ndarray(value)
expected_shape = self.position_array.shape[:-1]
if value.shape != expected_shape:
raise log_error(
ValueError,
f"Expected `{attribute}` to have shape "
f"{expected_shape}, but got {value.shape}.",
)

@individual_names.validator
def _validate_individual_names(self, attribute, value):
if self.source_software == "LightningPose":
# LightningPose only supports a single individual
_validate_list_length(attribute, value, 1)
else:
_validate_list_length(
attribute, value, self.position_array.shape[1]
)

@keypoint_names.validator
def _validate_keypoint_names(self, attribute, value):
_validate_list_length(attribute, value, self.position_array.shape[2])

def __attrs_post_init__(self):
"""Assign default values to optional attributes (if None)."""
if self.confidence_array is None:
self.confidence_array = np.full(
(self.position_array.shape[:-1]), np.nan, dtype="float32"
)
log_warning(
"Confidence array was not provided."
"Setting to an array of NaNs."
)
if self.individual_names is None:
self.individual_names = [
f"individual_{i}" for i in range(self.position_array.shape[1])
]
log_warning(
"Individual names were not provided. "
f"Setting to {self.individual_names}."
)
if self.keypoint_names is None:
self.keypoint_names = [
f"keypoint_{i}" for i in range(self.position_array.shape[2])
]
log_warning(
"Keypoint names were not provided. "
f"Setting to {self.keypoint_names}."
)
Loading