diff --git a/meddlr/data/transforms/transform.py b/meddlr/data/transforms/transform.py index f6ef4afe..1b4c6131 100644 --- a/meddlr/data/transforms/transform.py +++ b/meddlr/data/transforms/transform.py @@ -2,7 +2,7 @@ """ import math from functools import partial -from typing import Optional, Tuple +from typing import List, Optional, Sequence, Tuple import numpy as np import torch @@ -11,7 +11,8 @@ import meddlr.ops as F from meddlr.forward import SenseModel from meddlr.ops import complex as cplx -from meddlr.transforms.gen.spatial import RandomAffine +from meddlr.transforms.gen.spatial import RandomAffine, RandomTranslation +from meddlr.transforms.transform_gen import TransformGen from meddlr.utils import transforms as T from .motion import MotionModel @@ -26,7 +27,7 @@ def affine_transform( image: torch.Tensor, nshots: int, - translation: Optional[RandomAffine] = None, + transforms: Sequence[TransformGen], trajectory: str = "blocked", ) -> torch.Tensor: """ @@ -46,9 +47,11 @@ def affine_transform( image: The complex-valued image. Shape [..., height, width]. nshots: The number of shots in the image. This should be equivalent to ceil(phase_encode_dim / echo_train_length). - translation: This is the translation to augment images in the image - domain. This is either 'None' or 'RandomAffine' for now. - trajectory: One of 'interleaved' or 'consecutive'. + transforms: A sequence of random transform generators. These transforms + will be used to augment images in the image domain. We recommend using + [RandomTranslation, RandomAffine] in that order. This matches the MRAugment + augmentation strategy. + trajectory: One of 'interleaved' or 'blocked'. Returns: A motion corrupted image. @@ -57,7 +60,11 @@ def affine_transform( offset = int(math.ceil(kspace.shape[-1] / nshots)) for shot in range(nshots): - motion_image = translation.get_transform(image).apply_image(image) + # Apply sequence of random transforms to the image. + motion_image = image + for tfm in transforms: + motion_image = tfm.get_transform(motion_image).apply_image(motion_image) + motion_kspace = F.fft2c(motion_image) if trajectory == "blocked": kspace[..., shot * offset : (shot + 1) * offset] = motion_kspace[ @@ -558,24 +565,33 @@ def __call__(self, kspace, maps, target, fname, slice_id, is_fixed, acceleration add_motion = self.add_motion and self._is_test if add_motion: + # Separate translation and affine transformations for proper padding. + translation = RandomTranslation( + p=1.0, + translate=self.translation, + pad_mode="reflect" if self.pad_like == "mraugment" else "constant", + pad_value=0.0, + ndim=2 + ) + affine = RandomAffine( + p=1.0, translate=None, angle=self.angle, pad_like=pad + ) + transforms: List[TransformGen] = [translation, affine] # Motion seed should not be different for each slice for now. # TODO: Change this for 2D acquisitions. + for tfm_gen in transforms: + tfm_gen.seed(seed) - tfm_gen = RandomAffine( - p=1.0, translate=self.translation, angle=self.angle, pad_like=pad - ) - tfm_gen.seed(seed) - - image = image.permute(0, 3, 1, 2) # Shape: (B, 1, H, W) + image = image.permute(0, 3, 1, 2) # Shape: (B, 1, H, W) motion_img = affine_transform( image=image, nshots=self.nshots, - translation=tfm_gen, + transforms=transforms, trajectory=self.trajectory, ) - motion_img = motion_img.permute(0, 2, 3, 1) # Shape: (B, H, W, 1) + motion_img = motion_img.permute(0, 2, 3, 1) # Shape: (B, H, W, 1) sense = SenseModel(maps) kspace = sense(motion_img)