diff --git a/tests/test_unit/test_filtering.py b/tests/test_unit/test_filtering.py index 3c7cafae..11a3b9ea 100644 --- a/tests/test_unit/test_filtering.py +++ b/tests/test_unit/test_filtering.py @@ -10,10 +10,36 @@ savgol_filter, ) +# Dataset fixtures +valid_datasets_without_nans = [ + "valid_poses_dataset", + "valid_bboxes_dataset", +] +valid_datasets_with_nans = [ + f"{dataset}_with_nan" for dataset in valid_datasets_without_nans +] +all_valid_datasets = valid_datasets_without_nans + valid_datasets_with_nans + + +# Expected number of nans in the position array per individual, +# for each dataset +expected_n_nans_in_position_per_indiv = { + "valid_poses_dataset": {0: 0, 1: 0}, + # filtering should not introduce nans if input has no nans + "valid_bboxes_dataset": {0: 0, 1: 0}, + # filtering should not introduce nans if input has no nans + "valid_poses_dataset_with_nan": {0: 7, 1: 0}, + # individual with index 0 has 7 frames with nans in position after + # filtering individual with index 1 has no nans after filtering + "valid_bboxes_dataset_with_nan": {0: 7, 1: 0}, + # individual with index 0 has 7 frames with nans in position after + # filtering individual with index 0 has no nans after filtering +} + @pytest.mark.parametrize( "valid_dataset_with_nan", - ["valid_poses_dataset_with_nan", "valid_bboxes_dataset_with_nan"], + valid_datasets_with_nans, ) @pytest.mark.parametrize( "max_gap, expected_n_nans_in_position", [(None, 0), (0, 3), (1, 2), (2, 0)] @@ -106,12 +132,7 @@ def test_filter_by_confidence_on_position( @pytest.mark.parametrize( "valid_dataset", - [ - "valid_poses_dataset", - "valid_bboxes_dataset", - "valid_poses_dataset_with_nan", - "valid_bboxes_dataset_with_nan", - ], + all_valid_datasets, ) @pytest.mark.parametrize("window_size", [2, 4]) def test_median_filter_on_position(valid_dataset, window_size, request): @@ -134,30 +155,7 @@ def test_median_filter_on_position(valid_dataset, window_size, request): @pytest.mark.parametrize( ("valid_dataset, expected_n_nans_in_position_per_indiv"), - [ - ( - "valid_poses_dataset", - {0: 0, 1: 0}, - ), # median filtering should not introduce nans if input has no nans - ( - "valid_bboxes_dataset", - {0: 0, 1: 0}, - ), # median filtering should not introduce nans if input has no nans - ( - "valid_poses_dataset_with_nan", - {0: 7, 1: 0}, - ), - # individual with index 0 has 7 frames with nans in position after - # filtering - # individual with index 1 has no nans after filtering - ( - "valid_bboxes_dataset_with_nan", - {0: 7, 1: 0}, - ), - # individual with index 0 has 7 frames with nans in position after - # filtering - # individual with index 0 has no nans after filtering - ], + [(k, v) for k, v in expected_n_nans_in_position_per_indiv.items()], ) def test_median_filter_with_nans_on_position( valid_dataset, @@ -203,12 +201,7 @@ def test_median_filter_with_nans_on_position( @pytest.mark.parametrize( "valid_dataset", - [ - "valid_poses_dataset", - "valid_bboxes_dataset", - "valid_poses_dataset_with_nan", - "valid_bboxes_dataset_with_nan", - ], + all_valid_datasets, ) @pytest.mark.parametrize("window, polyorder", [(2, 1), (4, 2)]) def test_savgol_filter_on_position(valid_dataset, window, polyorder, request): @@ -233,30 +226,7 @@ def test_savgol_filter_on_position(valid_dataset, window, polyorder, request): @pytest.mark.parametrize( ("valid_dataset, expected_n_nans_in_position_per_indiv"), - [ - ( - "valid_poses_dataset", - {0: 0, 1: 0}, - ), # median filtering should not introduce nans if input has no nans - ( - "valid_bboxes_dataset", - {0: 0, 1: 0}, - ), # median filtering should not introduce nans if input has no nans - ( - "valid_poses_dataset_with_nan", - {0: 7, 1: 0}, - ), - # individual with index 0 has 7 frames with nans in position after - # filtering - # individual with index 1 has no nans after filtering - ( - "valid_bboxes_dataset_with_nan", - {0: 7, 1: 0}, - ), - # individual with index 0 has 7 frames with nans in position after - # filtering - # individual with index 0 has no nans after filtering - ], + [(k, v) for k, v in expected_n_nans_in_position_per_indiv.items()], ) def test_savgol_filter_with_nans_on_position( valid_dataset, expected_n_nans_in_position_per_indiv, helpers, request @@ -301,12 +271,7 @@ def test_savgol_filter_with_nans_on_position( @pytest.mark.parametrize( "valid_dataset", - [ - "valid_poses_dataset", - "valid_bboxes_dataset", - "valid_poses_dataset_with_nan", - "valid_bboxes_dataset_with_nan", - ], + all_valid_datasets, ) @pytest.mark.parametrize( "override_kwargs",