From 530cc1f35e40e50d4b35de30f75b75ef6fb78d08 Mon Sep 17 00:00:00 2001 From: Eloi Date: Thu, 7 Nov 2024 09:30:40 +0100 Subject: [PATCH] Fix ImageFilter to allow Gaussian filter without filter_size (#8189) Fixes #8127 Update `ImageFilter` to handle Gaussian filter without requiring `filter_size`. * Modify `monai/transforms/utility/array.py` to allow Gaussian filter without `filter_size`. - Adjust `_check_filter_format` method to skip `filter_size` check for Gaussian filter. Indeed Gauss filter is the only one in the list that doesn't require a filter_size. * Add unit test in `tests/test_image_filter.py` for Gaussian filter without `filter_size`. - Verify output shape matches input shape. Note that this method is compliant with the dictionnary version since this one load the fixed version. Signed-off-by: Eloi --------- Signed-off-by: Eloi Navet Signed-off-by: Eloi Signed-off-by: Eloi eloi.navet@gmail.com --- monai/transforms/utility/array.py | 4 ++-- tests/test_image_filter.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 72dd189009..1b3c59afdb 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1609,9 +1609,9 @@ def _check_all_values_uneven(self, x: tuple) -> None: def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None) -> None: if isinstance(filter, str): - if not filter_size: + if filter != "gauss" and not filter_size: # Gauss is the only filter that does not require `filter_size` raise ValueError("`filter_size` must be specified when specifying filters by string.") - if filter_size % 2 == 0: + if filter_size and filter_size % 2 == 0: raise ValueError("`filter_size` should be a single uneven integer.") if filter not in self.supported_filters: raise NotImplementedError(f"{filter}. Supported filters are {self.supported_filters}.") diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 76e38d94f4..fb08b2295d 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -134,6 +134,12 @@ def test_pass_empty_metadata_dict(self): out_tensor = filter(image) self.assertTrue(isinstance(out_tensor, MetaTensor)) + def test_gaussian_filter_without_filter_size(self): + "Test Gaussian filter without specifying filter_size" + filter = ImageFilter("gauss", sigma=2) + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + class TestImageFilterDict(unittest.TestCase):