Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small edits to ValidBboxesDataset (1/4) #230

Merged
merged 13 commits into from
Jul 22, 2024
111 changes: 76 additions & 35 deletions movement/validators/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from collections.abc import Iterable
from typing import Any

import attrs
import numpy as np
from attrs import converters, define, field, validators

from movement.utils.logging import log_error, log_warning


def _list_of_str(value: str | Iterable[Any]) -> list[str]:
def _convert_to_list_of_str(value: str | Iterable[Any]) -> list[str]:
"""Try to coerce the value into a list of strings."""
if isinstance(value, str):
log_warning(
Expand All @@ -25,15 +26,7 @@ def _list_of_str(value: str | Iterable[Any]) -> list[str]:
)


def _ensure_type_ndarray(value: Any) -> None:
"""Raise ValueError the value is a not numpy array."""
if not isinstance(value, np.ndarray):
raise log_error(
ValueError, f"Expected a numpy array, but got {type(value)}."
)


def _set_fps_to_none_if_invalid(fps: float | None) -> float | None:
def _convert_fps_to_none_if_invalid(fps: float | None) -> float | None:
"""Set fps to None if a non-positive float is passed."""
if fps is not None and fps <= 0:
log_warning(
Expand All @@ -44,7 +37,29 @@ def _set_fps_to_none_if_invalid(fps: float | None) -> float | None:
return fps


def _validate_list_length(attribute, value: list | None, expected_length: int):
def _validate_type_ndarray(value: Any) -> None:
"""Raise ValueError the value is a not numpy array."""
if not isinstance(value, np.ndarray):
raise log_error(
ValueError, f"Expected a numpy array, but got {type(value)}."
)


def _validate_array_shape(
attribute: attrs.Attribute, value: np.ndarray, expected_shape: tuple
):
"""Raise ValueError if the value does not have the expected shape."""
if value.shape != expected_shape:
raise log_error(
ValueError,
f"Expected '{attribute.name}' to have shape {expected_shape}, "
f"but got {value.shape}.",
)


def _validate_list_length(
attribute: attrs.Attribute, value: list | None, expected_length: int
):
"""Raise a ValueError if the list does not have the expected length."""
if (value is not None) and (len(value) != expected_length):
raise log_error(
Expand Down Expand Up @@ -88,16 +103,16 @@ class ValidPosesDataset:
confidence_array: np.ndarray | None = field(default=None)
individual_names: list[str] | None = field(
default=None,
converter=converters.optional(_list_of_str),
converter=converters.optional(_convert_to_list_of_str),
)
keypoint_names: list[str] | None = field(
default=None,
converter=converters.optional(_list_of_str),
converter=converters.optional(_convert_to_list_of_str),
)
fps: float | None = field(
default=None,
converter=converters.pipe( # type: ignore
converters.optional(float), _set_fps_to_none_if_invalid
converters.optional(float), _convert_fps_to_none_if_invalid
),
)
source_software: str | None = field(
Expand All @@ -108,7 +123,7 @@ class ValidPosesDataset:
# Add validators
@position_array.validator
def _validate_position_array(self, attribute, value):
_ensure_type_ndarray(value)
_validate_type_ndarray(value)
if value.ndim != 4:
raise log_error(
ValueError,
Expand All @@ -125,14 +140,11 @@ def _validate_position_array(self, attribute, value):
@confidence_array.validator
def _validate_confidence_array(self, attribute, value):
if value is not None:
_ensure_type_ndarray(value)
expected_shape = self.position_array.shape[:-1]
if value.shape != expected_shape:
raise log_error(
ValueError,
f"Expected '{attribute.name}' to have shape "
f"{expected_shape}, but got {value.shape}.",
)
_validate_type_ndarray(value)

_validate_array_shape(
attribute, value, expected_shape=self.position_array.shape[:-1]
)

@individual_names.validator
def _validate_individual_names(self, attribute, value):
Expand Down Expand Up @@ -200,6 +212,11 @@ class ValidBboxesDataset:
If None (default), bounding boxes are assigned names based on the size
of the `position_array`. The names will be in the format of `id_<N>`,
where <N> is an integer from 1 to `position_array.shape[1]`.
frame_array : np.ndarray, optional
Array of shape (n_frames, 1) containing the frame numbers for which
bounding boxes are defined. If None (default), frame numbers will
be assigned based on the first dimension of the `position_array`,
starting from 0.
fps : float, optional
Frames per second defining the sampling rate of the data.
Defaults to None.
Expand All @@ -218,13 +235,14 @@ class ValidBboxesDataset:
individual_names: list[str] | None = field(
default=None,
converter=converters.optional(
_list_of_str
_convert_to_list_of_str
), # force into list of strings if not
)
frame_array: np.ndarray | None = field(default=None)
fps: float | None = field(
default=None,
converter=converters.pipe( # type: ignore
converters.optional(float), _set_fps_to_none_if_invalid
converters.optional(float), _convert_fps_to_none_if_invalid
),
)
source_software: str | None = field(
Expand All @@ -236,7 +254,7 @@ class ValidBboxesDataset:
@position_array.validator
@shape_array.validator
def _validate_position_and_shape_arrays(self, attribute, value):
_ensure_type_ndarray(value)
_validate_type_ndarray(value)

# check last dimension (spatial) has 2 coordinates
n_expected_spatial_coordinates = 2
Expand Down Expand Up @@ -268,14 +286,29 @@ def _validate_individual_names(self, attribute, value):
@confidence_array.validator
def _validate_confidence_array(self, attribute, value):
if value is not None:
_ensure_type_ndarray(value)
_validate_type_ndarray(value)

_validate_array_shape(
attribute, value, expected_shape=self.position_array.shape[:-1]
)

expected_shape = self.position_array.shape[:-1]
if value.shape != expected_shape:
@frame_array.validator
def _validate_frame_array(self, attribute, value):
if value is not None:
_validate_type_ndarray(value)

# should be a column vector (n_frames, 1)
_validate_array_shape(
attribute,
value,
expected_shape=(self.position_array.shape[0], 1),
)

# check frames are continuous: exactly one frame number per row
if not np.all(np.diff(value, axis=0) == 1):
raise log_error(
ValueError,
f"Expected '{attribute.name}' to have shape "
f"{expected_shape}, but got {value.shape}.",
f"Frame numbers in {attribute.name} are not continuous.",
)

# Define defaults
Expand All @@ -284,7 +317,7 @@ def __attrs_post_init__(self):

If no confidence_array is provided, set it to an array of NaNs.
If no individual names are provided, assign them unique IDs per frame,
starting with 1 ("id_1")
starting with 0 ("id_0").
"""
if self.confidence_array is None:
self.confidence_array = np.full(
Expand All @@ -293,17 +326,25 @@ def __attrs_post_init__(self):
dtype="float32",
)
log_warning(
"Confidence array was not provided."
"Confidence array was not provided. "
"Setting to an array of NaNs."
)

if self.individual_names is None:
self.individual_names = [
f"id_{i+1}" for i in range(self.position_array.shape[1])
f"id_{i}" for i in range(self.position_array.shape[1])
]
log_warning(
"Individual names for the bounding boxes "
"were not provided. "
"Setting to 1-based IDs that are unique per frame: \n"
"Setting to 0-based IDs that are unique per frame: \n"
f"{self.individual_names}.\n"
)

if self.frame_array is None:
n_frames = self.position_array.shape[0]
self.frame_array = np.arange(n_frames).reshape(-1, 1)
log_warning(
"Frame numbers were not provided. "
"Setting to an array of 0-based integers."
)
1 change: 1 addition & 0 deletions tests/test_unit/test_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def validate_metadata(metadata: dict[str, dict]) -> None:
"sha256sum",
"type",
"source_software",
"type",
"fps",
"species",
"number_of_individuals",
Expand Down
58 changes: 55 additions & 3 deletions tests/test_unit/test_validators/test_datasets_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_bboxes_dataset_validator_confidence_array(
):
"""Test that invalid confidence arrays raise the appropriate errors."""
with expected_exception as excinfo:
poses = ValidBboxesDataset(
ds = ValidBboxesDataset(
position_array=request.getfixturevalue("valid_bboxes_inputs")[
"position_array"
],
Expand All @@ -347,10 +347,62 @@ def test_bboxes_dataset_validator_confidence_array(
)
if confidence_array is None:
assert np.all(
np.isnan(poses.confidence_array)
np.isnan(ds.confidence_array)
) # assert it is a NaN array
assert (
poses.confidence_array.shape == poses.position_array.shape[:-1]
ds.confidence_array.shape == ds.position_array.shape[:-1]
) # assert shape matches position array
else:
assert str(excinfo.value) == log_message


@pytest.mark.parametrize(
"frame_array, expected_exception, log_message",
[
(
np.arange(10).reshape(-1, 2),
pytest.raises(ValueError),
"Expected 'frame_array' to have shape (10, 1), " "but got (5, 2).",
), # frame_array should be a column vector
(
[1, 2, 3],
pytest.raises(ValueError),
f"Expected a numpy array, but got {type(list())}.",
), # not an ndarray, should raise ValueError
(
np.array([1, 2, 3, 4, 6, 7, 8, 9, 10, 11]).reshape(-1, 1),
pytest.raises(ValueError),
"Frame numbers in frame_array are not continuous.",
), # frame numbers are not continuous
(
None,
does_not_raise(),
"",
), # valid, should return an array of frame numbers starting from 0
],
)
def test_bboxes_dataset_validator_frame_array(
frame_array, expected_exception, log_message, request
):
"""Test that invalid frame arrays raise the appropriate errors."""
with expected_exception as excinfo:
ds = ValidBboxesDataset(
position_array=request.getfixturevalue("valid_bboxes_inputs")[
"position_array"
],
shape_array=request.getfixturevalue("valid_bboxes_inputs")[
"shape_array"
],
individual_names=request.getfixturevalue("valid_bboxes_inputs")[
"individual_names"
],
frame_array=frame_array,
)

if frame_array is None:
n_frames = ds.position_array.shape[0]
default_frame_array = np.arange(n_frames).reshape(-1, 1)
assert np.array_equal(ds.frame_array, default_frame_array)
assert ds.frame_array.shape == (ds.position_array.shape[0], 1)
else:
assert str(excinfo.value) == log_message
Loading