Skip to content

Commit

Permalink
Implement separate regression submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Apr 17, 2024
1 parent acd0c67 commit 180645d
Show file tree
Hide file tree
Showing 9 changed files with 421 additions and 370 deletions.
24 changes: 18 additions & 6 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
'diffdrr.drr.DRR.perspective_projection': ('api/drr.html#drr.perspective_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.reshape_transform': ('api/drr.html#drr.reshape_transform', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.set_intrinsics': ('api/drr.html#drr.set_intrinsics', 'diffdrr/drr.py'),
'diffdrr.drr.Registration': ('api/drr.html#registration', 'diffdrr/drr.py'),
'diffdrr.drr.Registration.__init__': ('api/drr.html#registration.__init__', 'diffdrr/drr.py'),
'diffdrr.drr.Registration.forward': ('api/drr.html#registration.forward', 'diffdrr/drr.py'),
'diffdrr.drr.Registration.pose': ('api/drr.html#registration.pose', 'diffdrr/drr.py'),
'diffdrr.drr.Registration.rotation': ('api/drr.html#registration.rotation', 'diffdrr/drr.py'),
'diffdrr.drr.Registration.translation': ('api/drr.html#registration.translation', 'diffdrr/drr.py'),
'diffdrr.drr.reshape_subsampled_drr': ('api/drr.html#reshape_subsampled_drr', 'diffdrr/drr.py')},
'diffdrr.metrics': { 'diffdrr.metrics.GradientNormalizedCrossCorrelation2d': ( 'api/metrics.html#gradientnormalizedcrosscorrelation2d',
'diffdrr/metrics.py'),
Expand Down Expand Up @@ -108,6 +102,24 @@
'diffdrr.pose.so3_relative_angle': ('api/pose.html#so3_relative_angle', 'diffdrr/pose.py'),
'diffdrr.pose.so3_rotation_angle': ('api/pose.html#so3_rotation_angle', 'diffdrr/pose.py'),
'diffdrr.pose.standardize_quaternion': ('api/pose.html#standardize_quaternion', 'diffdrr/pose.py')},
'diffdrr.registration': { 'diffdrr.registration.PoseRegressor': ( 'api/registration.html#poseregressor',
'diffdrr/registration.py'),
'diffdrr.registration.PoseRegressor.__init__': ( 'api/registration.html#poseregressor.__init__',
'diffdrr/registration.py'),
'diffdrr.registration.PoseRegressor.forward': ( 'api/registration.html#poseregressor.forward',
'diffdrr/registration.py'),
'diffdrr.registration.Registration': ( 'api/registration.html#registration',
'diffdrr/registration.py'),
'diffdrr.registration.Registration.__init__': ( 'api/registration.html#registration.__init__',
'diffdrr/registration.py'),
'diffdrr.registration.Registration.forward': ( 'api/registration.html#registration.forward',
'diffdrr/registration.py'),
'diffdrr.registration.Registration.pose': ( 'api/registration.html#registration.pose',
'diffdrr/registration.py'),
'diffdrr.registration.Registration.rotation': ( 'api/registration.html#registration.rotation',
'diffdrr/registration.py'),
'diffdrr.registration.Registration.translation': ( 'api/registration.html#registration.translation',
'diffdrr/registration.py')},
'diffdrr.renderers': { 'diffdrr.renderers.Siddon': ('api/renderers.html#siddon', 'diffdrr/renderers.py'),
'diffdrr.renderers.Siddon.__init__': ('api/renderers.html#siddon.__init__', 'diffdrr/renderers.py'),
'diffdrr.renderers.Siddon.dims': ('api/renderers.html#siddon.dims', 'diffdrr/renderers.py'),
Expand Down
58 changes: 1 addition & 57 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .renderers import Siddon, Trilinear

# %% auto 0
__all__ = ['DRR', 'Registration']
__all__ = ['DRR']

# %% ../notebooks/api/00_drr.ipynb 7
from torchio import Subject
Expand Down Expand Up @@ -207,59 +207,3 @@ def inverse_projection(
pad(pts, (0, 1), value=1), # Convert to homogenous coordinates
)
return extrinsic(x)

# %% ../notebooks/api/00_drr.ipynb 15
class Registration(nn.Module):
"""Perform automatic 2D-to-3D registration using differentiable rendering."""

def __init__(
self,
drr: DRR,
rotation: torch.Tensor,
translation: torch.Tensor,
parameterization: str,
convention: str = None,
):
super().__init__()
self.drr = drr
self._rotation = nn.Parameter(rotation)
self._translation = nn.Parameter(translation)
self.parameterization = parameterization
self.convention = convention

def forward(self):
return self.drr(self.pose)

@property
def pose(self):
R = convert(
self._rotation,
torch.tensor([[0.0, 0.0, 0.0]]).to(self._rotation),
parameterization=self.parameterization,
convention=self.convention,
)
t = convert(
torch.tensor([[0.0, 0.0, 0.0]]).to(self._translation),
self._translation,
parameterization=self.parameterization,
convention=self.convention,
)
return t.compose(R)

@property
def rotation(self):
return (
self.pose.convert(self.parameterization, self.convention)[0]
.clone()
.detach()
.cpu()
)

@property
def translation(self):
return (
self.pose.convert(self.parameterization, self.convention)[1]
.clone()
.detach()
.cpu()
)
127 changes: 127 additions & 0 deletions diffdrr/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/08_registration.ipynb.

# %% auto 0
__all__ = ['Registration', 'PoseRegressor']

# %% ../notebooks/api/08_registration.ipynb 4
import torch
import torch.nn as nn

from .drr import DRR


class Registration(nn.Module):
"""Perform automatic 2D-to-3D registration using differentiable rendering."""

def __init__(
self,
drr: DRR, # Preinitialized DRR module
rotation: torch.Tensor, # Initial guess for rotations
translation: torch.Tensor, # Initial guess for translations
parameterization: str, # Specifies the representation of the rotation
convention: str = None, # If `parameterization` is `euler_angles`, specify convention
):
super().__init__()
self.drr = drr
self._rotation = nn.Parameter(rotation)
self._translation = nn.Parameter(translation)
self.parameterization = parameterization
self.convention = convention

def forward(self, **kwargs):
return self.drr(self.pose, **kwargs)

@property
def pose(self):
R = convert(
self._rotation,
torch.tensor([[0.0, 0.0, 0.0]]).to(self._rotation),
parameterization=self.parameterization,
convention=self.convention,
)
t = convert(
torch.tensor([[0.0, 0.0, 0.0]]).to(self._translation),
self._translation,
parameterization=self.parameterization,
convention=self.convention,
)
return t.compose(R)

@property
def rotation(self):
return (
self.pose.convert(self.parameterization, self.convention)[0]
.clone()
.detach()
.cpu()
)

@property
def translation(self):
return (
self.pose.convert(self.parameterization, self.convention)[1]
.clone()
.detach()
.cpu()
)

# %% ../notebooks/api/08_registration.ipynb 6
import timm

from .pose import RigidTransform, convert


class PoseRegressor(torch.nn.Module):
"""
A PoseRegressor is comprised of a pretrained backbone model that extracts features
from an input X-ray and two linear layers that decode these features into rotational
and translational camera pose parameters, respectively.
"""

def __init__(
self,
model_name,
parameterization,
convention=None,
pretrained=False,
**kwargs,
):
super().__init__()

self.parameterization = parameterization
self.convention = convention
n_angular_components = N_ANGULAR_COMPONENTS[parameterization]

# Get the size of the output from the backbone
self.backbone = timm.create_model(
model_name,
pretrained,
num_classes=0,
in_chans=1,
**kwargs,
)
output = self.backbone(torch.randn(1, 1, 256, 256)).shape[-1]
self.xyz_regression = torch.nn.Linear(output, 3)
self.rot_regression = torch.nn.Linear(output, n_angular_components)

def forward(self, x):
x = self.backbone(x)
rot = self.rot_regression(x)
xyz = self.xyz_regression(x)
return convert(
rot,
xyz,
parameterization=self.parameterization,
convention=self.convention,
)

# %% ../notebooks/api/08_registration.ipynb 7
N_ANGULAR_COMPONENTS = {
"axis_angle": 3,
"euler_angles": 3,
"se3_log_map": 3,
"quaternion": 4,
"rotation_6d": 6,
"rotation_10d": 10,
"quaternion_adjugate": 10,
}
5 changes: 3 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ dependencies:
- einops
- matplotlib
- seaborn
- tqdm
- imageio
- fastcore
- tqdm
- scipy
- pip:
- pyvista
- vtk
- timm
- torchio
- vtk
74 changes: 0 additions & 74 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -350,80 +350,6 @@
" return extrinsic(x)"
]
},
{
"cell_type": "markdown",
"id": "10bc1b2b-a444-45dc-8430-56646d54f95f",
"metadata": {},
"source": [
"## Registration\n",
"\n",
"The `Registration` module uses the `DRR` module to perform differentiable 2D-to-3D registration. Initial guesses for the pose parameters are as stored as `nn.Parameters` of the module. This allows the pose parameters to be optimized with any PyTorch optimizer. Furthermore, this design choice allows `DRR` to be used purely as a differentiable renderer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "868fd7b9-e83d-43dd-89e2-3d49e0434da9",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class Registration(nn.Module):\n",
" \"\"\"Perform automatic 2D-to-3D registration using differentiable rendering.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" drr: DRR,\n",
" rotation: torch.Tensor,\n",
" translation: torch.Tensor,\n",
" parameterization: str,\n",
" convention: str = None,\n",
" ):\n",
" super().__init__()\n",
" self.drr = drr\n",
" self._rotation = nn.Parameter(rotation)\n",
" self._translation = nn.Parameter(translation)\n",
" self.parameterization = parameterization\n",
" self.convention = convention\n",
"\n",
" def forward(self):\n",
" return self.drr(self.pose)\n",
"\n",
" @property\n",
" def pose(self):\n",
" R = convert(\n",
" self._rotation,\n",
" torch.tensor([[0.0, 0.0, 0.0]]).to(self._rotation),\n",
" parameterization=self.parameterization,\n",
" convention=self.convention,\n",
" )\n",
" t = convert(\n",
" torch.tensor([[0.0, 0.0, 0.0]]).to(self._translation),\n",
" self._translation,\n",
" parameterization=self.parameterization,\n",
" convention=self.convention,\n",
" )\n",
" return t.compose(R)\n",
"\n",
" @property\n",
" def rotation(self):\n",
" return (\n",
" self.pose.convert(self.parameterization, self.convention)[0]\n",
" .clone()\n",
" .detach()\n",
" .cpu()\n",
" )\n",
"\n",
" @property\n",
" def translation(self):\n",
" return (\n",
" self.pose.convert(self.parameterization, self.convention)[1]\n",
" .clone()\n",
" .detach()\n",
" .cpu()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Loading

0 comments on commit 180645d

Please sign in to comment.