diff --git a/albumentations/augmentations/dropout/coarse_dropout.py b/albumentations/augmentations/dropout/coarse_dropout.py index 0d02ba4ab..407d61a81 100644 --- a/albumentations/augmentations/dropout/coarse_dropout.py +++ b/albumentations/augmentations/dropout/coarse_dropout.py @@ -4,14 +4,13 @@ from warnings import warn import numpy as np -from pydantic import AfterValidator, Field, model_validator -from typing_extensions import Self +from pydantic import AfterValidator import albumentations.augmentations.dropout.functional as fdropout from albumentations.augmentations.dropout.transforms import BaseDropout from albumentations.core.bbox_utils import denormalize_bboxes from albumentations.core.pydantic import check_range_bounds, nondecreasing -from albumentations.core.types import ColorType, DropoutFillValue, Number, ScalarType +from albumentations.core.types import ColorType, DropoutFillValue, ScalarType __all__ = ["CoarseDropout", "ConstrainedCoarseDropout", "Erasing"] @@ -92,91 +91,29 @@ class CoarseDropout(BaseDropout): """ class InitSchema(BaseDropout.InitSchema): - min_holes: int | None = Field(ge=0) - max_holes: int | None = Field(ge=0) num_holes_range: Annotated[ tuple[int, int], AfterValidator(check_range_bounds(1, None)), AfterValidator(nondecreasing), ] - min_height: ScalarType | None = Field(ge=0) - max_height: ScalarType | None = Field(ge=0) hole_height_range: Annotated[ tuple[ScalarType, ScalarType], AfterValidator(nondecreasing), AfterValidator(check_range_bounds(0, None)), ] - min_width: ScalarType | None = Field(ge=0) - max_width: ScalarType | None = Field(ge=0) hole_width_range: Annotated[ tuple[ScalarType, ScalarType], AfterValidator(nondecreasing), AfterValidator(check_range_bounds(0, None)), ] - @staticmethod - def update_range( - min_value: Number | None, - max_value: Number | None, - default_range: tuple[Number, Number], - ) -> tuple[Number, Number]: - return (min_value or max_value, max_value) if max_value is not None else default_range - - @staticmethod - def validate_range(range_value: tuple[float, float], range_name: str, minimum: float = 0) -> None: - if not minimum <= range_value[0] <= range_value[1]: - raise ValueError( - f"First value in {range_name} should be less or equal than the second value " - f"and at least {minimum}. Got: {range_value}", - ) - if isinstance(range_value[0], float) and not all(0 <= x <= 1 for x in range_value): - raise ValueError(f"All values in {range_name} should be in [0, 1] range. Got: {range_value}") - - @model_validator(mode="after") - def check_num_holes_and_dimensions(self) -> Self: - if self.min_holes is not None: - warn("`min_holes` is deprecated. Use num_holes_range instead.", DeprecationWarning, stacklevel=2) - if self.max_holes is not None: - warn("`max_holes` is deprecated. Use num_holes_range instead.", DeprecationWarning, stacklevel=2) - if self.min_height is not None: - warn("`min_height` is deprecated. Use hole_height_range instead.", DeprecationWarning, stacklevel=2) - if self.max_height is not None: - warn("`max_height` is deprecated. Use hole_height_range instead.", DeprecationWarning, stacklevel=2) - if self.min_width is not None: - warn("`min_width` is deprecated. Use hole_width_range instead.", DeprecationWarning, stacklevel=2) - if self.max_width is not None: - warn("`max_width` is deprecated. Use hole_width_range instead.", DeprecationWarning, stacklevel=2) - - if self.max_holes is not None: - self.num_holes_range = self.update_range(self.min_holes, self.max_holes, self.num_holes_range) - - self.validate_range(self.num_holes_range, "num_holes_range", minimum=1) - - if self.max_height is not None: - self.hole_height_range = self.update_range(self.min_height, self.max_height, self.hole_height_range) - self.validate_range(self.hole_height_range, "hole_height_range") - - if self.max_width is not None: - self.hole_width_range = self.update_range(self.min_width, self.max_width, self.hole_width_range) - self.validate_range(self.hole_width_range, "hole_width_range") - - return self - def __init__( self, - max_holes: int | None = None, - max_height: ScalarType | None = None, - max_width: ScalarType | None = None, - min_holes: int | None = None, - min_height: ScalarType | None = None, - min_width: ScalarType | None = None, - fill_value: DropoutFillValue | None = None, - mask_fill_value: ColorType | None = None, - num_holes_range: tuple[int, int] = (1, 1), - hole_height_range: tuple[ScalarType, ScalarType] = (8, 8), - hole_width_range: tuple[ScalarType, ScalarType] = (8, 8), + num_holes_range: tuple[int, int] = (1, 2), + hole_height_range: tuple[ScalarType, ScalarType] = (0.1, 0.2), + hole_width_range: tuple[ScalarType, ScalarType] = (0.1, 0.2), fill: DropoutFillValue = 0, fill_mask: ColorType | None = None, p: float = 0.5, diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6cd5c73a5..760e317d8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1223,52 +1223,6 @@ def test_image_compression_invalid_input(params): A.ImageCompression(**params) -@pytest.mark.parametrize( - "params, expected", - [ - # Default values - ( - {}, - { - "num_holes_range": (1, 1), - "hole_height_range": (8, 8), - "hole_width_range": (8, 8), - }, - ), - # Boundary values - ({"num_holes_range": (2, 3)}, {"num_holes_range": (2, 3)}), - ({"hole_height_range": (1, 12)}, {"hole_height_range": (1, 12)}), - ({"hole_width_range": (1, 12)}, {"hole_width_range": (1, 12)}), - # Random fill value - ({"fill": "random"}, {"fill": "random"}), - ({"fill": (255, 255, 255)}, {"fill": (255, 255, 255)}), - # Deprecated values handling - ({"min_holes": 1, "max_holes": 5}, {"num_holes_range": (1, 5)}), - ({"min_height": 2, "max_height": 6}, {"hole_height_range": (2, 6)}), - ({"min_width": 3, "max_width": 7}, {"hole_width_range": (3, 7)}), - ], -) -def test_coarse_dropout_functionality(params, expected): - aug = A.CoarseDropout(**params, p=1) - aug_dict = aug.to_dict()["transform"] - for key, value in expected.items(): - assert aug_dict[key] == value, f"Failed on {key} with value {value}" - - -@pytest.mark.parametrize( - "params", - [ - ({"num_holes_range": (5, 1)}), # Invalid range - ({"num_holes_range": (0, 3)}), # Invalid range - ({"hole_height_range": (2.1, 3)}), # Invalid type - ({"hole_height_range": ("a", "b")}), # Invalid type - ], -) -def test_coarse_dropout_invalid_input(params): - with pytest.raises(Exception): - _ = A.CoarseDropout(**params, p=1) - - @pytest.mark.parametrize( ["augmentation_cls", "params"], get_2d_transforms(