Skip to content

Commit

Permalink
Update SE(3) backend
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Feb 25, 2024
1 parent 4329ce2 commit 0833ac3
Show file tree
Hide file tree
Showing 19 changed files with 928 additions and 696 deletions.
24 changes: 3 additions & 21 deletions diffpose/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,7 @@
'doc_host': 'https://vivekg.dev',
'git_url': 'https://github.com/eigenvivek/DiffPose',
'lib_path': 'diffpose'},
'syms': { 'diffpose.calibration': { 'diffpose.calibration.RigidTransform': ( 'api/calibration.html#rigidtransform',
'diffpose/calibration.py'),
'diffpose.calibration.RigidTransform.__init__': ( 'api/calibration.html#rigidtransform.__init__',
'diffpose/calibration.py'),
'diffpose.calibration.RigidTransform.clone': ( 'api/calibration.html#rigidtransform.clone',
'diffpose/calibration.py'),
'diffpose.calibration.RigidTransform.compose': ( 'api/calibration.html#rigidtransform.compose',
'diffpose/calibration.py'),
'diffpose.calibration.RigidTransform.get_rotation': ( 'api/calibration.html#rigidtransform.get_rotation',
'diffpose/calibration.py'),
'diffpose.calibration.RigidTransform.get_se3_log': ( 'api/calibration.html#rigidtransform.get_se3_log',
'diffpose/calibration.py'),
'diffpose.calibration.RigidTransform.get_translation': ( 'api/calibration.html#rigidtransform.get_translation',
'diffpose/calibration.py'),
'diffpose.calibration.RigidTransform.inverse': ( 'api/calibration.html#rigidtransform.inverse',
'diffpose/calibration.py'),
'diffpose.calibration.convert': ('api/calibration.html#convert', 'diffpose/calibration.py'),
'diffpose.calibration.perspective_projection': ( 'api/calibration.html#perspective_projection',
'syms': { 'diffpose.calibration': { 'diffpose.calibration.perspective_projection': ( 'api/calibration.html#perspective_projection',
'diffpose/calibration.py')},
'diffpose.deepfluoro': { 'diffpose.deepfluoro.DeepFluoroDataset': ( 'api/deepfluoro.html#deepfluorodataset',
'diffpose/deepfluoro.py'),
Expand Down Expand Up @@ -72,10 +55,9 @@
'diffpose/jacobians.py'),
'diffpose.jacobians.JacobianDRR.permute': ( 'api/jacobians.html#jacobiandrr.permute',
'diffpose/jacobians.py'),
'diffpose.jacobians.gradient_matching': ( 'api/jacobians.html#gradient_matching',
'diffpose/jacobians.py'),
'diffpose.jacobians.plot_img_jacobian': ( 'api/jacobians.html#plot_img_jacobian',
'diffpose/jacobians.py')},
'diffpose/jacobians.py'),
'diffpose.jacobians.preconditioner': ('api/jacobians.html#preconditioner', 'diffpose/jacobians.py')},
'diffpose.ljubljana': { 'diffpose.ljubljana.Evaluator': ('api/ljubljana.html#evaluator', 'diffpose/ljubljana.py'),
'diffpose.ljubljana.Evaluator.__call__': ( 'api/ljubljana.html#evaluator.__call__',
'diffpose/ljubljana.py'),
Expand Down
113 changes: 3 additions & 110 deletions diffpose/calibration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/02_calibration.ipynb.

# %% auto 0
__all__ = ['RigidTransform', 'convert', 'perspective_projection']
__all__ = ['perspective_projection']

# %% ../notebooks/api/02_calibration.ipynb 4
import torch
Expand All @@ -10,124 +10,17 @@
from typing import Optional

from beartype import beartype
from diffdrr.utils import Transform3d
from diffdrr.utils import convert as convert_so3
from diffdrr.utils import se3_exp_map, se3_log_map
from diffdrr.pose import RigidTransform
from jaxtyping import Float, jaxtyped

# %% ../notebooks/api/02_calibration.ipynb 7
@beartype
class RigidTransform(Transform3d):
"""Wrapper of pytorch3d.transforms.Transform3d with extra functionalities."""

@jaxtyped(typechecker=beartype)
def __init__(
self,
R: Float[torch.Tensor, "..."],
t: Float[torch.Tensor, "... 3"],
parameterization: str = "matrix",
convention: Optional[str] = None,
device=None,
dtype=torch.float32,
):
if device is None and (R.device == t.device):
device = R.device

R = convert_so3(R, parameterization, "matrix", convention)
if R.dim() == 2 and t.dim() == 1:
R = R.unsqueeze(0)
t = t.unsqueeze(0)
assert (batch_size := len(R)) == len(t), "R and t need same batch size"

matrix = torch.zeros(batch_size, 4, 4, device=device, dtype=dtype)
matrix[..., :3, :3] = R.transpose(-1, -2)
matrix[..., 3, :3] = t
matrix[..., 3, 3] = 1

super().__init__(matrix=matrix, device=device, dtype=dtype)

def get_rotation(self, parameterization=None, convention=None):
R = self.get_matrix()[..., :3, :3].transpose(-1, -2)
if parameterization is not None:
R = convert_so3(R, "matrix", parameterization, None, convention)
return R

def get_translation(self):
return self.get_matrix()[..., 3, :3]

def inverse(self):
"""Closed-form inverse for rigid transforms."""
R = self.get_rotation().transpose(-1, -2)
t = self.get_translation()
t = -torch.einsum("bij,bj->bi", R, t)
return RigidTransform(R, t, device=self.device, dtype=self.dtype)

def compose(self, other):
T = super().compose(other)
R = T.get_matrix()[..., :3, :3].transpose(-1, -2)
t = T.get_matrix()[..., 3, :3]
return RigidTransform(R, t, device=self.device, dtype=self.dtype)

def clone(self):
R = self.get_matrix()[..., :3, :3].transpose(-1, -2).clone()
t = self.get_matrix()[..., 3, :3].clone()
return RigidTransform(R, t, device=self.device, dtype=self.dtype)

def get_se3_log(self):
return se3_log_map(self.get_matrix())

# %% ../notebooks/api/02_calibration.ipynb 8
def convert(
transform,
input_parameterization,
output_parameterization,
input_convention=None,
output_convention=None,
):
"""Convert between representations of SE(3)."""

# Convert any input parameterization to a RigidTransform
if input_parameterization == "se3_log_map":
transform = torch.concat([transform[1], transform[0]], axis=-1)
matrix = se3_exp_map(transform).transpose(-1, -2)
transform = RigidTransform(
R=matrix[..., :3, :3],
t=matrix[..., :3, 3],
device=matrix.device,
dtype=matrix.dtype,
)
elif input_parameterization == "se3_exp_map":
pass
else:
transform = RigidTransform(
R=transform[0],
t=transform[1],
parameterization=input_parameterization,
convention=input_convention,
)

# Convert the RigidTransform to any output
if output_parameterization == "se3_exp_map":
return transform
elif output_parameterization == "se3_log_map":
se3_log = transform.get_se3_log()
log_t_vee = se3_log[..., :3]
log_R_vee = se3_log[..., 3:]
return log_R_vee, log_t_vee
else:
return (
transform.get_rotation(output_parameterization, output_convention),
transform.get_translation(),
)

# %% ../notebooks/api/02_calibration.ipynb 10
@jaxtyped(typechecker=beartype)
def perspective_projection(
extrinsic: RigidTransform, # Extrinsic camera matrix (world to camera)
intrinsic: Float[torch.Tensor, "3 3"], # Intrinsic camera matrix (camera to image)
x: Float[torch.Tensor, "b n 3"], # World coordinates
) -> Float[torch.Tensor, "b n 2"]:
x = extrinsic.transform_points(x)
x = extrinsic(x)
x = torch.einsum("ij, bnj -> bni", intrinsic, x)
z = x[..., -1].unsqueeze(-1).clone()
x = x / z
Expand Down
71 changes: 51 additions & 20 deletions diffpose/deepfluoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import numpy as np
import torch
from beartype import beartype
from diffdrr.pose import RigidTransform, convert, make_matrix

from .calibration import RigidTransform, perspective_projection
from .calibration import perspective_projection

# %% ../notebooks/api/00_deepfluoro.ipynb 5
@beartype
Expand Down Expand Up @@ -49,25 +50,46 @@ def __init__(
isocenter_rot = torch.tensor([[torch.pi / 2, 0.0, -torch.pi / 2]])
isocenter_xyz = torch.tensor(self.volume.shape) * self.spacing / 2
isocenter_xyz = isocenter_xyz.unsqueeze(0)
self.isocenter_pose = RigidTransform(
isocenter_rot, isocenter_xyz, "euler_angles", "ZYX"
self.isocenter_pose = convert(
isocenter_rot,
isocenter_xyz,
parameterization="euler_angles",
convention="ZYX",
)

# Camera matrices and fiducials for the specimen
self.fiducials = get_3d_fiducials(self.specimen)

# Miscellaneous transformation matrices for wrangling SE(3) poses
self.flip_xz = RigidTransform(
torch.tensor([[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]),
torch.zeros(3),
torch.tensor(
[
[0.0, 0.0, -1.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
)
self.translate = RigidTransform(
torch.eye(3),
torch.tensor([-self.focal_len / 2, 0.0, 0.0]),
torch.tensor(
[
[1.0, 0.0, 0.0, -self.focal_len / 2],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
)
self.flip_180 = RigidTransform(
torch.tensor([[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]]),
torch.zeros(3),
torch.tensor(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, -1.0, 0.0, 0.0],
[0.0, 0.0, -1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
)

def __len__(self):
Expand All @@ -87,7 +109,7 @@ def __getitem__(self, idx):
projection = self.projections[f"{idx:03d}"]
img = torch.from_numpy(projection["image/pixels"][:])
world2volume = torch.from_numpy(projection["gt-poses/cam-to-pelvis-vol"][:])
world2volume = RigidTransform(world2volume[:3, :3], world2volume[:3, 3])
world2volume = RigidTransform(world2volume)
pose = convert_deepfluoro_to_diffdrr(self, world2volume)

# Handle rotations in the imaging dataset
Expand Down Expand Up @@ -186,7 +208,7 @@ def project(self, pose):
extrinsic = (
self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose)
)
return extrinsic.transform_points(x)
return extrinsic(x)

def __call__(self, pose):
pred_projected_fiducials = self.project(pose)
Expand Down Expand Up @@ -255,20 +277,29 @@ def load_deepfluoro_dataset(id_number, filename):

def parse_volume(specimen):
# Parse the volume
spacing = specimen["vol/spacing"][:].flatten()
spacing = specimen["vol/spacing"][:].flatten().astype(np.float32)
volume = specimen["vol/pixels"][:].astype(np.float32)
volume = np.swapaxes(volume, 0, 2)[::-1].copy()

# Parse the translation matrix from LPS coordinates to volume coordinates
origin = torch.from_numpy(specimen["vol/origin"][:].flatten())
lps2volume = RigidTransform(torch.eye(3), origin)
# Parse the translation matrix from anatomical coordinates to world coordinates
origin = torch.from_numpy(specimen["vol/origin"][:].flatten()).to(torch.float32)
lps2volume = RigidTransform(
torch.tensor(
[
[1.0, 0.0, 0.0, origin[0]],
[0.0, 1.0, 0.0, origin[1]],
[0.0, 0.0, 1.0, origin[2]],
[0.0, 0.0, 0.0, 1.0],
]
)
)
return volume, spacing, lps2volume


def parse_proj_params(f):
proj_params = f["proj-params"]
extrinsic = torch.from_numpy(proj_params["extrinsic"][:])
extrinsic = RigidTransform(extrinsic[..., :3, :3], extrinsic[:3, 3])
extrinsic = RigidTransform(extrinsic)
intrinsic = torch.from_numpy(proj_params["intrinsic"][:])
num_cols = float(proj_params["num-cols"][()])
num_rows = float(proj_params["num-rows"][()])
Expand Down Expand Up @@ -306,7 +337,7 @@ def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)):
return img

# %% ../notebooks/api/00_deepfluoro.ipynb 26
from .calibration import RigidTransform, convert
from diffdrr.pose import RigidTransform, convert


@beartype
Expand All @@ -320,9 +351,9 @@ def get_random_offset(batch_size: int, device) -> RigidTransform:
log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device)
log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device)
return convert(
[log_R_vee, log_t_vee],
"se3_log_map",
"se3_exp_map",
log_R_vee,
log_t_vee,
parameterization="se3_log_map",
)

# %% ../notebooks/api/00_deepfluoro.ipynb 32
Expand Down
42 changes: 33 additions & 9 deletions diffpose/jacobians.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/06_jacobians.ipynb.

# %% auto 0
__all__ = ['JacobianDRR', 'gradient_matching', 'plot_img_jacobian']
__all__ = ['JacobianDRR', 'plot_img_jacobian', 'preconditioner']

# %% ../notebooks/api/06_jacobians.ipynb 3
import torch
from diffdrr.pose import convert


class JacobianDRR(torch.nn.Module):
"""Computes the Jacobian of a DRR wrt pose parameters."""

def __init__(self, drr, rotation, translation, parameterization, convention=None):
def __init__(
self,
drr,
rotation,
translation,
parameterization,
convention=None,
):
super().__init__()
self.drr = drr
self.rotation = torch.nn.Parameter(rotation.clone())
Expand All @@ -30,18 +38,18 @@ def forward(self):
return I, J

def cast(self, rotation, translation):
return self.drr(rotation, translation, self.parameterization, self.convention)
pose = convert(
rotation,
translation,
parameterization=self.parameterization,
convention=self.convention,
)
return self.drr(pose)

def permute(self, x):
return x.permute(-1, 0, 2, 3, 1, 4)[..., 0, 0]

# %% ../notebooks/api/06_jacobians.ipynb 4
def gradient_matching(J0, J1):
J0 /= J0.norm(dim=[-1, -2], keepdim=True)
J1 /= J1.norm(dim=[-1, -2], keepdim=True)
return (J0 - J1).norm()

# %% ../notebooks/api/06_jacobians.ipynb 5
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

Expand Down Expand Up @@ -100,3 +108,19 @@ def fmt(x, pos):
plt.axis("off")
plt.colorbar()
plt.show()

# %% ../notebooks/api/06_jacobians.ipynb 5
def preconditioner(jacobian, p=-0.5, λ=1e-1, μ=1e-8):
"""Calculate the inverse preconditioning matrix."""
# Calculate the covariance matrix
vecj = jacobian.flatten(start_dim=1)
cov = vecj @ vecj.T

# Apply a small dampening parameter
eyelike = lambda A: torch.eye(A.shape[0]).to(A)
cov += λ * torch.diag(torch.diag(cov)) + μ * eyelike(cov)

# Calculate the preconditioning matrix
eigenvalues, V = torch.linalg.eigh(cov)
Dexp = torch.diag(eigenvalues.pow(p))
return V @ Dexp @ V.T
Loading

0 comments on commit 0833ac3

Please sign in to comment.