Skip to content

Commit

Permalink
Make no_grad in Siddon switchable (#275)
Browse files Browse the repository at this point in the history
* Remove no_grad call

* Make  switchable

* Add comments on

* Bump version
  • Loading branch information
eigenvivek authored Jun 15, 2024
1 parent e04e1d8 commit 95fa604
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 84 deletions.
2 changes: 1 addition & 1 deletion diffdrr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.1"
__version__ = "0.4.2"
7 changes: 6 additions & 1 deletion diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ class Siddon(torch.nn.Module):
def __init__(
self,
mode="nearest",
stop_gradients_through_grid_sample=False,
eps=1e-8,
):
super().__init__()
self.mode = mode
self.stop_gradients_through_grid_sample = stop_gradients_through_grid_sample
self.eps = eps

def dims(self, volume):
Expand Down Expand Up @@ -49,7 +51,10 @@ def forward(
xyzs = _get_xyzs(alphamid, source, target, origin, spacing, dims, self.eps)

# Use torch.nn.functional.grid_sample to lookup the values of each intersected voxel
with torch.no_grad():
if self.stop_gradients_through_grid_sample:
with torch.no_grad():
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
else:
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)

# Weight each intersected voxel by the length of the ray's intersection with the voxel
Expand Down
9 changes: 7 additions & 2 deletions notebooks/api/01_renderers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@
" def __init__(\n",
" self,\n",
" mode=\"nearest\",\n",
" stop_gradients_through_grid_sample=False,\n",
" eps=1e-8,\n",
" ):\n",
" super().__init__()\n",
" self.mode = mode\n",
" self.stop_gradients_through_grid_sample = stop_gradients_through_grid_sample\n",
" self.eps = eps\n",
"\n",
" def dims(self, volume):\n",
Expand All @@ -145,11 +147,14 @@
" # These midpoints lie exclusively in a single voxel\n",
" alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2\n",
"\n",
" # Get the XYZ coordinate of each midpoint (normalized to [-1, +1]^3) \n",
" # Get the XYZ coordinate of each midpoint (normalized to [-1, +1]^3)\n",
" xyzs = _get_xyzs(alphamid, source, target, origin, spacing, dims, self.eps)\n",
"\n",
" # Use torch.nn.functional.grid_sample to lookup the values of each intersected voxel\n",
" with torch.no_grad():\n",
" if self.stop_gradients_through_grid_sample:\n",
" with torch.no_grad():\n",
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
" else:\n",
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
"\n",
" # Weight each intersected voxel by the length of the ray's intersection with the voxel\n",
Expand Down
107 changes: 51 additions & 56 deletions notebooks/tutorials/optimizers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@
"\n",
"from diffdrr.data import load_example_ct\n",
"from diffdrr.drr import DRR\n",
"from diffdrr.metrics import (\n",
" MultiscaleNormalizedCrossCorrelation2d,\n",
" NormalizedCrossCorrelation2d,\n",
")\n",
"from diffdrr.metrics import NormalizedCrossCorrelation2d\n",
"from diffdrr.pose import convert\n",
"from diffdrr.registration import Registration\n",
"from diffdrr.visualization import plot_drr\n",
"\n",
Expand All @@ -87,30 +85,10 @@
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"drr = DRR(subject, sdd=SDD, height=HEIGHT, delx=DELX).to(device)\n",
"rotations = torch.tensor(\n",
" [\n",
" [\n",
" true_params[\"alpha\"],\n",
" true_params[\"beta\"],\n",
" true_params[\"gamma\"],\n",
" ]\n",
" ]\n",
").to(device)\n",
"translations = torch.tensor(\n",
" [\n",
" [\n",
" true_params[\"bx\"],\n",
" true_params[\"by\"],\n",
" true_params[\"bz\"],\n",
" ]\n",
" ]\n",
").to(device)\n",
"ground_truth = drr(\n",
" rotations,\n",
" translations,\n",
" parameterization=\"euler_angles\",\n",
" convention=\"ZXY\",\n",
")\n",
"rotations = torch.tensor([[true_params[\"alpha\"], true_params[\"beta\"], true_params[\"gamma\"]]])\n",
"translations = torch.tensor([[true_params[\"bx\"], true_params[\"by\"], true_params[\"bz\"]]])\n",
"gt_pose = convert(rotations, translations, parameterization=\"euler_angles\", convention=\"ZXY\").to(device)\n",
"ground_truth = drr(gt_pose)\n",
"\n",
"plot_drr(ground_truth)\n",
"plt.show()"
Expand All @@ -124,24 +102,6 @@
"## 2. Initialize a moving DRR from a random pose"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "596ea362-842d-4b19-8cbd-9b052e7f59f6",
"metadata": {},
"outputs": [],
"source": [
"from diffdrr.pose import convert\n",
"\n",
"\n",
"def pose_from_carm(sid, tx, ty, alpha, beta, gamma):\n",
" rot = torch.tensor([[alpha, beta, gamma]])\n",
" xyz = torch.tensor([[tx, sid, ty]])\n",
" return convert(rot, xyz, parameterization=\"euler_angles\", convention=\"ZXY\")\n",
"\n",
"gt_pose = convert(rotations, translations, parameterization=\"euler_angles\", convention=\"ZXY\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -176,6 +136,12 @@
"np.random.seed(5)\n",
"\n",
"\n",
"def pose_from_carm(sid, tx, ty, alpha, beta, gamma):\n",
" rot = torch.tensor([[alpha, beta, gamma]])\n",
" xyz = torch.tensor([[tx, sid, ty]])\n",
" return convert(rot, xyz, parameterization=\"euler_angles\", convention=\"ZXY\")\n",
"\n",
"\n",
"def get_initial_parameters(true_params):\n",
" alpha = true_params[\"alpha\"] + np.random.uniform(-np.pi / 4, np.pi / 4)\n",
" beta = true_params[\"beta\"] + np.random.uniform(-np.pi / 4, np.pi / 4)\n",
Expand Down Expand Up @@ -302,7 +268,7 @@
" optimizer.zero_grad()\n",
" estimate = reg()\n",
" loss = criterion(ground_truth, estimate)\n",
" loss.backward(retain_graph=True)\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.item())\n",
" pbar.set_description(f\"{loss.item():06f}\")\n",
Expand All @@ -322,7 +288,36 @@
"id": "dbfde8a4-d5e1-46c6-8c0e-88c506b9de2c",
"metadata": {},
"source": [
"## Run the optimization algorithm"
"## 5. Run the optimization algorithm\n",
"\n",
"Compare different iterative optimization stratigies with gradient descent. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "000781a2-1280-4894-a93e-97bebefaa475",
"metadata": {},
"outputs": [],
"source": [
"# Keyword arguments for diffdrr.drr.DRR\n",
"kwargs = {\n",
" \"subject\": subject,\n",
" \"sdd\": SDD,\n",
" \"height\": HEIGHT,\n",
" \"delx\": DELX,\n",
" \"stop_gradients_through_grid_sample\": True,\n",
"}"
]
},
{
"cell_type": "raw",
"id": "5e6cf198-319f-4318-8269-ce3533379981",
"metadata": {},
"source": [
"::: {.callout-tip}\n",
"For 2D/3D registration with Siddon's method, we don't need gradients calculated through the `grid_sample` (which uses nearest neighbors). To avoid computing these gradients, which improves rendering speed, you can set `stop_gradients_through_grid_sample=True`.\n",
":::"
]
},
{
Expand All @@ -335,7 +330,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"0.999050: 76%|██████████████████████████████████████▉ | 191/250 [00:03<00:01, 48.92it/s]\n"
"0.999050: 76%|██████████████████████████████████████▉ | 191/250 [00:03<00:01, 50.76it/s]\n"
]
},
{
Expand All @@ -350,7 +345,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"0.999136: 28%|██████████████▊ | 71/250 [00:01<00:03, 59.33it/s]\n"
"0.999136: 28%|██████████████▊ | 71/250 [00:01<00:03, 57.59it/s]\n"
]
},
{
Expand All @@ -365,7 +360,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"0.999118: 31%|████████████████ | 77/250 [00:01<00:02, 59.62it/s]\n"
"0.999118: 31%|████████████████ | 77/250 [00:01<00:02, 65.47it/s]\n"
]
},
{
Expand All @@ -380,7 +375,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"0.999281: 19%|█████████▉ | 48/250 [00:00<00:03, 56.47it/s]"
"0.999281: 19%|█████████▉ | 48/250 [00:00<00:03, 55.75it/s]"
]
},
{
Expand All @@ -401,7 +396,7 @@
],
"source": [
"# Base SGD\n",
"drr = DRR(subject, sdd=SDD, height=HEIGHT, delx=DELX).to(device)\n",
"drr = DRR(**kwargs).to(device)\n",
"reg = Registration(\n",
" drr,\n",
" rotations.clone(),\n",
Expand All @@ -414,7 +409,7 @@
"del drr\n",
"\n",
"# SGD + momentum\n",
"drr = DRR(subject, sdd=SDD, height=HEIGHT, delx=DELX).to(device)\n",
"drr = DRR(**kwargs).to(device)\n",
"reg = Registration(\n",
" drr,\n",
" rotations.clone(),\n",
Expand All @@ -427,7 +422,7 @@
"del drr\n",
"\n",
"# SGD + momentum + dampening\n",
"drr = DRR(subject, sdd=SDD, height=HEIGHT, delx=DELX).to(device)\n",
"drr = DRR(**kwargs).to(device)\n",
"reg = Registration(\n",
" drr,\n",
" rotations.clone(),\n",
Expand All @@ -440,7 +435,7 @@
"del drr\n",
"\n",
"# Adam\n",
"drr = DRR(subject, sdd=SDD, height=HEIGHT, delx=DELX).to(device)\n",
"drr = DRR(**kwargs).to(device)\n",
"reg = Registration(\n",
" drr,\n",
" rotations.clone(),\n",
Expand Down
44 changes: 21 additions & 23 deletions notebooks/tutorials/reconstruction.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[DEFAULT]
repo = DiffDRR
lib_name = diffdrr
version = 0.4.1
version = 0.4.2
min_python = 3.7
license = mit
black_formatting = True
Expand Down

0 comments on commit 95fa604

Please sign in to comment.