Skip to content

Commit

Permalink
Updated TemplateTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Jan 7, 2025
1 parent f6e16c3 commit ffd8c31
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 31 deletions.
25 changes: 11 additions & 14 deletions albumentations/augmentations/crops/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,21 +1681,18 @@ def _get_px_params(self) -> list[int]:
raise ValueError(msg)

if isinstance(self.px, int):
params = [self.px] * 4
elif len(self.px) == PAIR:
return [self.px] * 4
if len(self.px) == PAIR:
if self.sample_independently:
params = [self.py_random.randrange(*self.px) for _ in range(4)]
else:
px = self.py_random.randrange(*self.px)
params = [px] * 4
elif isinstance(self.px[0], int):
params = self.px
elif len(self.px[0]) == PAIR:
params = [self.py_random.randrange(*i) for i in self.px]
else:
params = [self.py_random.choice(i) for i in self.px]

return params
return [self.py_random.randrange(*self.px) for _ in range(4)]
px = self.py_random.randrange(*self.px)
return [px] * 4
if isinstance(self.px[0], int):
return self.px
if len(self.px[0]) == PAIR:
return [self.py_random.randrange(*i) for i in self.px]

return [self.py_random.choice(i) for i in self.px]

def _get_percent_params(self) -> list[float]:
if self.percent is None:
Expand Down
6 changes: 1 addition & 5 deletions albumentations/augmentations/domain_adaptation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import cv2
import numpy as np
from albucore import add_weighted, get_num_channels
from pydantic import AfterValidator, Field, field_validator
from pydantic import AfterValidator, field_validator

import albumentations.augmentations.geometric.functional as fgeometric
from albumentations.augmentations.domain_adaptation.functional import (
Expand Down Expand Up @@ -439,9 +439,6 @@ class TemplateTransform(ImageOnlyTransform):
class InitSchema(BaseTransformInitSchema):
templates: np.ndarray | Sequence[np.ndarray]
img_weight: ZeroOneRangeType
template_weight: ZeroOneRangeType | None = Field(
deprecated="Template_weight is deprecated. Computed automatically as (1 - img_weight)",
)
template_transform: Compose | BasicTransform | None = None
name: str | None

Expand All @@ -462,7 +459,6 @@ def __init__(
self,
templates: np.ndarray | list[np.ndarray],
img_weight: ScaleFloatType = (0.5, 0.5),
template_weight: None = None,
template_transform: Compose | BasicTransform | None = None,
name: str | None = None,
p: float = 0.5,
Expand Down
21 changes: 9 additions & 12 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,28 +696,25 @@ def test_smallest_max_size_list():
@pytest.mark.parametrize(
[
"img_weight",
"template_weight",
"template_transform",
"image_size",
"template_size",
],
[
(
0.5,
0.5,
A.RandomSizedCrop((50, 200), size=(513, 450), p=1.0),
(513, 450),
(224, 224),
),
(0.3, 0.5, A.RandomResizedCrop(size=(513, 450), p=1.0), (513, 450), (224, 224)),
(1.0, 0.5, A.CenterCrop(500, 450, p=1.0), (500, 450, 3), (512, 512, 3)),
(0.5, 0.8, A.Resize(513, 450, p=1.0), (513, 450), (512, 512)),
(0.5, 0.2, A.NoOp(), (224, 224), (224, 224)),
(0.5, 0.9, A.NoOp(), (512, 512, 3), (512, 512, 3)),
(0.5, 0.5, None, (512, 512), (512, 512)),
(0.8, 0.7, None, (512, 512, 3), (512, 512, 3)),
(0.3, A.RandomResizedCrop(size=(513, 450), p=1.0), (513, 450), (224, 224)),
(1.0, A.CenterCrop(500, 450, p=1.0), (500, 450, 3), (512, 512, 3)),
(0.5, A.Resize(513, 450, p=1.0), (513, 450), (512, 512)),
(0.5, A.NoOp(), (224, 224), (224, 224)),
(0.5, A.NoOp(), (512, 512, 3), (512, 512, 3)),
(0.5, None, (512, 512), (512, 512)),
(0.8, None, (512, 512, 3), (512, 512, 3)),
(
0.5,
0.5,
A.Compose(
[
Expand All @@ -732,12 +729,12 @@ def test_smallest_max_size_list():
],
)
def test_template_transform(
img_weight, template_weight, template_transform, image_size, template_size
img_weight, template_transform, image_size, template_size
):
img = np.random.randint(0, 256, image_size, np.uint8)
template = np.random.randint(0, 256, template_size, np.uint8)

aug = A.TemplateTransform(template, img_weight, template_weight, template_transform)
aug = A.TemplateTransform(template, img_weight, template_transform)
result = aug(image=img)["image"]

assert result.shape == img.shape
Expand Down

0 comments on commit ffd8c31

Please sign in to comment.