Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated GaussNoise and deleted traget_as_params #2268

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 1 addition & 47 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections.abc import Sequence
from types import LambdaType
from typing import Annotated, Any, Callable, Union, cast
from warnings import warn

import albucore
import cv2
Expand Down Expand Up @@ -2219,9 +2218,6 @@ class GaussNoise(ImageOnlyTransform):
- The noise parameters (std_range and mean_range) are normalized to [0, 1] range:
* For uint8 images, they are multiplied by 255
* For float32 images, they are used directly
- The behavior differs between old and new parameters:
* When using var_limit (deprecated): samples variance uniformly and takes sqrt to get std dev
* When using std_range: samples standard deviation directly (aligned with torchvision/kornia)
- Setting per_channel=False is faster but applies the same noise to all channels
- The noise_scale_factor parameter allows for a trade-off between transform speed and noise granularity

Expand All @@ -2233,15 +2229,9 @@ class GaussNoise(ImageOnlyTransform):
>>> # Apply Gaussian noise with normalized std_range
>>> transform = A.GaussNoise(std_range=(0.1, 0.2), p=1.0) # 10-20% of max value
>>> noisy_image = transform(image=image)['image']
>>>
>>> # Using deprecated var_limit (will be converted to std_range)
>>> transform = A.GaussNoise(var_limit=(50.0, 100.0), mean=10, p=1.0)
>>> noisy_image = transform(image=image)['image']
"""

class InitSchema(BaseTransformInitSchema):
var_limit: ScaleFloatType | None
mean: float | None
std_range: Annotated[
tuple[float, float],
AfterValidator(check_range_bounds(0, 1)),
Expand All @@ -2255,36 +2245,8 @@ class InitSchema(BaseTransformInitSchema):
per_channel: bool
noise_scale_factor: float = Field(gt=0, le=1)

@model_validator(mode="after")
def check_range(self) -> Self:
if self.var_limit is not None:
warnings.warn("`var_limit` deprecated. Use `std_range` instead.", DeprecationWarning, stacklevel=2)
self.var_limit = to_tuple(self.var_limit, 0)
if self.var_limit[1] > 1:
# Convert legacy uint8 variance to normalized std dev
self.std_range = (math.sqrt(10 / 255), math.sqrt(50 / 255))
else:
# Already normalized variance, convert to std dev
self.std_range = (
math.sqrt(self.var_limit[0]),
math.sqrt(self.var_limit[1]),
)

if self.mean is not None:
warn("`mean` deprecated. Use `mean_range` instead.", DeprecationWarning, stacklevel=2)
if self.mean >= 1:
# Convert legacy uint8 mean to normalized range
self.mean_range = (self.mean / 255, self.mean / 255)
else:
# Already normalized mean
self.mean_range = (self.mean, self.mean)

return self

def __init__(
self,
var_limit: ScaleFloatType | None = None,
mean: float | None = None,
std_range: tuple[float, float] = (0.2, 0.44), # sqrt(10 / 255), sqrt(50 / 255)
mean_range: tuple[float, float] = (0.0, 0.0),
per_channel: bool = True,
Expand All @@ -2297,8 +2259,6 @@ def __init__(
self.per_channel = per_channel
self.noise_scale_factor = noise_scale_factor

self.var_limit = var_limit

def apply(
self,
img: np.ndarray,
Expand All @@ -2315,13 +2275,7 @@ def get_params_dependent_on_data(
image = data["image"] if "image" in data else data["images"][0]
max_value = MAX_VALUES_BY_DTYPE[image.dtype]

if self.var_limit is not None:
# Legacy behavior: sample variance uniformly then take sqrt
var = self.py_random.uniform(self.std_range[0] ** 2, self.std_range[1] ** 2)
sigma = math.sqrt(var)
else:
# New behavior: sample std dev directly (aligned with torchvision/kornia)
sigma = self.py_random.uniform(*self.std_range)
sigma = self.py_random.uniform(*self.std_range)

mean = self.py_random.uniform(*self.mean_range)

Expand Down
15 changes: 0 additions & 15 deletions albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,6 @@ def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Any:
params_dependent_on_data = self.get_params_dependent_on_data(params=params, data=kwargs)
params.update(params_dependent_on_data)

if self.targets_as_params: # this block will be removed after removing `get_params_dependent_on_targets`
targets_as_params = {k: kwargs.get(k) for k in self.targets_as_params}
if missing_keys: # here we expecting case when missing_keys == {"image"} and "images" in kwargs
targets_as_params["image"] = kwargs["images"][0]
params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params)
params.update(params_dependent_on_targets)

# Store the final params
self.params = params

Expand Down Expand Up @@ -337,14 +330,6 @@ def targets_as_params(self) -> list[str]:
"""
return []

def get_params_dependent_on_targets(self, params: dict[str, Any]) -> dict[str, Any]:
"""This method is deprecated.
Use `get_params_dependent_on_data` instead.
Returns parameters dependent on targets.
Dependent target is defined in `self.targets_as_params`
"""
return {}

@classmethod
def get_class_fullname(cls) -> str:
return get_shortest_class_fullname(cls)
Expand Down
Loading