Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding A.AtLeastOneBBoxRandomCrop #2207

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| Transform | Image | Mask | BBoxes | Keypoints | Volume | Mask3D |
| ------------------------------------------------------------------------------------------------ | :---: | :--: | :----: | :-------: | :----: | :----: |
| [Affine](https://explore.albumentations.ai/transform/Affine) | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
| [AtLeastOneBBoxRandomCrop](https://explore.albumentations.ai/transform/AtLeastOneBBoxRandomCrop) | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
| [BBoxSafeRandomCrop](https://explore.albumentations.ai/transform/BBoxSafeRandomCrop) | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
| [CenterCrop](https://explore.albumentations.ai/transform/CenterCrop) | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
| [CoarseDropout](https://explore.albumentations.ai/transform/CoarseDropout) | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
Expand Down
111 changes: 111 additions & 0 deletions albumentations/augmentations/crops/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from . import functional as fcrops

__all__ = [
"AtLeastOneBBoxRandomCrop",
"BBoxSafeRandomCrop",
"CenterCrop",
"Crop",
Expand Down Expand Up @@ -1989,3 +1990,113 @@ def get_params_dependent_on_data(

def get_transform_init_args_names(self) -> tuple[str, ...]:
return "crop_left", "crop_right", "crop_top", "crop_bottom"


class AtLeastOneBBoxRandomCrop(BaseCrop):
ternaus marked this conversation as resolved.
Show resolved Hide resolved
"""Crops an image to a fixed resolution, while ensuring that at least one bounding box is always in the crop.
The maximal erosion factor define by how much the target bounding box can be thinned out.
For example, erosion_factor = 0.2 means that the bounding box dimensions can be thinned by up to 20%.

Args:
height: Height of the crop.
width: Width of the crop.
erosion_factor: Maximal erosion factor of the height and width of the target bounding box. Default: 0.0.
p: The probability of applying the transform. Default: 1.0.
always_apply: Whether to apply the transform systematically.
ternaus marked this conversation as resolved.
Show resolved Hide resolved

Targets:
image, mask, bboxes, keypoints, volume, mask3d

Image types:
uint8, float32
"""

_targets = ALL_TARGETS

class InitSchema(BaseCrop.InitSchema):
height: Annotated[int, Field(ge=1)]
width: Annotated[int, Field(ge=1)]
erosion_factor: Annotated[float, Field(ge=0.0, le=1.0)]

def __init__(
self,
height: int,
width: int,
erosion_factor: float = 0.0,
p: float = 1.0,
always_apply: bool | None = None,
):
super().__init__(p=p, always_apply=always_apply)
self.height = height
self.width = width
self.erosion_factor = erosion_factor

def get_params_dependent_on_data(
self,
params: dict[str, Any],
data: dict[str, Any],
) -> dict[str, tuple[int, int, int, int]]:
image_height, image_width = params["shape"][:2]
bboxes = data.get("bboxes", [])

if self.height > image_height or self.width > image_width:
raise CropSizeError(
f"Crop size (height, width) exceeds image dimensions (height, width):"
f" {(self.height, self.width)} vs {image_height, image_width}",
)

if len(bboxes) > 0:
# Pick a bbox amongst all possible as our reference bbox.
bboxes = denormalize_bboxes(bboxes, image_shape=(image_height, image_width))
bbox = self.py_random.choice(bboxes)

x1, y1, x2, y2 = bbox[:4]

w = x2 - x1
h = y2 - y1

# Compute the eroded width and height
ew = w * (1.0 - self.erosion_factor)
eh = h * (1.0 - self.erosion_factor)

# Compute the lower and upper bounds for the x-axis and y-axis.
ax1 = np.clip(
ternaus marked this conversation as resolved.
Show resolved Hide resolved
a=x1 + ew - self.width,
a_min=0.0,
a_max=image_width - self.width,
)
bx1 = np.clip(
a=x2 - ew,
a_min=0.0,
a_max=image_width - self.width,
)

ay1 = np.clip(
a=y1 + eh - self.height,
a_min=0.0,
a_max=image_height - self.height,
)
by1 = np.clip(
a=y2 - eh,
a_min=0.0,
a_max=image_height - self.height,
)
else:
# If there are no bboxes, just crop anywhere in the image.
ax1 = 0.0
bx1 = image_width - self.width

ay1 = 0.0
by1 = image_height - self.height

# Randomly draw the upper-left corner.
x1 = int(self.py_random.uniform(a=ax1, b=bx1))
y1 = int(self.py_random.uniform(a=ay1, b=by1))

x2 = x1 + self.width
y2 = y1 + self.height

return {"crop_coords": (x1, y1, x2, y2)}

def get_transform_init_args_names(self) -> tuple[str, ...]:
return "height", "width", "erosion_factor"
1 change: 1 addition & 0 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
],
[A.CenterCrop, {"height": 10, "width": 10}],
[A.RandomCrop, {"height": 10, "width": 10}],
[A.AtLeastOneBBoxRandomCrop, {"height": 10, "width": 10}],
[A.CropNonEmptyMaskIfExists, {"height": 10, "width": 10}],
[A.RandomSizedCrop, {"min_max_height": (4, 8), "height": 10, "width": 10}],
[A.Crop, {"x_max": 64, "y_max": 64}],
Expand Down
11 changes: 11 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_image_only_augmentations(augmentation_cls, params):
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -157,6 +158,7 @@ def test_dual_augmentations(augmentation_cls, params):
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -211,6 +213,7 @@ def test_dual_augmentations_with_float_values(augmentation_cls, params):
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -279,6 +282,7 @@ def test_augmentations_wont_change_input(augmentation_cls, params):
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -372,6 +376,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params):
A.Crop,
A.CropNonEmptyMaskIfExists,
A.RandomCrop,
A.AtLeastOneBBoxRandomCrop,
A.RandomResizedCrop,
A.RandomSizedCrop,
A.CropAndPad,
Expand Down Expand Up @@ -462,6 +467,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, sha
A.Crop,
A.CropNonEmptyMaskIfExists,
A.RandomCrop,
A.AtLeastOneBBoxRandomCrop,
A.RandomResizedCrop,
A.RandomSizedCrop,
A.CropAndPad,
Expand Down Expand Up @@ -557,6 +563,7 @@ def test_mask_fill_value(augmentation_cls, params):
A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10},
A.CenterCrop: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -645,6 +652,7 @@ def test_multichannel_image_augmentations(augmentation_cls, params):
A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10},
A.CenterCrop: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -726,6 +734,7 @@ def test_float_multichannel_image_augmentations(augmentation_cls, params):
A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10},
A.CenterCrop: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -811,6 +820,7 @@ def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params
A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10},
A.CenterCrop: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -1053,6 +1063,7 @@ def test_pad_if_needed_position(params, image_shape):
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down
7 changes: 6 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def test_single_transform_compose(
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"size": (10, 10)},
A.RandomSizedCrop: {"min_max_height": (4, 8), "size": (10, 10)},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -1069,6 +1070,7 @@ def test_transform_always_apply_warning() -> None:
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"size": (10, 10)},
A.RandomSizedCrop: {"min_max_height": (4, 8), "size": (10, 10)},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -1160,7 +1162,7 @@ def test_images_as_target(augmentation_cls, params, as_array, shape):
if len(shape) == 3:
assert transformed["images"].shape[-1] == image.shape[2] # Channels match input

if augmentation_cls not in [A.RandomCrop, A.RandomResizedCrop, A.Resize, A.RandomSizedCrop, A.RandomSizedBBoxSafeCrop,
if augmentation_cls not in [A.RandomCrop, A.AtLeastOneBBoxRandomCrop, A.RandomResizedCrop, A.Resize, A.RandomSizedCrop, A.RandomSizedBBoxSafeCrop,
A.BBoxSafeRandomCrop, A.Transpose, A.RandomCropNearBBox, A.CenterCrop, A.Crop, A.CropAndPad,
A.LongestMaxSize, A.RandomScale, A.PadIfNeeded, A.SmallestMaxSize, A.RandomCropFromBorders,
A.RandomRotate90, A.D4]:
Expand Down Expand Up @@ -1198,6 +1200,7 @@ def test_images_as_target(augmentation_cls, params, as_array, shape):
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -1292,6 +1295,7 @@ def test_non_contiguous_input_with_compose(augmentation_cls, params, bboxes):
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"size": (10, 10)},
A.RandomSizedCrop: {"min_max_height": (4, 8), "size": (10, 10)},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -1378,6 +1382,7 @@ def test_masks_as_target(augmentation_cls, params, masks):
A.PixelDistributionAdaptation,
A.PadIfNeeded,
A.RandomCrop,
A.AtLeastOneBBoxRandomCrop,
A.Crop,
A.CenterCrop,
A.FDA,
Expand Down
4 changes: 4 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"size": (10, 10)},
A.RandomSizedCrop: {"min_max_height": (4, 8), "size": (10, 10)},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -221,6 +222,7 @@ def test_augmentations_serialization_to_file_with_custom_parameters(
A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10},
A.CenterCrop: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -289,6 +291,7 @@ def test_augmentations_for_bboxes_serialization(
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
Expand Down Expand Up @@ -817,6 +820,7 @@ def test_template_transform_serialization(
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"size": (10, 10)},
A.RandomSizedCrop: {"min_max_height": (4, 8), "size": (10, 10)},
A.CropAndPad: {"px": 10},
Expand Down
1 change: 1 addition & 0 deletions tests/test_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def test_image_only(augmentation_cls, params):
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.AtLeastOneBBoxRandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10},
Expand Down
Loading
Loading