Skip to content

Commit

Permalink
Merge pull request #195 from eigenvivek/camera-matrix
Browse files Browse the repository at this point in the history
Add perspective projection and inverse
  • Loading branch information
eigenvivek authored Feb 22, 2024
2 parents 3caa6eb + 8246d9d commit 72b348b
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 11 deletions.
5 changes: 5 additions & 0 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
'diffdrr.detector.Detector.__init__': ('api/detector.html#detector.__init__', 'diffdrr/detector.py'),
'diffdrr.detector.Detector._initialize_carm': ( 'api/detector.html#detector._initialize_carm',
'diffdrr/detector.py'),
'diffdrr.detector.Detector.flip_xz': ('api/detector.html#detector.flip_xz', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.forward': ('api/detector.html#detector.forward', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.intrinsic': ('api/detector.html#detector.intrinsic', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.translate': ('api/detector.html#detector.translate', 'diffdrr/detector.py'),
'diffdrr.detector.diffdrr_to_deepdrr': ('api/detector.html#diffdrr_to_deepdrr', 'diffdrr/detector.py')},
'diffdrr.drr': { 'diffdrr.drr.DRR': ('api/drr.html#drr', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.__init__': ('api/drr.html#drr.__init__', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.inverse_projection': ('api/drr.html#drr.inverse_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.perspective_projection': ('api/drr.html#drr.perspective_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.reshape_transform': ('api/drr.html#drr.reshape_transform', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.set_bone_attenuation_multiplier': ( 'api/drr.html#drr.set_bone_attenuation_multiplier',
'diffdrr/drr.py'),
Expand Down
44 changes: 44 additions & 0 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
__all__ = ['Detector', 'diffdrr_to_deepdrr']

# %% ../notebooks/api/02_detector.ipynb 5
from .pose import RigidTransform
from .utils import make_intrinsic_matrix


class Detector(torch.nn.Module):
"""Construct a 6 DoF X-ray detector system. This model is based on a C-Arm."""

Expand Down Expand Up @@ -44,6 +48,46 @@ def __init__(
self.register_buffer("source", source)
self.register_buffer("target", target)

# Anatomy to world coordinates
flip_xz = torch.tensor(
[
[0.0, 0.0, -1.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
translate = torch.tensor(
[
[1.0, 0.0, 0.0, -self.sdr],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
self.register_buffer("_flip_xz", flip_xz)
self.register_buffer("_translate", translate)

@property
def intrinsic(self):
return make_intrinsic_matrix(
self.sdr,
self.delx,
self.dely,
self.height,
self.width,
self.x0,
self.y0,
).to(self._flip_xz)

@property
def flip_xz(self):
return RigidTransform(self._flip_xz)

@property
def translate(self):
return RigidTransform(self._translate)

# %% ../notebooks/api/02_detector.ipynb 6
@patch
def _initialize_carm(self: Detector):
Expand Down
46 changes: 45 additions & 1 deletion diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def reshape_subsampled_drr(
return drr

# %% ../notebooks/api/00_drr.ipynb 10
# from diffdrr.se3 import RigidTransform, convert
from .pose import convert


Expand Down Expand Up @@ -170,7 +169,52 @@ def set_intrinsics(
reverse_x_axis=self.detector.reverse_x_axis,
).to(self.volume)

# %% ../notebooks/api/00_drr.ipynb 13
from .pose import RigidTransform


@patch
def perspective_projection(
self: DRR,
pose: RigidTransform,
pts: torch.Tensor,
):
extrinsic = (
pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz)
)
x = extrinsic(pts)
x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x)
z = x[..., -1].unsqueeze(-1).clone()
x = x / z
return x[..., :2]

# %% ../notebooks/api/00_drr.ipynb 14
from torch.nn.functional import pad


@patch
def inverse_projection(
self: DRR,
pose: RigidTransform,
pts: torch.Tensor,
):
extrinsic = (
self.detector.flip_xz.inverse()
.compose(self.detector.translate.inverse())
.compose(pose)
)
x = (
-2
* self.detector.sdr
* torch.einsum(
"ij, bnj -> bni",
self.detector.intrinsic.inverse(),
pad(pts, (0, 1), value=1), # Convert to homogenous coordinates
)
)
return extrinsic(x)

# %% ../notebooks/api/00_drr.ipynb 16
class Registration(nn.Module):
"""Perform automatic 2D-to-3D registration using differentiable rendering."""

Expand Down
11 changes: 7 additions & 4 deletions diffdrr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,22 @@ def parse_intrinsic_matrix(
return focal_length, x0, y0

# %% ../notebooks/api/07_utils.ipynb 7
import torch


def make_intrinsic_matrix(
sdr: float, # Source-to-detector radius (in units length)
delx: float, # X-direction spacing (in units length / pixel)
dely: float, # Y-direction spacing (in units length / pixel)
height: int, # Y-direction length (in units pixels)
width: int, # X-direction length (in units pixels)
delx: float, # X-direction spacing (in units length)
dely: float, # Y-direction spacing (in units length)
x0: float = 0.0, # Principal point x-coordinate (in units length)
y0: float = 0.0, # Principal point y-coordinate (in units length)
):
return torch.tensor(
[
[-2 * sdr / delx, 0.0, height / 2 - x0 / delx],
[0.0, -2 * sdr / dely, width / 2 - y0 / dely],
[2 * sdr / delx, 0.0, x0 / delx - height / 2],
[0.0, 2 * sdr / dely, y0 / dely - width / 2],
[0.0, 0.0, 1.0],
]
)
62 changes: 61 additions & 1 deletion notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@
"outputs": [],
"source": [
"#| export\n",
"# from diffdrr.se3 import RigidTransform, convert\n",
"from diffdrr.pose import convert\n",
"\n",
"\n",
Expand Down Expand Up @@ -306,6 +305,67 @@
" ).to(self.volume)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93a94ef3-5449-45dc-aa62-9fcf6fad643d",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from diffdrr.pose import RigidTransform\n",
"\n",
"\n",
"@patch\n",
"def perspective_projection(\n",
" self: DRR,\n",
" pose: RigidTransform,\n",
" pts: torch.Tensor,\n",
"):\n",
" extrinsic = (\n",
" pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz)\n",
" )\n",
" x = extrinsic(pts)\n",
" x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n",
" z = x[..., -1].unsqueeze(-1).clone()\n",
" x = x / z\n",
" return x[..., :2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "802ba874-eef8-4524-be5c-bd250e5639d7",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from torch.nn.functional import pad\n",
"\n",
"\n",
"@patch\n",
"def inverse_projection(\n",
" self: DRR,\n",
" pose: RigidTransform,\n",
" pts: torch.Tensor,\n",
"):\n",
" extrinsic = (\n",
" self.detector.flip_xz.inverse()\n",
" .compose(self.detector.translate.inverse())\n",
" .compose(pose)\n",
" )\n",
" x = (\n",
" -2\n",
" * self.detector.sdr\n",
" * torch.einsum(\n",
" \"ij, bnj -> bni\",\n",
" self.detector.intrinsic.inverse(),\n",
" pad(pts, (0, 1), value=1), # Convert to homogenous coordinates\n",
" )\n",
" )\n",
" return extrinsic(x)"
]
},
{
"cell_type": "markdown",
"id": "10bc1b2b-a444-45dc-8430-56646d54f95f",
Expand Down
46 changes: 45 additions & 1 deletion notebooks/api/02_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
"outputs": [],
"source": [
"#| export\n",
"from diffdrr.pose import RigidTransform\n",
"from diffdrr.utils import make_intrinsic_matrix\n",
"\n",
"\n",
"class Detector(torch.nn.Module):\n",
" \"\"\"Construct a 6 DoF X-ray detector system. This model is based on a C-Arm.\"\"\"\n",
"\n",
Expand Down Expand Up @@ -97,7 +101,47 @@
" # Initialize the source and detector plane in default positions (along the x-axis)\n",
" source, target = self._initialize_carm()\n",
" self.register_buffer(\"source\", source)\n",
" self.register_buffer(\"target\", target)"
" self.register_buffer(\"target\", target)\n",
"\n",
" # Anatomy to world coordinates\n",
" flip_xz = torch.tensor(\n",
" [\n",
" [0.0, 0.0, -1.0, 0.0],\n",
" [0.0, 1.0, 0.0, 0.0],\n",
" [1.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 1.0],\n",
" ]\n",
" )\n",
" translate = torch.tensor(\n",
" [\n",
" [1.0, 0.0, 0.0, -self.sdr],\n",
" [0.0, 1.0, 0.0, 0.0],\n",
" [0.0, 0.0, 1.0, 0.0],\n",
" [0.0, 0.0, 0.0, 1.0],\n",
" ]\n",
" )\n",
" self.register_buffer(\"_flip_xz\", flip_xz)\n",
" self.register_buffer(\"_translate\", translate)\n",
"\n",
" @property\n",
" def intrinsic(self):\n",
" return make_intrinsic_matrix(\n",
" self.sdr,\n",
" self.delx,\n",
" self.dely,\n",
" self.height,\n",
" self.width,\n",
" self.x0,\n",
" self.y0,\n",
" ).to(self._flip_xz)\n",
"\n",
" @property\n",
" def flip_xz(self):\n",
" return RigidTransform(self._flip_xz)\n",
"\n",
" @property\n",
" def translate(self):\n",
" return RigidTransform(self._translate)"
]
},
{
Expand Down
7 changes: 7 additions & 0 deletions notebooks/api/06_pose.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,13 @@
"display_name": "python3",
"language": "python",
"name": "python3"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
Expand Down
18 changes: 14 additions & 4 deletions notebooks/api/07_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,22 @@
"outputs": [],
"source": [
"#| export\n",
"import torch\n",
"\n",
"\n",
"def make_intrinsic_matrix(\n",
" sdr: float, # Source-to-detector radius (in units length)\n",
" delx: float, # X-direction spacing (in units length / pixel)\n",
" dely: float, # Y-direction spacing (in units length / pixel)\n",
" height: int, # Y-direction length (in units pixels)\n",
" width: int, # X-direction length (in units pixels)\n",
" delx: float, # X-direction spacing (in units length)\n",
" dely: float, # Y-direction spacing (in units length)\n",
" x0: float = 0.0, # Principal point x-coordinate (in units length)\n",
" y0: float = 0.0, # Principal point y-coordinate (in units length)\n",
"):\n",
" return torch.tensor(\n",
" [\n",
" [-2 * sdr / delx, 0.0, height / 2 - x0 / delx],\n",
" [0.0, -2 * sdr / dely, width / 2 - y0 / dely],\n",
" [2 * sdr / delx, 0.0, x0 / delx - height / 2],\n",
" [0.0, 2 * sdr / dely, y0 / dely - width / 2],\n",
" [0.0, 0.0, 1.0],\n",
" ]\n",
" )"
Expand Down Expand Up @@ -156,6 +159,13 @@
"display_name": "python3",
"language": "python",
"name": "python3"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
Expand Down

0 comments on commit 72b348b

Please sign in to comment.