Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement pad fix template #4

Merged
merged 1 commit into from
Sep 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions meddlr/data/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
import math
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Sequence, Tuple

import numpy as np
import torch
Expand All @@ -11,7 +11,8 @@
import meddlr.ops as F
from meddlr.forward import SenseModel
from meddlr.ops import complex as cplx
from meddlr.transforms.gen.spatial import RandomAffine
from meddlr.transforms.gen.spatial import RandomAffine, RandomTranslation
from meddlr.transforms.transform_gen import TransformGen
from meddlr.utils import transforms as T

from .motion import MotionModel
Expand All @@ -26,7 +27,7 @@
def affine_transform(
image: torch.Tensor,
nshots: int,
translation: Optional[RandomAffine] = None,
transforms: Sequence[TransformGen],
trajectory: str = "blocked",
) -> torch.Tensor:
"""
Expand All @@ -46,9 +47,11 @@ def affine_transform(
image: The complex-valued image. Shape [..., height, width].
nshots: The number of shots in the image.
This should be equivalent to ceil(phase_encode_dim / echo_train_length).
translation: This is the translation to augment images in the image
domain. This is either 'None' or 'RandomAffine' for now.
trajectory: One of 'interleaved' or 'consecutive'.
transforms: A sequence of random transform generators. These transforms
will be used to augment images in the image domain. We recommend using
[RandomTranslation, RandomAffine] in that order. This matches the MRAugment
augmentation strategy.
trajectory: One of 'interleaved' or 'blocked'.

Returns:
A motion corrupted image.
Expand All @@ -57,7 +60,11 @@ def affine_transform(
offset = int(math.ceil(kspace.shape[-1] / nshots))

for shot in range(nshots):
motion_image = translation.get_transform(image).apply_image(image)
# Apply sequence of random transforms to the image.
motion_image = image
for tfm in transforms:
motion_image = tfm.get_transform(motion_image).apply_image(motion_image)

motion_kspace = F.fft2c(motion_image)
if trajectory == "blocked":
kspace[..., shot * offset : (shot + 1) * offset] = motion_kspace[
Expand Down Expand Up @@ -558,24 +565,33 @@ def __call__(self, kspace, maps, target, fname, slice_id, is_fixed, acceleration

add_motion = self.add_motion and self._is_test
if add_motion:
# Separate translation and affine transformations for proper padding.
translation = RandomTranslation(
p=1.0,
translate=self.translation,
pad_mode="reflect" if self.pad_like == "mraugment" else "constant",
pad_value=0.0,
ndim=2
)
affine = RandomAffine(
p=1.0, translate=None, angle=self.angle, pad_like=pad
)
transforms: List[TransformGen] = [translation, affine]
# Motion seed should not be different for each slice for now.
# TODO: Change this for 2D acquisitions.
for tfm_gen in transforms:
tfm_gen.seed(seed)

tfm_gen = RandomAffine(
p=1.0, translate=self.translation, angle=self.angle, pad_like=pad
)
tfm_gen.seed(seed)

image = image.permute(0, 3, 1, 2) # Shape: (B, 1, H, W)
image = image.permute(0, 3, 1, 2) # Shape: (B, 1, H, W)

motion_img = affine_transform(
image=image,
nshots=self.nshots,
translation=tfm_gen,
transforms=transforms,
trajectory=self.trajectory,
)

motion_img = motion_img.permute(0, 2, 3, 1) # Shape: (B, H, W, 1)
motion_img = motion_img.permute(0, 2, 3, 1) # Shape: (B, H, W, 1)
sense = SenseModel(maps)
kspace = sense(motion_img)

Expand Down