Skip to content

Commit

Permalink
Updated RandomShadow (#2261)
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus authored Jan 8, 2025
1 parent b05b828 commit 062665a
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 97 deletions.
47 changes: 0 additions & 47 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,8 +1378,6 @@ class InitSchema(BaseTransformInitSchema):
AfterValidator(check_range_bounds(1, None)),
AfterValidator(nondecreasing),
]
num_shadows_lower: int | None
num_shadows_upper: int | None
shadow_dimension: int = Field(ge=3)

shadow_intensity_range: Annotated[
Expand All @@ -1390,62 +1388,17 @@ class InitSchema(BaseTransformInitSchema):

@model_validator(mode="after")
def validate_shadows(self) -> Self:
if self.num_shadows_lower is not None:
warn(
"`num_shadows_lower` is deprecated. Use `num_shadows_limit` instead.",
DeprecationWarning,
stacklevel=2,
)

if self.num_shadows_upper is not None:
warn(
"`num_shadows_upper` is deprecated. Use `num_shadows_limit` instead.",
DeprecationWarning,
stacklevel=2,
)

if self.num_shadows_lower is not None or self.num_shadows_upper is not None:
num_shadows_lower = (
self.num_shadows_lower if self.num_shadows_lower is not None else self.num_shadows_limit[0]
)
num_shadows_upper = (
self.num_shadows_upper if self.num_shadows_upper is not None else self.num_shadows_limit[1]
)

self.num_shadows_limit = (num_shadows_lower, num_shadows_upper)
self.num_shadows_lower = None
self.num_shadows_upper = None

shadow_lower_x, shadow_lower_y, shadow_upper_x, shadow_upper_y = self.shadow_roi

if not 0 <= shadow_lower_x <= shadow_upper_x <= 1 or not 0 <= shadow_lower_y <= shadow_upper_y <= 1:
raise ValueError(f"Invalid shadow_roi. Got: {self.shadow_roi}")

if isinstance(self.shadow_intensity_range, float):
if not (0 <= self.shadow_intensity_range <= 1):
raise ValueError(
f"shadow_intensity_range value should be within [0, 1] range. "
f"Got: {self.shadow_intensity_range}",
)
elif isinstance(self.shadow_intensity_range, tuple):
if not (0 <= self.shadow_intensity_range[0] <= self.shadow_intensity_range[1] <= 1):
raise ValueError(
f"shadow_intensity_range values should be within [0, 1] range and increasing. "
f"Got: {self.shadow_intensity_range}",
)
else:
raise TypeError(
"shadow_intensity_range should be an float or a tuple of floats.",
)

return self

def __init__(
self,
shadow_roi: tuple[float, float, float, float] = (0, 0.5, 1, 1),
num_shadows_limit: tuple[int, int] = (1, 2),
num_shadows_lower: int | None = None,
num_shadows_upper: int | None = None,
shadow_dimension: int = 5,
shadow_intensity_range: tuple[float, float] = (0.5, 0.5),
p: float = 0.5,
Expand Down
50 changes: 0 additions & 50 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,56 +1085,6 @@ def test_non_rgb_transform_warning(augmentation, img_channels):
assert str(exc_info.value).startswith(message)


@pytest.mark.parametrize(
"num_shadows_limit, num_shadows_lower, num_shadows_upper, expected_warning",
[
((1, 1), None, None, None),
((1, 2), None, None, None),
((2, 3), None, None, None),
((1, 2), 1, None, DeprecationWarning),
((1, 2), None, 2, DeprecationWarning),
((1, 2), 1, 2, DeprecationWarning),
((2, 1), None, None, ValueError),
],
)
def test_deprecation_warnings_random_shadow(
num_shadows_limit: tuple[int, int],
num_shadows_lower: int | None,
num_shadows_upper: int | None,
expected_warning: Type[Warning] | None,
) -> None:
"""Test deprecation warnings for RandomShadow"""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Change the filter to capture all warnings
if expected_warning == ValueError:
with pytest.raises(ValueError):
A.RandomShadow(
num_shadows_limit=num_shadows_limit,
num_shadows_lower=num_shadows_lower,
num_shadows_upper=num_shadows_upper,
p=1,
)
elif expected_warning is DeprecationWarning:
A.RandomShadow(
num_shadows_limit=num_shadows_limit,
num_shadows_lower=num_shadows_lower,
num_shadows_upper=num_shadows_upper,
p=1,
)
for warning in w:
print(
f"Warning captured: {warning.category.__name__}, Message: '{warning.message}'"
)

if warning.category is DeprecationWarning:
print(f"Deprecation Warning: {warning.message}")
assert any(
issubclass(warning.category, DeprecationWarning) for warning in w
), "No DeprecationWarning found"
else:
assert not w, "Unexpected warnings raised"


@pytest.mark.parametrize("image", IMAGES)
@pytest.mark.parametrize(
"grid",
Expand Down

0 comments on commit 062665a

Please sign in to comment.