Skip to content

Commit

Permalink
Merge pull request #4 from TerminalVelocityDPro/pad-fix-arjun
Browse files Browse the repository at this point in the history
implement pad fix template
  • Loading branch information
TerminalVelocityDPro authored Sep 7, 2022
2 parents 963d89c + bda889d commit a85d86e
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions meddlr/data/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
import math
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Sequence, Tuple

import numpy as np
import torch
Expand All @@ -11,7 +11,8 @@
import meddlr.ops as F
from meddlr.forward import SenseModel
from meddlr.ops import complex as cplx
from meddlr.transforms.gen.spatial import RandomAffine
from meddlr.transforms.gen.spatial import RandomAffine, RandomTranslation
from meddlr.transforms.transform_gen import TransformGen
from meddlr.utils import transforms as T

from .motion import MotionModel
Expand All @@ -26,7 +27,7 @@
def affine_transform(
image: torch.Tensor,
nshots: int,
translation: Optional[RandomAffine] = None,
transforms: Sequence[TransformGen],
trajectory: str = "blocked",
) -> torch.Tensor:
"""
Expand All @@ -46,9 +47,11 @@ def affine_transform(
image: The complex-valued image. Shape [..., height, width].
nshots: The number of shots in the image.
This should be equivalent to ceil(phase_encode_dim / echo_train_length).
translation: This is the translation to augment images in the image
domain. This is either 'None' or 'RandomAffine' for now.
trajectory: One of 'interleaved' or 'consecutive'.
transforms: A sequence of random transform generators. These transforms
will be used to augment images in the image domain. We recommend using
[RandomTranslation, RandomAffine] in that order. This matches the MRAugment
augmentation strategy.
trajectory: One of 'interleaved' or 'blocked'.
Returns:
A motion corrupted image.
Expand All @@ -57,7 +60,11 @@ def affine_transform(
offset = int(math.ceil(kspace.shape[-1] / nshots))

for shot in range(nshots):
motion_image = translation.get_transform(image).apply_image(image)
# Apply sequence of random transforms to the image.
motion_image = image
for tfm in transforms:
motion_image = tfm.get_transform(motion_image).apply_image(motion_image)

motion_kspace = F.fft2c(motion_image)
if trajectory == "blocked":
kspace[..., shot * offset : (shot + 1) * offset] = motion_kspace[
Expand Down Expand Up @@ -558,24 +565,33 @@ def __call__(self, kspace, maps, target, fname, slice_id, is_fixed, acceleration

add_motion = self.add_motion and self._is_test
if add_motion:
# Separate translation and affine transformations for proper padding.
translation = RandomTranslation(
p=1.0,
translate=self.translation,
pad_mode="reflect" if self.pad_like == "mraugment" else "constant",
pad_value=0.0,
ndim=2
)
affine = RandomAffine(
p=1.0, translate=None, angle=self.angle, pad_like=pad
)
transforms: List[TransformGen] = [translation, affine]
# Motion seed should not be different for each slice for now.
# TODO: Change this for 2D acquisitions.
for tfm_gen in transforms:
tfm_gen.seed(seed)

tfm_gen = RandomAffine(
p=1.0, translate=self.translation, angle=self.angle, pad_like=pad
)
tfm_gen.seed(seed)

image = image.permute(0, 3, 1, 2) # Shape: (B, 1, H, W)
image = image.permute(0, 3, 1, 2) # Shape: (B, 1, H, W)

motion_img = affine_transform(
image=image,
nshots=self.nshots,
translation=tfm_gen,
transforms=transforms,
trajectory=self.trajectory,
)

motion_img = motion_img.permute(0, 2, 3, 1) # Shape: (B, H, W, 1)
motion_img = motion_img.permute(0, 2, 3, 1) # Shape: (B, H, W, 1)
sense = SenseModel(maps)
kspace = sense(motion_img)

Expand Down

0 comments on commit a85d86e

Please sign in to comment.