Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement efficient rendering of segmentation masks #216

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions diffdrr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
__all__ = ['load_example_ct', 'read']

# %% ../notebooks/api/03_data.ipynb 4
def load_example_ct() -> Subject:
def load_example_ct(labels=None) -> Subject:
"""Load an example chest CT for demonstration purposes."""
datadir = Path(__file__).resolve().parent / "data"
filename = datadir / "cxr.nii.gz"
labelmap = datadir / "mask.nii.gz"
structures = pd.read_csv(datadir / "structures.csv")
return read(filename, labelmap, structures=structures)
return read(filename, labelmap, labels, structures=structures)

# %% ../notebooks/api/03_data.ipynb 5
def read(
filename: str | Path, # Path to CT volume
labelmap: str | Path = None, # Path to a labelmap for the CT volume
labels: int | list = None, # Labels from the mask of structures to render
**kwargs, # Any additional information to be stored in the torchio.Subject
) -> Subject:
"""
Expand Down Expand Up @@ -55,6 +56,16 @@ def read(
# Canonicalize the images by converting to RAS+ and moving the
# Subject's isocenter to the origin in world coordinates
subject = canonicalize(subject)

# Apply mask
if labels is not None:
if isinstance(labels, int):
labels = [labels]
mask = torch.any(
torch.stack([mask.data.squeeze() == idx for idx in labels]), dim=0
)
subject.density = subject.density * mask

return subject

# %% ../notebooks/api/03_data.ipynb 6
Expand All @@ -63,24 +74,18 @@ def canonicalize(subject):
subject = ToCanonical()(subject)

# Move the Subject's isocenter to the origin in world coordinates
isocenter = subject.volume.get_center()
Tinv = np.array(
[
[1.0, 0.0, 0.0, -isocenter[0]],
[0.0, 1.0, 0.0, -isocenter[1]],
[0.0, 0.0, 1.0, -isocenter[2]],
[0.0, 0.0, 0.0, 1.0],
]
)
for image in subject.get_images(intensity_only=False):
isocenter = image.get_center()
Tinv = np.array(
[
[1.0, 0.0, 0.0, -isocenter[0]],
[0.0, 1.0, 0.0, -isocenter[1]],
[0.0, 0.0, 1.0, -isocenter[2]],
[0.0, 0.0, 0.0, 1.0],
]
)
image.affine = Tinv.dot(image.affine)

# Need to manually change the affine matrix of the labelmap
try:
subject.mask.affine = subject.volume.affine
except AttributeError:
pass

return subject

# %% ../notebooks/api/03_data.ipynb 7
Expand Down
40 changes: 22 additions & 18 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
"origin", torch.tensor(subject.volume.origin, dtype=torch.float32)
)
if subject.mask is not None:
self.register_buffer("mask", subject.mask.data.squeeze())
self.register_buffer("mask", subject.mask.data[0].to(torch.int64))

# Initialize the renderer
if renderer == "siddon":
Expand All @@ -87,7 +87,9 @@ def __init__(
def reshape_transform(self, img, batch_size):
if self.reshape:
if self.detector.n_subsample is None:
img = img.view(-1, 1, self.detector.height, self.detector.width)
img = img.view(
batch_size, -1, self.detector.height, self.detector.width
)
else:
img = reshape_subsampled_drr(img, self.detector, batch_size)
return img
Expand All @@ -110,7 +112,7 @@ def forward(
*args, # Some batched representation of SE(3)
parameterization: str = None, # Specifies the representation of the rotation
convention: str = None, # If parameterization is Euler angles, specify convention
labels: list = None, # Labels from the mask of structures to render
mask_to_channels: bool = False, # If True, structures from the CT mask are rendered in separate channels
**kwargs, # Passed to the renderer
):
"""Generate DRR with rotational and translational parameters."""
Expand All @@ -121,30 +123,32 @@ def forward(
pose = convert(*args, parameterization=parameterization, convention=convention)
source, target = self.detector(pose)

# Apply mask
if labels is not None:
if isinstance(labels, int):
labels = [labels]
mask = torch.any(torch.stack([self.mask == idx for idx in labels]), dim=0)
density = self.density * mask
else:
density = self.density

# Render the DRR
if self.patch_size is not None:
kwargs["mask"] = self.mask if mask_to_channels else None
if self.patch_size is None:
img = self.renderer(
self.density,
self.origin,
self.spacing,
source,
target,
**kwargs,
)
else:
n_points = target.shape[1] // self.n_patches
img = []
for idx in range(self.n_patches):
t = target[:, idx * n_points : (idx + 1) * n_points]
partial = self.renderer(
density, self.origin, self.spacing, source, t, **kwargs
self.density,
self.origin,
self.spacing,
source,
t,
**kwargs,
)
img.append(partial)
img = torch.cat(img, dim=1)
else:
img = self.renderer(
density, self.origin, self.spacing, source, target, **kwargs
)
return self.reshape_transform(img, batch_size=len(pose))

# %% ../notebooks/api/00_drr.ipynb 11
Expand Down
53 changes: 42 additions & 11 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def dims(self, volume):
def maxidx(self, volume):
return volume.numel() - 1

def forward(self, volume, origin, spacing, source, target):
def forward(self, volume, origin, spacing, source, target, mask=None):
dims = self.dims(volume)
maxidx = self.maxidx(volume)
origin = origin.to(torch.float64)

alphas = _get_alphas(source, target, origin, spacing, dims, self.eps)
alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2
voxels = _get_voxel(
voxels, idxs = _get_voxel(
alphamid, source, target, volume, origin, spacing, dims, maxidx, self.eps
)

Expand All @@ -36,10 +36,29 @@ def forward(self, volume, origin, spacing, source, target):
step_length = torch.diff(alphas, dim=-1)
weighted_voxels = voxels * step_length

drr = torch.nansum(weighted_voxels, dim=-1)
# Handle optional mapping
if mask is None:
img = torch.nansum(weighted_voxels, dim=-1)
img = img.unsqueeze(1)
else:
# Thanks to @Ivan for the clutch assist w/ pytorch tensor ops
# https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614
channels = torch.take(mask, idxs) # B D N
weighted_voxels = weighted_voxels.nan_to_num()
B, D, N = weighted_voxels.shape
C = channels.max().item() + 1
img = (
torch.zeros(B, C, D)
.to(volume)
.scatter_add_(
1, channels.transpose(-1, -2), weighted_voxels.transpose(-1, -2)
)
)

# Finish rendering the DRR
raylength = (target - source + self.eps).norm(dim=-1)
drr *= raylength
return drr
img *= raylength.unsqueeze(1)
return img

# %% ../notebooks/api/01_renderers.ipynb 8
def _get_alphas(source, target, origin, spacing, dims, eps):
Expand Down Expand Up @@ -91,7 +110,7 @@ def _get_alpha_minmax(sdd, source, target, origin, spacing, dims):

def _get_voxel(alpha, source, target, volume, origin, spacing, dims, maxidx, eps):
idxs = _get_index(alpha, source, target, origin, spacing, dims, maxidx, eps)
return torch.take(volume, idxs)
return torch.take(volume, idxs), idxs


def _get_index(alpha, source, target, origin, spacing, dims, maxidx, eps):
Expand Down Expand Up @@ -134,8 +153,20 @@ def dims(self, volume):
return torch.tensor(volume.shape).to(volume) - 1

def forward(
self, volume, origin, spacing, source, target, n_points=250, align_corners=True
self,
volume,
origin,
spacing,
source,
target,
n_points=250,
align_corners=True,
mask=None,
):
# Ensure not using mask_to_channels
if mask is not None:
raise ValueErro("mask_to_channels can only be True if renderer=='Siddon'")

# Get the raylength and reshape sources
raylength = (source - target + self.eps).norm(dim=-1)
source = source[:, None, :, None, :] - origin
Expand All @@ -153,7 +184,7 @@ def forward(
# Render the DRR
batch_size = len(rays)
vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)
drr = grid_sample(vol, rays, mode=self.mode, align_corners=align_corners)
drr = drr[:, 0, 0].sum(dim=-1)
drr *= raylength / n_points
return drr
img = grid_sample(vol, rays, mode=self.mode, align_corners=align_corners)
img = img[:, 0, 0].sum(dim=-1)
img *= raylength / n_points
return img
39 changes: 21 additions & 18 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
"#| export\n",
"from torchio import Subject\n",
"\n",
"\n",
"class DRR(nn.Module):\n",
" \"\"\"PyTorch module that computes differentiable digitally reconstructed radiographs.\"\"\"\n",
"\n",
Expand Down Expand Up @@ -164,7 +165,7 @@
" \"origin\", torch.tensor(subject.volume.origin, dtype=torch.float32)\n",
" )\n",
" if subject.mask is not None:\n",
" self.register_buffer(\"mask\", subject.mask.data.squeeze())\n",
" self.register_buffer(\"mask\", subject.mask.data[0].to(torch.int64))\n",
"\n",
" # Initialize the renderer\n",
" if renderer == \"siddon\":\n",
Expand All @@ -181,7 +182,7 @@
" def reshape_transform(self, img, batch_size):\n",
" if self.reshape:\n",
" if self.detector.n_subsample is None:\n",
" img = img.view(-1, 1, self.detector.height, self.detector.width)\n",
" img = img.view(batch_size, -1, self.detector.height, self.detector.width)\n",
" else:\n",
" img = reshape_subsampled_drr(img, self.detector, batch_size)\n",
" return img"
Expand Down Expand Up @@ -228,7 +229,7 @@
" *args, # Some batched representation of SE(3)\n",
" parameterization: str = None, # Specifies the representation of the rotation\n",
" convention: str = None, # If parameterization is Euler angles, specify convention\n",
" labels: list = None, # Labels from the mask of structures to render\n",
" mask_to_channels: bool = False, # If True, structures from the CT mask are rendered in separate channels\n",
" **kwargs, # Passed to the renderer\n",
"):\n",
" \"\"\"Generate DRR with rotational and translational parameters.\"\"\"\n",
Expand All @@ -239,30 +240,32 @@
" pose = convert(*args, parameterization=parameterization, convention=convention)\n",
" source, target = self.detector(pose)\n",
"\n",
" # Apply mask\n",
" if labels is not None:\n",
" if isinstance(labels, int):\n",
" labels = [labels]\n",
" mask = torch.any(torch.stack([self.mask == idx for idx in labels]), dim=0)\n",
" density = self.density * mask\n",
" else:\n",
" density = self.density\n",
"\n",
" # Render the DRR\n",
" if self.patch_size is not None:\n",
" kwargs[\"mask\"] = self.mask if mask_to_channels else None\n",
" if self.patch_size is None:\n",
" img = self.renderer(\n",
" self.density,\n",
" self.origin,\n",
" self.spacing,\n",
" source,\n",
" target,\n",
" **kwargs,\n",
" )\n",
" else:\n",
" n_points = target.shape[1] // self.n_patches\n",
" img = []\n",
" for idx in range(self.n_patches):\n",
" t = target[:, idx * n_points : (idx + 1) * n_points]\n",
" partial = self.renderer(\n",
" density, self.origin, self.spacing, source, t, **kwargs\n",
" self.density,\n",
" self.origin,\n",
" self.spacing,\n",
" source,\n",
" t,\n",
" **kwargs,\n",
" )\n",
" img.append(partial)\n",
" img = torch.cat(img, dim=1)\n",
" else:\n",
" img = self.renderer(\n",
" density, self.origin, self.spacing, source, target, **kwargs\n",
" )\n",
" return self.reshape_transform(img, batch_size=len(pose))"
]
},
Expand Down
Loading
Loading