Skip to content

Commit

Permalink
Merge pull request #216 from eigenvivek/efficient-masked-rendering
Browse files Browse the repository at this point in the history
Implement efficient rendering of segmentation masks
  • Loading branch information
eigenvivek authored Apr 16, 2024
2 parents 8ec723f + 3b81d5c commit 0656177
Show file tree
Hide file tree
Showing 8 changed files with 418 additions and 224 deletions.
39 changes: 22 additions & 17 deletions diffdrr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
__all__ = ['load_example_ct', 'read']

# %% ../notebooks/api/03_data.ipynb 4
def load_example_ct() -> Subject:
def load_example_ct(labels=None) -> Subject:
"""Load an example chest CT for demonstration purposes."""
datadir = Path(__file__).resolve().parent / "data"
filename = datadir / "cxr.nii.gz"
labelmap = datadir / "mask.nii.gz"
structures = pd.read_csv(datadir / "structures.csv")
return read(filename, labelmap, structures=structures)
return read(filename, labelmap, labels, structures=structures)

# %% ../notebooks/api/03_data.ipynb 5
def read(
filename: str | Path, # Path to CT volume
labelmap: str | Path = None, # Path to a labelmap for the CT volume
labels: int | list = None, # Labels from the mask of structures to render
**kwargs, # Any additional information to be stored in the torchio.Subject
) -> Subject:
"""
Expand Down Expand Up @@ -55,6 +56,16 @@ def read(
# Canonicalize the images by converting to RAS+ and moving the
# Subject's isocenter to the origin in world coordinates
subject = canonicalize(subject)

# Apply mask
if labels is not None:
if isinstance(labels, int):
labels = [labels]
mask = torch.any(
torch.stack([mask.data.squeeze() == idx for idx in labels]), dim=0
)
subject.density = subject.density * mask

return subject

# %% ../notebooks/api/03_data.ipynb 6
Expand All @@ -63,24 +74,18 @@ def canonicalize(subject):
subject = ToCanonical()(subject)

# Move the Subject's isocenter to the origin in world coordinates
isocenter = subject.volume.get_center()
Tinv = np.array(
[
[1.0, 0.0, 0.0, -isocenter[0]],
[0.0, 1.0, 0.0, -isocenter[1]],
[0.0, 0.0, 1.0, -isocenter[2]],
[0.0, 0.0, 0.0, 1.0],
]
)
for image in subject.get_images(intensity_only=False):
isocenter = image.get_center()
Tinv = np.array(
[
[1.0, 0.0, 0.0, -isocenter[0]],
[0.0, 1.0, 0.0, -isocenter[1]],
[0.0, 0.0, 1.0, -isocenter[2]],
[0.0, 0.0, 0.0, 1.0],
]
)
image.affine = Tinv.dot(image.affine)

# Need to manually change the affine matrix of the labelmap
try:
subject.mask.affine = subject.volume.affine
except AttributeError:
pass

return subject

# %% ../notebooks/api/03_data.ipynb 7
Expand Down
40 changes: 22 additions & 18 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
"origin", torch.tensor(subject.volume.origin, dtype=torch.float32)
)
if subject.mask is not None:
self.register_buffer("mask", subject.mask.data.squeeze())
self.register_buffer("mask", subject.mask.data[0].to(torch.int64))

# Initialize the renderer
if renderer == "siddon":
Expand All @@ -87,7 +87,9 @@ def __init__(
def reshape_transform(self, img, batch_size):
if self.reshape:
if self.detector.n_subsample is None:
img = img.view(-1, 1, self.detector.height, self.detector.width)
img = img.view(
batch_size, -1, self.detector.height, self.detector.width
)
else:
img = reshape_subsampled_drr(img, self.detector, batch_size)
return img
Expand All @@ -110,7 +112,7 @@ def forward(
*args, # Some batched representation of SE(3)
parameterization: str = None, # Specifies the representation of the rotation
convention: str = None, # If parameterization is Euler angles, specify convention
labels: list = None, # Labels from the mask of structures to render
mask_to_channels: bool = False, # If True, structures from the CT mask are rendered in separate channels
**kwargs, # Passed to the renderer
):
"""Generate DRR with rotational and translational parameters."""
Expand All @@ -121,30 +123,32 @@ def forward(
pose = convert(*args, parameterization=parameterization, convention=convention)
source, target = self.detector(pose)

# Apply mask
if labels is not None:
if isinstance(labels, int):
labels = [labels]
mask = torch.any(torch.stack([self.mask == idx for idx in labels]), dim=0)
density = self.density * mask
else:
density = self.density

# Render the DRR
if self.patch_size is not None:
kwargs["mask"] = self.mask if mask_to_channels else None
if self.patch_size is None:
img = self.renderer(
self.density,
self.origin,
self.spacing,
source,
target,
**kwargs,
)
else:
n_points = target.shape[1] // self.n_patches
img = []
for idx in range(self.n_patches):
t = target[:, idx * n_points : (idx + 1) * n_points]
partial = self.renderer(
density, self.origin, self.spacing, source, t, **kwargs
self.density,
self.origin,
self.spacing,
source,
t,
**kwargs,
)
img.append(partial)
img = torch.cat(img, dim=1)
else:
img = self.renderer(
density, self.origin, self.spacing, source, target, **kwargs
)
return self.reshape_transform(img, batch_size=len(pose))

# %% ../notebooks/api/00_drr.ipynb 11
Expand Down
53 changes: 42 additions & 11 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def dims(self, volume):
def maxidx(self, volume):
return volume.numel() - 1

def forward(self, volume, origin, spacing, source, target):
def forward(self, volume, origin, spacing, source, target, mask=None):
dims = self.dims(volume)
maxidx = self.maxidx(volume)
origin = origin.to(torch.float64)

alphas = _get_alphas(source, target, origin, spacing, dims, self.eps)
alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2
voxels = _get_voxel(
voxels, idxs = _get_voxel(
alphamid, source, target, volume, origin, spacing, dims, maxidx, self.eps
)

Expand All @@ -36,10 +36,29 @@ def forward(self, volume, origin, spacing, source, target):
step_length = torch.diff(alphas, dim=-1)
weighted_voxels = voxels * step_length

drr = torch.nansum(weighted_voxels, dim=-1)
# Handle optional mapping
if mask is None:
img = torch.nansum(weighted_voxels, dim=-1)
img = img.unsqueeze(1)
else:
# Thanks to @Ivan for the clutch assist w/ pytorch tensor ops
# https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614
channels = torch.take(mask, idxs) # B D N
weighted_voxels = weighted_voxels.nan_to_num()
B, D, N = weighted_voxels.shape
C = channels.max().item() + 1
img = (
torch.zeros(B, C, D)
.to(volume)
.scatter_add_(
1, channels.transpose(-1, -2), weighted_voxels.transpose(-1, -2)
)
)

# Finish rendering the DRR
raylength = (target - source + self.eps).norm(dim=-1)
drr *= raylength
return drr
img *= raylength.unsqueeze(1)
return img

# %% ../notebooks/api/01_renderers.ipynb 8
def _get_alphas(source, target, origin, spacing, dims, eps):
Expand Down Expand Up @@ -91,7 +110,7 @@ def _get_alpha_minmax(sdd, source, target, origin, spacing, dims):

def _get_voxel(alpha, source, target, volume, origin, spacing, dims, maxidx, eps):
idxs = _get_index(alpha, source, target, origin, spacing, dims, maxidx, eps)
return torch.take(volume, idxs)
return torch.take(volume, idxs), idxs


def _get_index(alpha, source, target, origin, spacing, dims, maxidx, eps):
Expand Down Expand Up @@ -134,8 +153,20 @@ def dims(self, volume):
return torch.tensor(volume.shape).to(volume) - 1

def forward(
self, volume, origin, spacing, source, target, n_points=250, align_corners=True
self,
volume,
origin,
spacing,
source,
target,
n_points=250,
align_corners=True,
mask=None,
):
# Ensure not using mask_to_channels
if mask is not None:
raise ValueErro("mask_to_channels can only be True if renderer=='Siddon'")

# Get the raylength and reshape sources
raylength = (source - target + self.eps).norm(dim=-1)
source = source[:, None, :, None, :] - origin
Expand All @@ -153,7 +184,7 @@ def forward(
# Render the DRR
batch_size = len(rays)
vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)
drr = grid_sample(vol, rays, mode=self.mode, align_corners=align_corners)
drr = drr[:, 0, 0].sum(dim=-1)
drr *= raylength / n_points
return drr
img = grid_sample(vol, rays, mode=self.mode, align_corners=align_corners)
img = img[:, 0, 0].sum(dim=-1)
img *= raylength / n_points
return img
39 changes: 21 additions & 18 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
"#| export\n",
"from torchio import Subject\n",
"\n",
"\n",
"class DRR(nn.Module):\n",
" \"\"\"PyTorch module that computes differentiable digitally reconstructed radiographs.\"\"\"\n",
"\n",
Expand Down Expand Up @@ -164,7 +165,7 @@
" \"origin\", torch.tensor(subject.volume.origin, dtype=torch.float32)\n",
" )\n",
" if subject.mask is not None:\n",
" self.register_buffer(\"mask\", subject.mask.data.squeeze())\n",
" self.register_buffer(\"mask\", subject.mask.data[0].to(torch.int64))\n",
"\n",
" # Initialize the renderer\n",
" if renderer == \"siddon\":\n",
Expand All @@ -181,7 +182,7 @@
" def reshape_transform(self, img, batch_size):\n",
" if self.reshape:\n",
" if self.detector.n_subsample is None:\n",
" img = img.view(-1, 1, self.detector.height, self.detector.width)\n",
" img = img.view(batch_size, -1, self.detector.height, self.detector.width)\n",
" else:\n",
" img = reshape_subsampled_drr(img, self.detector, batch_size)\n",
" return img"
Expand Down Expand Up @@ -228,7 +229,7 @@
" *args, # Some batched representation of SE(3)\n",
" parameterization: str = None, # Specifies the representation of the rotation\n",
" convention: str = None, # If parameterization is Euler angles, specify convention\n",
" labels: list = None, # Labels from the mask of structures to render\n",
" mask_to_channels: bool = False, # If True, structures from the CT mask are rendered in separate channels\n",
" **kwargs, # Passed to the renderer\n",
"):\n",
" \"\"\"Generate DRR with rotational and translational parameters.\"\"\"\n",
Expand All @@ -239,30 +240,32 @@
" pose = convert(*args, parameterization=parameterization, convention=convention)\n",
" source, target = self.detector(pose)\n",
"\n",
" # Apply mask\n",
" if labels is not None:\n",
" if isinstance(labels, int):\n",
" labels = [labels]\n",
" mask = torch.any(torch.stack([self.mask == idx for idx in labels]), dim=0)\n",
" density = self.density * mask\n",
" else:\n",
" density = self.density\n",
"\n",
" # Render the DRR\n",
" if self.patch_size is not None:\n",
" kwargs[\"mask\"] = self.mask if mask_to_channels else None\n",
" if self.patch_size is None:\n",
" img = self.renderer(\n",
" self.density,\n",
" self.origin,\n",
" self.spacing,\n",
" source,\n",
" target,\n",
" **kwargs,\n",
" )\n",
" else:\n",
" n_points = target.shape[1] // self.n_patches\n",
" img = []\n",
" for idx in range(self.n_patches):\n",
" t = target[:, idx * n_points : (idx + 1) * n_points]\n",
" partial = self.renderer(\n",
" density, self.origin, self.spacing, source, t, **kwargs\n",
" self.density,\n",
" self.origin,\n",
" self.spacing,\n",
" source,\n",
" t,\n",
" **kwargs,\n",
" )\n",
" img.append(partial)\n",
" img = torch.cat(img, dim=1)\n",
" else:\n",
" img = self.renderer(\n",
" density, self.origin, self.spacing, source, target, **kwargs\n",
" )\n",
" return self.reshape_transform(img, batch_size=len(pose))"
]
},
Expand Down
Loading

0 comments on commit 0656177

Please sign in to comment.