-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement separate regression submodule
- Loading branch information
1 parent
acd0c67
commit 180645d
Showing
9 changed files
with
421 additions
and
370 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/08_registration.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['Registration', 'PoseRegressor'] | ||
|
||
# %% ../notebooks/api/08_registration.ipynb 4 | ||
import torch | ||
import torch.nn as nn | ||
|
||
from .drr import DRR | ||
|
||
|
||
class Registration(nn.Module): | ||
"""Perform automatic 2D-to-3D registration using differentiable rendering.""" | ||
|
||
def __init__( | ||
self, | ||
drr: DRR, # Preinitialized DRR module | ||
rotation: torch.Tensor, # Initial guess for rotations | ||
translation: torch.Tensor, # Initial guess for translations | ||
parameterization: str, # Specifies the representation of the rotation | ||
convention: str = None, # If `parameterization` is `euler_angles`, specify convention | ||
): | ||
super().__init__() | ||
self.drr = drr | ||
self._rotation = nn.Parameter(rotation) | ||
self._translation = nn.Parameter(translation) | ||
self.parameterization = parameterization | ||
self.convention = convention | ||
|
||
def forward(self, **kwargs): | ||
return self.drr(self.pose, **kwargs) | ||
|
||
@property | ||
def pose(self): | ||
R = convert( | ||
self._rotation, | ||
torch.tensor([[0.0, 0.0, 0.0]]).to(self._rotation), | ||
parameterization=self.parameterization, | ||
convention=self.convention, | ||
) | ||
t = convert( | ||
torch.tensor([[0.0, 0.0, 0.0]]).to(self._translation), | ||
self._translation, | ||
parameterization=self.parameterization, | ||
convention=self.convention, | ||
) | ||
return t.compose(R) | ||
|
||
@property | ||
def rotation(self): | ||
return ( | ||
self.pose.convert(self.parameterization, self.convention)[0] | ||
.clone() | ||
.detach() | ||
.cpu() | ||
) | ||
|
||
@property | ||
def translation(self): | ||
return ( | ||
self.pose.convert(self.parameterization, self.convention)[1] | ||
.clone() | ||
.detach() | ||
.cpu() | ||
) | ||
|
||
# %% ../notebooks/api/08_registration.ipynb 6 | ||
import timm | ||
|
||
from .pose import RigidTransform, convert | ||
|
||
|
||
class PoseRegressor(torch.nn.Module): | ||
""" | ||
A PoseRegressor is comprised of a pretrained backbone model that extracts features | ||
from an input X-ray and two linear layers that decode these features into rotational | ||
and translational camera pose parameters, respectively. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name, | ||
parameterization, | ||
convention=None, | ||
pretrained=False, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
|
||
self.parameterization = parameterization | ||
self.convention = convention | ||
n_angular_components = N_ANGULAR_COMPONENTS[parameterization] | ||
|
||
# Get the size of the output from the backbone | ||
self.backbone = timm.create_model( | ||
model_name, | ||
pretrained, | ||
num_classes=0, | ||
in_chans=1, | ||
**kwargs, | ||
) | ||
output = self.backbone(torch.randn(1, 1, 256, 256)).shape[-1] | ||
self.xyz_regression = torch.nn.Linear(output, 3) | ||
self.rot_regression = torch.nn.Linear(output, n_angular_components) | ||
|
||
def forward(self, x): | ||
x = self.backbone(x) | ||
rot = self.rot_regression(x) | ||
xyz = self.xyz_regression(x) | ||
return convert( | ||
rot, | ||
xyz, | ||
parameterization=self.parameterization, | ||
convention=self.convention, | ||
) | ||
|
||
# %% ../notebooks/api/08_registration.ipynb 7 | ||
N_ANGULAR_COMPONENTS = { | ||
"axis_angle": 3, | ||
"euler_angles": 3, | ||
"se3_log_map": 3, | ||
"quaternion": 4, | ||
"rotation_6d": 6, | ||
"rotation_10d": 10, | ||
"quaternion_adjugate": 10, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.