Skip to content

Commit

Permalink
More cleanup (#2265)
Browse files Browse the repository at this point in the history
* More cleanup

* More cleanup
  • Loading branch information
ternaus authored Jan 8, 2025
1 parent 033f08f commit 24e5386
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 1,026 deletions.
45 changes: 12 additions & 33 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
denormalize_bboxes,
normalize_bboxes,
)
from albumentations.core.keypoints_utils import KeypointsProcessor
from albumentations.core.pydantic import (
InterpolationType,
NonNegativeFloatRangeType,
Expand Down Expand Up @@ -3208,9 +3207,7 @@ def __init__(
super().__init__(p=p)

self.name = name
self.custom_apply_fns = {
target_name: fmain.noop for target_name in ("image", "mask", "keypoints", "bboxes", "global_label")
}
self.custom_apply_fns = {target_name: fmain.noop for target_name in ("image", "mask", "keypoints", "bboxes")}
for target_name, custom_apply_fn in {
"image": image,
"mask": mask,
Expand Down Expand Up @@ -4165,18 +4162,11 @@ class RingingOvershoot(ImageOnlyTransform):

class InitSchema(BlurInitSchema):
blur_limit: ScaleIntType
cutoff: Annotated[tuple[float, float], nondecreasing]

@field_validator("cutoff")
@classmethod
def check_cutoff(
cls,
v: tuple[float, float],
info: ValidationInfo,
) -> tuple[float, float]:
bounds = 0, np.pi
check_range(v, *bounds, info.field_name)
return v
cutoff: Annotated[
tuple[float, float],
AfterValidator(check_range_bounds(0, np.pi)),
AfterValidator(nondecreasing),
]

def __init__(
self,
Expand Down Expand Up @@ -4482,15 +4472,7 @@ def apply_to_keypoints(
drop_mask: np.ndarray | None,
**params: Any,
) -> np.ndarray:
if drop_mask is None or self.per_channel:
return keypoints

processor = cast(KeypointsProcessor, self.get_processor("keypoints"))

if processor is None or not processor.params.remove_invisible:
return keypoints

return fdropout.mask_dropout_keypoints(keypoints, drop_mask)
return keypoints

def get_params_dependent_on_data(
self,
Expand Down Expand Up @@ -4600,11 +4582,11 @@ class Spatter(ImageOnlyTransform):
"""

class InitSchema(BaseTransformInitSchema):
mean: ZeroOneRangeType = (0.65, 0.65)
std: ZeroOneRangeType = (0.3, 0.3)
gauss_sigma: NonNegativeFloatRangeType = (2, 2)
cutout_threshold: ZeroOneRangeType = (0.68, 0.68)
intensity: ZeroOneRangeType = (0.6, 0.6)
mean: ZeroOneRangeType
std: ZeroOneRangeType
gauss_sigma: NonNegativeFloatRangeType
cutout_threshold: ZeroOneRangeType
intensity: ZeroOneRangeType
mode: SpatterMode | Sequence[SpatterMode]
color: Sequence[int] | dict[str, Sequence[int]] | None = None

Expand Down Expand Up @@ -6427,9 +6409,6 @@ class AutoContrast(ImageOnlyTransform):
uint8, float32
"""

class InitSchema(BaseTransformInitSchema):
pass

def __init__(
self,
p: float = 0.5,
Expand Down
19 changes: 10 additions & 9 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@
[
A.RandomRain,
{
"slant_lower": -5,
"slant_upper": 5,
"slant_range": (-5, 5),
"drop_length": 15,
"drop_width": 2,
"drop_color": (100, 100, 100),
Expand All @@ -57,7 +56,7 @@
"rain_type": "heavy",
},
],
[A.RandomFog, {"fog_coef_lower": 0.2, "fog_coef_upper": 0.8, "alpha_coef": 0.11}],
[A.RandomFog, {"fog_coef_range": (0.2, 0.8), "alpha_coef": 0.11}],
[
A.RandomSunFlare,
{
Expand Down Expand Up @@ -174,7 +173,7 @@
[A.SmallestMaxSize, {"max_size": 64, "interpolation": cv2.INTER_NEAREST}],
[A.LongestMaxSize, {"max_size": 128, "interpolation": cv2.INTER_NEAREST}],
[A.RandomGridShuffle, {"grid": (4, 4)}],
[A.Solarize, {"threshold": 32}],
[A.Solarize, {"threshold_range": [0.5, 0.5]}],
[A.Posterize, {"num_bits": (3, 5)}],
[A.Equalize, {"mode": "pil", "by_channels": False}],
[
Expand Down Expand Up @@ -264,12 +263,14 @@
"interpolation": cv2.INTER_AREA,
"mask_interpolation": cv2.INTER_NEAREST,
"absolute_scale": True,
"keypoints_threshold": 0.1,
},
],
[A.ChannelDropout, dict(channel_drop_range=(1, 2), fill=1)],
[A.ChannelShuffle, {}],
[A.Downscale, dict(scale_min=0.5, scale_max=0.75, interpolation=cv2.INTER_LINEAR)],
[A.Downscale, dict(scale_range=[0.5, 0.75], interpolation_pair={
"downscale": cv2.INTER_LINEAR,
"upscale": cv2.INTER_LINEAR,
})],
[A.FromFloat, dict(dtype="uint8", max_value=1)],
[A.HorizontalFlip, {}],
[A.ISONoise, dict(color_shift=(0.2, 0.3), intensity=(0.7, 0.9))],
Expand Down Expand Up @@ -354,7 +355,7 @@
A.GridDropout,
dict(
ratio=0.75,
unit_size_range=(2, 10),
holes_number_xy=(2, 10),
shift_xy=(10, 20),
random_offset=True,
fill=10,
Expand All @@ -370,8 +371,8 @@
A.TextImage,
dict(
font_path="./tests/files/LiberationSerif-Bold.ttf",
font_size_range=(0.8, 0.9),
color="red",
font_size_fraction_range=(0.8, 0.9),
font_color="red",
stopwords=[
"a",
"the",
Expand Down
55 changes: 21 additions & 34 deletions tests/files/transform_serialization_v2_with_totensor.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,27 @@
{
"__class_fullname__": "CoarseDropout",
"p": 0.5,
"max_holes": 8,
"max_height": 8,
"max_width": 8,
"min_holes": 8,
"min_height": 8,
"min_width": 8,
"num_holes_range": [8, 8],
"hole_height_range": [0.1, 0.3],
"hole_width_range": [0.1, 0.3],
"fill": 0,
"fill_mask": null
},
{
"__class_fullname__": "Downscale",
"p": 0.5,
"scale_min": 0.25,
"scale_max": 0.25,
"interpolation": 0
"scale_range": [0.25, 0.25],
"interpolation_pair": {
"downscale": 0,
"upscale": 0
}
},
{
"__class_fullname__": "ElasticTransform",
"p": 0.5,
"alpha": 1,
"sigma": 50,
"interpolation": 1,
"border_mode": 4,
"fill": null,
"fill_mask": null,
"approximate": false
},
{
Expand All @@ -81,9 +77,9 @@
{
"__class_fullname__": "GaussNoise",
"p": 0.5,
"var_limit": [
10.0,
50.0
"std_range": [
0.2,
0.44
]
},
{
Expand Down Expand Up @@ -113,10 +109,7 @@
-0.3,
0.3
],
"interpolation": 1,
"border_mode": 4,
"fill": null,
"fill_mask": null
"interpolation": 1
},
{
"__class_fullname__": "HorizontalFlip",
Expand Down Expand Up @@ -153,8 +146,7 @@
{
"__class_fullname__": "ImageCompression",
"p": 0.5,
"quality_lower": 99,
"quality_upper": 100,
"quality_range": [99, 100],
"compression_type": "jpeg"
},
{
Expand Down Expand Up @@ -258,8 +250,7 @@
{
"__class_fullname__": "RandomFog",
"p": 0.5,
"fog_coef_lower": 0.3,
"fog_coef_upper": 1,
"fog_coef_range": [0.3, 1],
"alpha_coef": 0.08
},
{
Expand All @@ -281,8 +272,7 @@
{
"__class_fullname__": "RandomRain",
"p": 0.5,
"slant_lower": -10,
"slant_upper": 10,
"slant_range": [-10, 10],
"drop_length": 20,
"drop_width": 1,
"drop_color": [
Expand Down Expand Up @@ -325,8 +315,7 @@
{
"__class_fullname__": "RandomSnow",
"p": 0.5,
"snow_point_lower": 0.1,
"snow_point_upper": 0.3,
"snow_point_range": [0.1, 0.3],
"brightness_coeff": 2.5
},
{
Expand All @@ -338,10 +327,8 @@
1,
0.5
],
"angle_lower": 0,
"angle_upper": 1,
"num_flare_circles_lower": 6,
"num_flare_circles_upper": 10,
"angle_range": [0, 1],
"num_flare_circles_range": [6, 10],
"src_radius": 400,
"src_color": [
255,
Expand Down Expand Up @@ -394,9 +381,9 @@
{
"__class_fullname__": "Solarize",
"p": 0.5,
"threshold": [
128,
128
"threshold_range": [
0.5,
0.5
]
},
{
Expand Down
Loading

0 comments on commit 24e5386

Please sign in to comment.