From ffd8c317b42e0f2ed165fa3d93ba5e426e6699a6 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Tue, 7 Jan 2025 08:30:25 -0800 Subject: [PATCH] Updated TemplateTransform --- .../augmentations/crops/transforms.py | 25 ++++++++----------- .../domain_adaptation/transforms.py | 6 +---- tests/test_transforms.py | 21 +++++++--------- 3 files changed, 21 insertions(+), 31 deletions(-) diff --git a/albumentations/augmentations/crops/transforms.py b/albumentations/augmentations/crops/transforms.py index f6b0a7a83..8bd1ac18c 100644 --- a/albumentations/augmentations/crops/transforms.py +++ b/albumentations/augmentations/crops/transforms.py @@ -1681,21 +1681,18 @@ def _get_px_params(self) -> list[int]: raise ValueError(msg) if isinstance(self.px, int): - params = [self.px] * 4 - elif len(self.px) == PAIR: + return [self.px] * 4 + if len(self.px) == PAIR: if self.sample_independently: - params = [self.py_random.randrange(*self.px) for _ in range(4)] - else: - px = self.py_random.randrange(*self.px) - params = [px] * 4 - elif isinstance(self.px[0], int): - params = self.px - elif len(self.px[0]) == PAIR: - params = [self.py_random.randrange(*i) for i in self.px] - else: - params = [self.py_random.choice(i) for i in self.px] - - return params + return [self.py_random.randrange(*self.px) for _ in range(4)] + px = self.py_random.randrange(*self.px) + return [px] * 4 + if isinstance(self.px[0], int): + return self.px + if len(self.px[0]) == PAIR: + return [self.py_random.randrange(*i) for i in self.px] + + return [self.py_random.choice(i) for i in self.px] def _get_percent_params(self) -> list[float]: if self.percent is None: diff --git a/albumentations/augmentations/domain_adaptation/transforms.py b/albumentations/augmentations/domain_adaptation/transforms.py index b36b41d28..16dcd0252 100644 --- a/albumentations/augmentations/domain_adaptation/transforms.py +++ b/albumentations/augmentations/domain_adaptation/transforms.py @@ -6,7 +6,7 @@ import cv2 import numpy as np from albucore import add_weighted, get_num_channels -from pydantic import AfterValidator, Field, field_validator +from pydantic import AfterValidator, field_validator import albumentations.augmentations.geometric.functional as fgeometric from albumentations.augmentations.domain_adaptation.functional import ( @@ -439,9 +439,6 @@ class TemplateTransform(ImageOnlyTransform): class InitSchema(BaseTransformInitSchema): templates: np.ndarray | Sequence[np.ndarray] img_weight: ZeroOneRangeType - template_weight: ZeroOneRangeType | None = Field( - deprecated="Template_weight is deprecated. Computed automatically as (1 - img_weight)", - ) template_transform: Compose | BasicTransform | None = None name: str | None @@ -462,7 +459,6 @@ def __init__( self, templates: np.ndarray | list[np.ndarray], img_weight: ScaleFloatType = (0.5, 0.5), - template_weight: None = None, template_transform: Compose | BasicTransform | None = None, name: str | None = None, p: float = 0.5, diff --git a/tests/test_transforms.py b/tests/test_transforms.py index e46894320..6cd5c73a5 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -696,28 +696,25 @@ def test_smallest_max_size_list(): @pytest.mark.parametrize( [ "img_weight", - "template_weight", "template_transform", "image_size", "template_size", ], [ ( - 0.5, 0.5, A.RandomSizedCrop((50, 200), size=(513, 450), p=1.0), (513, 450), (224, 224), ), - (0.3, 0.5, A.RandomResizedCrop(size=(513, 450), p=1.0), (513, 450), (224, 224)), - (1.0, 0.5, A.CenterCrop(500, 450, p=1.0), (500, 450, 3), (512, 512, 3)), - (0.5, 0.8, A.Resize(513, 450, p=1.0), (513, 450), (512, 512)), - (0.5, 0.2, A.NoOp(), (224, 224), (224, 224)), - (0.5, 0.9, A.NoOp(), (512, 512, 3), (512, 512, 3)), - (0.5, 0.5, None, (512, 512), (512, 512)), - (0.8, 0.7, None, (512, 512, 3), (512, 512, 3)), + (0.3, A.RandomResizedCrop(size=(513, 450), p=1.0), (513, 450), (224, 224)), + (1.0, A.CenterCrop(500, 450, p=1.0), (500, 450, 3), (512, 512, 3)), + (0.5, A.Resize(513, 450, p=1.0), (513, 450), (512, 512)), + (0.5, A.NoOp(), (224, 224), (224, 224)), + (0.5, A.NoOp(), (512, 512, 3), (512, 512, 3)), + (0.5, None, (512, 512), (512, 512)), + (0.8, None, (512, 512, 3), (512, 512, 3)), ( - 0.5, 0.5, A.Compose( [ @@ -732,12 +729,12 @@ def test_smallest_max_size_list(): ], ) def test_template_transform( - img_weight, template_weight, template_transform, image_size, template_size + img_weight, template_transform, image_size, template_size ): img = np.random.randint(0, 256, image_size, np.uint8) template = np.random.randint(0, 256, template_size, np.uint8) - aug = A.TemplateTransform(template, img_weight, template_weight, template_transform) + aug = A.TemplateTransform(template, img_weight, template_transform) result = aug(image=img)["image"] assert result.shape == img.shape