Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated RandomShadow #2261

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 0 additions & 47 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,8 +1378,6 @@ class InitSchema(BaseTransformInitSchema):
AfterValidator(check_range_bounds(1, None)),
AfterValidator(nondecreasing),
]
num_shadows_lower: int | None
num_shadows_upper: int | None
shadow_dimension: int = Field(ge=3)

shadow_intensity_range: Annotated[
Expand All @@ -1390,62 +1388,17 @@ class InitSchema(BaseTransformInitSchema):

@model_validator(mode="after")
def validate_shadows(self) -> Self:
if self.num_shadows_lower is not None:
warn(
"`num_shadows_lower` is deprecated. Use `num_shadows_limit` instead.",
DeprecationWarning,
stacklevel=2,
)

if self.num_shadows_upper is not None:
warn(
"`num_shadows_upper` is deprecated. Use `num_shadows_limit` instead.",
DeprecationWarning,
stacklevel=2,
)

if self.num_shadows_lower is not None or self.num_shadows_upper is not None:
num_shadows_lower = (
self.num_shadows_lower if self.num_shadows_lower is not None else self.num_shadows_limit[0]
)
num_shadows_upper = (
self.num_shadows_upper if self.num_shadows_upper is not None else self.num_shadows_limit[1]
)

self.num_shadows_limit = (num_shadows_lower, num_shadows_upper)
self.num_shadows_lower = None
self.num_shadows_upper = None

shadow_lower_x, shadow_lower_y, shadow_upper_x, shadow_upper_y = self.shadow_roi

if not 0 <= shadow_lower_x <= shadow_upper_x <= 1 or not 0 <= shadow_lower_y <= shadow_upper_y <= 1:
raise ValueError(f"Invalid shadow_roi. Got: {self.shadow_roi}")

if isinstance(self.shadow_intensity_range, float):
if not (0 <= self.shadow_intensity_range <= 1):
raise ValueError(
f"shadow_intensity_range value should be within [0, 1] range. "
f"Got: {self.shadow_intensity_range}",
)
elif isinstance(self.shadow_intensity_range, tuple):
if not (0 <= self.shadow_intensity_range[0] <= self.shadow_intensity_range[1] <= 1):
raise ValueError(
f"shadow_intensity_range values should be within [0, 1] range and increasing. "
f"Got: {self.shadow_intensity_range}",
)
else:
raise TypeError(
"shadow_intensity_range should be an float or a tuple of floats.",
)

return self

def __init__(
self,
shadow_roi: tuple[float, float, float, float] = (0, 0.5, 1, 1),
num_shadows_limit: tuple[int, int] = (1, 2),
num_shadows_lower: int | None = None,
num_shadows_upper: int | None = None,
shadow_dimension: int = 5,
shadow_intensity_range: tuple[float, float] = (0.5, 0.5),
p: float = 0.5,
Expand Down
50 changes: 0 additions & 50 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,56 +1085,6 @@ def test_non_rgb_transform_warning(augmentation, img_channels):
assert str(exc_info.value).startswith(message)


@pytest.mark.parametrize(
"num_shadows_limit, num_shadows_lower, num_shadows_upper, expected_warning",
[
((1, 1), None, None, None),
((1, 2), None, None, None),
((2, 3), None, None, None),
((1, 2), 1, None, DeprecationWarning),
((1, 2), None, 2, DeprecationWarning),
((1, 2), 1, 2, DeprecationWarning),
((2, 1), None, None, ValueError),
],
)
def test_deprecation_warnings_random_shadow(
num_shadows_limit: tuple[int, int],
num_shadows_lower: int | None,
num_shadows_upper: int | None,
expected_warning: Type[Warning] | None,
) -> None:
"""Test deprecation warnings for RandomShadow"""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Change the filter to capture all warnings
if expected_warning == ValueError:
with pytest.raises(ValueError):
A.RandomShadow(
num_shadows_limit=num_shadows_limit,
num_shadows_lower=num_shadows_lower,
num_shadows_upper=num_shadows_upper,
p=1,
)
elif expected_warning is DeprecationWarning:
A.RandomShadow(
num_shadows_limit=num_shadows_limit,
num_shadows_lower=num_shadows_lower,
num_shadows_upper=num_shadows_upper,
p=1,
)
for warning in w:
print(
f"Warning captured: {warning.category.__name__}, Message: '{warning.message}'"
)

if warning.category is DeprecationWarning:
print(f"Deprecation Warning: {warning.message}")
assert any(
issubclass(warning.category, DeprecationWarning) for warning in w
), "No DeprecationWarning found"
else:
assert not w, "Unexpected warnings raised"


@pytest.mark.parametrize("image", IMAGES)
@pytest.mark.parametrize(
"grid",
Expand Down
Loading