Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide init_kwargs to every transform within another transform #1228

Merged
merged 7 commits into from
Jan 27, 2025
Merged
14 changes: 13 additions & 1 deletion src/torchio/transforms/augmentation/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def __getitem__(self, index) -> Transform:
def __repr__(self) -> str:
return f'{self.name}({self.transforms})'

def get_base_args(self) -> Dict:
init_args = super().get_base_args()
if 'parse_input' in init_args:
init_args.pop('parse_input')
return init_args

def apply_transform(self, subject: Subject) -> Subject:
for transform in self.transforms:
subject = transform(subject) # type: ignore[assignment]
Expand All @@ -66,7 +72,7 @@ def inverse(self, warn: bool = True) -> Compose:
message = f'Skipping {transform.name} as it is not invertible'
warnings.warn(message, RuntimeWarning, stacklevel=2)
transforms.reverse()
result = Compose(transforms)
result = Compose(transforms, **self.get_base_args())
if not transforms and warn:
warnings.warn(
'No invertible transforms found',
Expand Down Expand Up @@ -103,6 +109,12 @@ def __init__(self, transforms: TypeTransformsDict, **kwargs):
super().__init__(parse_input=False, **kwargs)
self.transforms_dict = self._get_transforms_dict(transforms)

def get_base_args(self) -> Dict:
init_args = super().get_base_args()
if 'parse_input' in init_args:
init_args.pop('parse_input')
return init_args

def apply_transform(self, subject: Subject) -> Subject:
weights = torch.Tensor(list(self.transforms_dict.values()))
index = torch.multinomial(weights, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def apply_transform(self, subject: Subject) -> Subject:
coefficients = self.get_params(self.order, self.coefficients_range)
arguments['coefficients'][image_name] = coefficients
arguments['order'][image_name] = self.order
transform = BiasField(**self.add_include_exclude(arguments))
transform = BiasField(**self.add_base_args(arguments))
transformed = transform(subject)
return transformed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def apply_transform(self, subject: Subject) -> Subject:
for name in images_dict:
std = self.get_params(self.std_ranges) # type: ignore[arg-type]
arguments['std'][name] = std
transform = Blur(**self.add_include_exclude(arguments))
transform = Blur(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def apply_transform(self, subject: Subject) -> Subject:
for name, image in images_dict.items():
gammas = [self.get_params(self.log_gamma_range) for _ in image.data]
arguments['gamma'][name] = gammas
transform = Gamma(**self.add_include_exclude(arguments))
transform = Gamma(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def apply_transform(self, subject: Subject) -> Subject:
arguments['axis'][name] = axis_param
arguments['intensity'][name] = intensity_param
arguments['restore'][name] = restore_param
transform = Ghosting(**self.add_include_exclude(arguments))
transform = Ghosting(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def apply_transform(self, subject: Subject) -> Subject:
means.append(mean)
stds.append(std)

transform = LabelsToImage(**self.add_include_exclude(arguments))
transform = LabelsToImage(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def apply_transform(self, subject: Subject) -> Subject:
arguments['degrees'][name] = degrees_params
arguments['translation'][name] = translation_params
arguments['image_interpolation'][name] = self.image_interpolation
transform = Motion(**self.add_include_exclude(arguments))
transform = Motion(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def apply_transform(self, subject: Subject) -> Subject:
arguments['mean'][image_name] = mean
arguments['std'][image_name] = std
arguments['seed'][image_name] = seed
transform = Noise(**self.add_include_exclude(arguments))
transform = Noise(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def apply_transform(self, subject: Subject) -> Subject:
)
arguments['spikes_positions'][image_name] = spikes_positions_param
arguments['intensity'][image_name] = intensity_param
transform = Spike(**self.add_include_exclude(arguments))
transform = Spike(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def apply_transform(self, subject: Subject) -> Subject:
)
arguments['locations'][name] = locations
arguments['patch_size'][name] = self.patch_size
transform = Swap(**self.add_include_exclude(arguments))
transform = Swap(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
5 changes: 0 additions & 5 deletions src/torchio/transforms/augmentation/random_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ class RandomTransform(Transform):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def add_include_exclude(self, kwargs):
kwargs['include'] = self.include
kwargs['exclude'] = self.exclude
return kwargs

def parse_degrees(
self,
degrees: TypeRangeFloat,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def apply_transform(self, subject: Subject) -> Subject:
'label_interpolation': self.label_interpolation,
'check_shape': self.check_shape,
}
transform = Affine(**self.add_include_exclude(arguments))
transform = Affine(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def apply_transform(self, subject: Subject) -> Subject:
}

sx, sy, sz = target_spacing # for mypy
downsample = Resample(
target=(sx, sy, sz), **self.add_include_exclude(arguments)
)
downsample = Resample(target=(sx, sy, sz), **self.add_base_args(arguments))
downsampled = downsample(subject)
image = subject.get_first_image()
target = image.spatial_shape, image.affine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def apply_transform(self, subject: Subject) -> Subject:
'label_interpolation': self.label_interpolation,
}

transform = ElasticDeformation(**self.add_include_exclude(arguments))
transform = ElasticDeformation(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
2 changes: 1 addition & 1 deletion src/torchio/transforms/augmentation/spatial/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def apply_transform(self, subject: Subject) -> Subject:
return subject

arguments = {'axes': axes_list}
transform = Flip(**self.add_include_exclude(arguments))
transform = Flip(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ def apply_transform(self, subject):
remapping = {
unique_labels[i].item(): i for i in range(0, len(unique_labels))
}
init_kwargs = self.get_base_args()
init_kwargs['include'] = [name]

transform = RemapLabels(
remapping=remapping,
masking_method=self.masking_method,
include=[name],
**init_kwargs,
)
subject = transform(subject)
return subject
4 changes: 2 additions & 2 deletions src/torchio/transforms/preprocessing/spatial/crop_or_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ def apply_transform(self, subject: Subject) -> Subject:
padding_params, cropping_params = self.compute_crop_or_pad(subject)
padding_kwargs = {'padding_mode': self.padding_mode}
if padding_params is not None:
pad = Pad(padding_params, **padding_kwargs)
pad = Pad(padding_params, **self.get_base_args(), **padding_kwargs)
subject = pad(subject) # type: ignore[assignment]
if cropping_params is not None:
crop = Crop(cropping_params)
crop = Crop(cropping_params, **self.get_base_args())
subject = crop(subject) # type: ignore[assignment]
return subject
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,6 @@ def apply_transform(self, subject: Subject) -> Subject:
integer_ratio = function(source_shape / self.target_multiple)
target_shape = integer_ratio * self.target_multiple
target_shape = np.maximum(target_shape, 1)
transform = CropOrPad(target_shape.astype(int))
transform = CropOrPad(target_shape.astype(int), **self.get_base_args())
subject = transform(subject) # type: ignore[assignment]
return subject
3 changes: 2 additions & 1 deletion src/torchio/transforms/preprocessing/spatial/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def apply_transform(self, subject: Subject) -> Subject:
spacing_out,
image_interpolation=self.image_interpolation,
label_interpolation=self.label_interpolation,
**self.get_base_args(),
)
resampled = resample(subject)
assert isinstance(resampled, Subject)
Expand All @@ -72,7 +73,7 @@ def apply_transform(self, subject: Subject) -> Subject:
f' != target shape {tuple(shape_out)}. Fixing with CropOrPad'
)
warnings.warn(message, RuntimeWarning, stacklevel=2)
crop_pad = CropOrPad(shape_out) # type: ignore[arg-type]
crop_pad = CropOrPad(shape_out, **self.get_base_args()) # type: ignore[arg-type]
resampled = crop_pad(resampled)
assert isinstance(resampled, Subject)
return resampled
32 changes: 32 additions & 0 deletions src/torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,38 @@
else:
return super().__repr__()

def get_base_args(self) -> dict:
r"""Provides easy access to the arguments used to instantiate the base class
(:class:`~torchio.transforms.transform.Transform`) of any transform.

This method is particularly useful when a new transform can be represented as a variant
of an existing transform (e.g. all random transforms), allowing for seamless instantiation
of the existing transform with the same arguments as the new transform during `apply_transform`.

Note: The `p` argument (probability of applying the transform) is excluded to avoid
multiplying the probability of both existing and new transform.
"""
return {
'copy': self.copy,
'include': self.include,
'exclude': self.exclude,
'keep': self.keep,
'parse_input': self.parse_input,
'label_keys': self.label_keys,
}

def add_base_args(
self,
arguments,
overwrite_on_existing: bool = False,
):
"""Add the init args to existing arguments"""
for key, value in self.get_base_args().items():
if key in arguments and not overwrite_on_existing:
continue

Check warning on line 217 in src/torchio/transforms/transform.py

View check run for this annotation

Codecov / codecov/patch

src/torchio/transforms/transform.py#L217

Added line #L217 was not covered by tests
arguments[key] = value
return arguments

@property
def name(self):
return self.__class__.__name__
Expand Down
23 changes: 23 additions & 0 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,26 @@ def test_bad_keys_type(self):
# From https://github.com/fepegar/torchio/issues/923
with self.assertRaises(ValueError):
tio.RandomAffine(include='t1')

def test_init_args(self):
transform = tio.Compose([tio.RandomNoise()])
base_args = transform.get_base_args()
assert 'parse_input' not in base_args

transform = tio.OneOf([tio.RandomNoise()])
base_args = transform.get_base_args()
assert 'parse_input' not in base_args

transform = tio.RandomNoise()
base_args = transform.get_base_args()
assert all(
arg in base_args
for arg in [
'copy',
'include',
'exclude',
'keep',
'parse_input',
'label_keys',
]
)
Loading