Skip to content

Commit

Permalink
Updated CoarseDropout (#2241)
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus authored Jan 7, 2025
1 parent 62c5499 commit 327a6b2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 114 deletions.
73 changes: 5 additions & 68 deletions albumentations/augmentations/dropout/coarse_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from warnings import warn

import numpy as np
from pydantic import AfterValidator, Field, model_validator
from typing_extensions import Self
from pydantic import AfterValidator

import albumentations.augmentations.dropout.functional as fdropout
from albumentations.augmentations.dropout.transforms import BaseDropout
from albumentations.core.bbox_utils import denormalize_bboxes
from albumentations.core.pydantic import check_range_bounds, nondecreasing
from albumentations.core.types import ColorType, DropoutFillValue, Number, ScalarType
from albumentations.core.types import ColorType, DropoutFillValue, ScalarType

__all__ = ["CoarseDropout", "ConstrainedCoarseDropout", "Erasing"]

Expand Down Expand Up @@ -92,91 +91,29 @@ class CoarseDropout(BaseDropout):
"""

class InitSchema(BaseDropout.InitSchema):
min_holes: int | None = Field(ge=0)
max_holes: int | None = Field(ge=0)
num_holes_range: Annotated[
tuple[int, int],
AfterValidator(check_range_bounds(1, None)),
AfterValidator(nondecreasing),
]

min_height: ScalarType | None = Field(ge=0)
max_height: ScalarType | None = Field(ge=0)
hole_height_range: Annotated[
tuple[ScalarType, ScalarType],
AfterValidator(nondecreasing),
AfterValidator(check_range_bounds(0, None)),
]

min_width: ScalarType | None = Field(ge=0)
max_width: ScalarType | None = Field(ge=0)
hole_width_range: Annotated[
tuple[ScalarType, ScalarType],
AfterValidator(nondecreasing),
AfterValidator(check_range_bounds(0, None)),
]

@staticmethod
def update_range(
min_value: Number | None,
max_value: Number | None,
default_range: tuple[Number, Number],
) -> tuple[Number, Number]:
return (min_value or max_value, max_value) if max_value is not None else default_range

@staticmethod
def validate_range(range_value: tuple[float, float], range_name: str, minimum: float = 0) -> None:
if not minimum <= range_value[0] <= range_value[1]:
raise ValueError(
f"First value in {range_name} should be less or equal than the second value "
f"and at least {minimum}. Got: {range_value}",
)
if isinstance(range_value[0], float) and not all(0 <= x <= 1 for x in range_value):
raise ValueError(f"All values in {range_name} should be in [0, 1] range. Got: {range_value}")

@model_validator(mode="after")
def check_num_holes_and_dimensions(self) -> Self:
if self.min_holes is not None:
warn("`min_holes` is deprecated. Use num_holes_range instead.", DeprecationWarning, stacklevel=2)
if self.max_holes is not None:
warn("`max_holes` is deprecated. Use num_holes_range instead.", DeprecationWarning, stacklevel=2)
if self.min_height is not None:
warn("`min_height` is deprecated. Use hole_height_range instead.", DeprecationWarning, stacklevel=2)
if self.max_height is not None:
warn("`max_height` is deprecated. Use hole_height_range instead.", DeprecationWarning, stacklevel=2)
if self.min_width is not None:
warn("`min_width` is deprecated. Use hole_width_range instead.", DeprecationWarning, stacklevel=2)
if self.max_width is not None:
warn("`max_width` is deprecated. Use hole_width_range instead.", DeprecationWarning, stacklevel=2)

if self.max_holes is not None:
self.num_holes_range = self.update_range(self.min_holes, self.max_holes, self.num_holes_range)

self.validate_range(self.num_holes_range, "num_holes_range", minimum=1)

if self.max_height is not None:
self.hole_height_range = self.update_range(self.min_height, self.max_height, self.hole_height_range)
self.validate_range(self.hole_height_range, "hole_height_range")

if self.max_width is not None:
self.hole_width_range = self.update_range(self.min_width, self.max_width, self.hole_width_range)
self.validate_range(self.hole_width_range, "hole_width_range")

return self

def __init__(
self,
max_holes: int | None = None,
max_height: ScalarType | None = None,
max_width: ScalarType | None = None,
min_holes: int | None = None,
min_height: ScalarType | None = None,
min_width: ScalarType | None = None,
fill_value: DropoutFillValue | None = None,
mask_fill_value: ColorType | None = None,
num_holes_range: tuple[int, int] = (1, 1),
hole_height_range: tuple[ScalarType, ScalarType] = (8, 8),
hole_width_range: tuple[ScalarType, ScalarType] = (8, 8),
num_holes_range: tuple[int, int] = (1, 2),
hole_height_range: tuple[ScalarType, ScalarType] = (0.1, 0.2),
hole_width_range: tuple[ScalarType, ScalarType] = (0.1, 0.2),
fill: DropoutFillValue = 0,
fill_mask: ColorType | None = None,
p: float = 0.5,
Expand Down
46 changes: 0 additions & 46 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,52 +1223,6 @@ def test_image_compression_invalid_input(params):
A.ImageCompression(**params)


@pytest.mark.parametrize(
"params, expected",
[
# Default values
(
{},
{
"num_holes_range": (1, 1),
"hole_height_range": (8, 8),
"hole_width_range": (8, 8),
},
),
# Boundary values
({"num_holes_range": (2, 3)}, {"num_holes_range": (2, 3)}),
({"hole_height_range": (1, 12)}, {"hole_height_range": (1, 12)}),
({"hole_width_range": (1, 12)}, {"hole_width_range": (1, 12)}),
# Random fill value
({"fill": "random"}, {"fill": "random"}),
({"fill": (255, 255, 255)}, {"fill": (255, 255, 255)}),
# Deprecated values handling
({"min_holes": 1, "max_holes": 5}, {"num_holes_range": (1, 5)}),
({"min_height": 2, "max_height": 6}, {"hole_height_range": (2, 6)}),
({"min_width": 3, "max_width": 7}, {"hole_width_range": (3, 7)}),
],
)
def test_coarse_dropout_functionality(params, expected):
aug = A.CoarseDropout(**params, p=1)
aug_dict = aug.to_dict()["transform"]
for key, value in expected.items():
assert aug_dict[key] == value, f"Failed on {key} with value {value}"


@pytest.mark.parametrize(
"params",
[
({"num_holes_range": (5, 1)}), # Invalid range
({"num_holes_range": (0, 3)}), # Invalid range
({"hole_height_range": (2.1, 3)}), # Invalid type
({"hole_height_range": ("a", "b")}), # Invalid type
],
)
def test_coarse_dropout_invalid_input(params):
with pytest.raises(Exception):
_ = A.CoarseDropout(**params, p=1)


@pytest.mark.parametrize(
["augmentation_cls", "params"],
get_2d_transforms(
Expand Down

0 comments on commit 327a6b2

Please sign in to comment.