Skip to content

Commit

Permalink
Merge pull request #230 from eigenvivek/trilinear-mask
Browse files Browse the repository at this point in the history
Make  compatible with trilinear renderer
  • Loading branch information
eigenvivek authored Apr 28, 2024
2 parents a787a0b + f5b8f53 commit fdb6392
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 29 deletions.
37 changes: 29 additions & 8 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def forward(self, volume, origin, spacing, source, target, mask=None):
step_length = torch.diff(alphas, dim=-1)
weighted_voxels = voxels * step_length

# Handle optional mapping
# Handle optional masking
if mask is None:
img = torch.nansum(weighted_voxels, dim=-1)
img = img.unsqueeze(1)
Expand Down Expand Up @@ -163,10 +163,6 @@ def forward(
align_corners=True,
mask=None,
):
# Ensure not using mask_to_channels
if mask is not None:
raise ValueErro("mask_to_channels can only be True if renderer=='Siddon'")

# Get the raylength and reshape sources
raylength = (source - target + self.eps).norm(dim=-1)
source = source[:, None, :, None, :] - origin
Expand All @@ -180,11 +176,36 @@ def forward(

# Reorder array to match torch conventions
volume = volume.permute(2, 1, 0)
if mask is not None:
mask = mask.permute(2, 1, 0)

# Render the DRR
batch_size = len(rays)
vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)
img = grid_sample(vol, rays, mode=self.mode, align_corners=align_corners)
img = img[:, 0, 0].sum(dim=-1)
img = grid_sample(
volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1),
rays,
mode=self.mode,
align_corners=align_corners,
)[:, 0, 0]

# Handle optional masking
if mask is None:
img = img.sum(dim=-1).unsqueeze(1)
else:
B, D, N = img.shape
C = mask.max().item() + 1
channels = grid_sample(
mask[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1).float(),
rays,
mode="nearest",
align_corners=align_corners,
).long()[:, 0, 0]
img = (
torch.zeros(B, C, D)
.to(volume)
.scatter_add_(1, channels.transpose(-1, -2), img.transpose(-1, -2))
)

# Multiply by raylength
img *= raylength / n_points
return img
55 changes: 44 additions & 11 deletions notebooks/api/01_renderers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
" step_length = torch.diff(alphas, dim=-1)\n",
" weighted_voxels = voxels * step_length\n",
"\n",
" # Handle optional mapping\n",
" # Handle optional masking\n",
" if mask is None:\n",
" img = torch.nansum(weighted_voxels, dim=-1)\n",
" img = img.unsqueeze(1)\n",
Expand All @@ -149,8 +149,12 @@
" weighted_voxels = weighted_voxels.nan_to_num()\n",
" B, D, N = weighted_voxels.shape\n",
" C = mask.max().item() + 1\n",
" img = torch.zeros(B, C, D).to(volume).scatter_add_(\n",
" 1, channels.transpose(-1, -2), weighted_voxels.transpose(-1, -2)\n",
" img = (\n",
" torch.zeros(B, C, D)\n",
" .to(volume)\n",
" .scatter_add_(\n",
" 1, channels.transpose(-1, -2), weighted_voxels.transpose(-1, -2)\n",
" )\n",
" )\n",
"\n",
" # Finish rendering the DRR\n",
Expand Down Expand Up @@ -280,12 +284,16 @@
" return torch.tensor(volume.shape).to(volume) - 1\n",
"\n",
" def forward(\n",
" self, volume, origin, spacing, source, target, n_points=500, align_corners=True, mask=None,\n",
" self,\n",
" volume,\n",
" origin,\n",
" spacing,\n",
" source,\n",
" target,\n",
" n_points=500,\n",
" align_corners=True,\n",
" mask=None,\n",
" ):\n",
" # Ensure not using mask_to_channels\n",
" if mask is not None:\n",
" raise ValueErro(\"mask_to_channels can only be True if renderer=='Siddon'\")\n",
" \n",
" # Get the raylength and reshape sources\n",
" raylength = (source - target + self.eps).norm(dim=-1)\n",
" source = source[:, None, :, None, :] - origin\n",
Expand All @@ -299,12 +307,37 @@
"\n",
" # Reorder array to match torch conventions\n",
" volume = volume.permute(2, 1, 0)\n",
" if mask is not None:\n",
" mask = mask.permute(2, 1, 0)\n",
"\n",
" # Render the DRR\n",
" batch_size = len(rays)\n",
" vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)\n",
" img = grid_sample(vol, rays, mode=self.mode, align_corners=align_corners)\n",
" img = img[:, 0, 0].sum(dim=-1)\n",
" img = grid_sample(\n",
" volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1),\n",
" rays,\n",
" mode=self.mode,\n",
" align_corners=align_corners,\n",
" )[:, 0, 0]\n",
"\n",
" # Handle optional masking\n",
" if mask is None:\n",
" img = img.sum(dim=-1).unsqueeze(1)\n",
" else:\n",
" B, D, N = img.shape\n",
" C = mask.max().item() + 1\n",
" channels = grid_sample(\n",
" mask[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1).float(),\n",
" rays,\n",
" mode=\"nearest\",\n",
" align_corners=align_corners,\n",
" ).long()[:, 0, 0]\n",
" img = (\n",
" torch.zeros(B, C, D)\n",
" .to(volume)\n",
" .scatter_add_(1, channels.transpose(-1, -2), img.transpose(-1, -2))\n",
" )\n",
"\n",
" # Multiply by raylength\n",
" img *= raylength / n_points\n",
" return img"
]
Expand Down
10 changes: 0 additions & 10 deletions notebooks/tutorials/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,6 @@
"The first way to do this is to set `mask_to_channels=True` in `DRR.forward`, which will create a new channel for every structure. "
]
},
{
"cell_type": "raw",
"id": "b9e0c8dd-9e79-4554-a1f5-ccaaa56548e4",
"metadata": {},
"source": [
"::: {.callout-tip}\n",
"Note `mask_to_channels` is only an option for the `Siddon` renderer (which is the default option).\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit fdb6392

Please sign in to comment.