Skip to content

Commit

Permalink
Suggestion for consistent names
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Jul 19, 2024
1 parent 1d58512 commit 4669f4a
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions movement/validators/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from movement.utils.logging import log_error, log_warning


def _list_of_str(value: str | Iterable[Any]) -> list[str]:
def _convert_to_list_of_str(value: str | Iterable[Any]) -> list[str]:
"""Try to coerce the value into a list of strings."""
if isinstance(value, str):
log_warning(
Expand All @@ -25,15 +25,26 @@ def _list_of_str(value: str | Iterable[Any]) -> list[str]:
)


def _ensure_type_ndarray(value: Any) -> None:
def _convert_fps_to_none_if_invalid(fps: float | None) -> float | None:
"""Set fps to None if a non-positive float is passed."""
if fps is not None and fps <= 0:
log_warning(
f"Invalid fps value ({fps}). Expected a positive number. "
"Setting fps to None."
)
return None
return fps


def _validate_type_ndarray(value: Any) -> None:
"""Raise ValueError the value is a not numpy array."""
if not isinstance(value, np.ndarray):
raise log_error(
ValueError, f"Expected a numpy array, but got {type(value)}."
)


def _ensure_shape(attribute, value: np.ndarray, expected_shape: tuple):
def _validate_array_shape(attribute, value: np.ndarray, expected_shape: tuple):
"""Raise ValueError if the value does not have the expected shape."""
if value.shape != expected_shape:
raise log_error(
Expand All @@ -43,17 +54,6 @@ def _ensure_shape(attribute, value: np.ndarray, expected_shape: tuple):
)


def _set_fps_to_none_if_invalid(fps: float | None) -> float | None:
"""Set fps to None if a non-positive float is passed."""
if fps is not None and fps <= 0:
log_warning(
f"Invalid fps value ({fps}). Expected a positive number. "
"Setting fps to None."
)
return None
return fps


def _validate_list_length(attribute, value: list | None, expected_length: int):
"""Raise a ValueError if the list does not have the expected length."""
if (value is not None) and (len(value) != expected_length):
Expand Down Expand Up @@ -98,16 +98,16 @@ class ValidPosesDataset:
confidence_array: np.ndarray | None = field(default=None)
individual_names: list[str] | None = field(
default=None,
converter=converters.optional(_list_of_str),
converter=converters.optional(_convert_to_list_of_str),
)
keypoint_names: list[str] | None = field(
default=None,
converter=converters.optional(_list_of_str),
converter=converters.optional(_convert_to_list_of_str),
)
fps: float | None = field(
default=None,
converter=converters.pipe( # type: ignore
converters.optional(float), _set_fps_to_none_if_invalid
converters.optional(float), _convert_fps_to_none_if_invalid
),
)
source_software: str | None = field(
Expand All @@ -118,7 +118,7 @@ class ValidPosesDataset:
# Add validators
@position_array.validator
def _validate_position_array(self, attribute, value):
_ensure_type_ndarray(value)
_validate_type_ndarray(value)
if value.ndim != 4:
raise log_error(
ValueError,
Expand All @@ -135,9 +135,9 @@ def _validate_position_array(self, attribute, value):
@confidence_array.validator
def _validate_confidence_array(self, attribute, value):
if value is not None:
_ensure_type_ndarray(value)
_validate_type_ndarray(value)

_ensure_shape(
_validate_array_shape(
attribute, value, expected_shape=self.position_array.shape[:-1]
)

Expand Down Expand Up @@ -230,14 +230,14 @@ class ValidBboxesDataset:
individual_names: list[str] | None = field(
default=None,
converter=converters.optional(
_list_of_str
_convert_to_list_of_str
), # force into list of strings if not
)
frame_array: np.ndarray | None = field(default=None)
fps: float | None = field(
default=None,
converter=converters.pipe( # type: ignore
converters.optional(float), _set_fps_to_none_if_invalid
converters.optional(float), _convert_fps_to_none_if_invalid
),
)
source_software: str | None = field(
Expand All @@ -249,7 +249,7 @@ class ValidBboxesDataset:
@position_array.validator
@shape_array.validator
def _validate_position_and_shape_arrays(self, attribute, value):
_ensure_type_ndarray(value)
_validate_type_ndarray(value)

# check last dimension (spatial) has 2 coordinates
n_expected_spatial_coordinates = 2
Expand Down Expand Up @@ -281,9 +281,9 @@ def _validate_individual_names(self, attribute, value):
@confidence_array.validator
def _validate_confidence_array(self, attribute, value):
if value is not None:
_ensure_type_ndarray(value)
_validate_type_ndarray(value)

_ensure_shape(
_validate_array_shape(
attribute, value, expected_shape=self.position_array.shape[:-1]
)

Expand All @@ -292,11 +292,11 @@ def _validate_frame_array(
self, attribute, value
): # ---- ADD check for contiguous frames
if value is not None:
_ensure_type_ndarray(value)
_validate_type_ndarray(value)

# should be a column vector (n_frames, 1)

_ensure_shape(
_validate_array_shape(
attribute,
value,
expected_shape=(self.position_array.shape[0], 1),
Expand Down

0 comments on commit 4669f4a

Please sign in to comment.