Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add perspective projection and inverse #195

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
'diffdrr.detector.Detector.__init__': ('api/detector.html#detector.__init__', 'diffdrr/detector.py'),
'diffdrr.detector.Detector._initialize_carm': ( 'api/detector.html#detector._initialize_carm',
'diffdrr/detector.py'),
'diffdrr.detector.Detector.flip_xz': ('api/detector.html#detector.flip_xz', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.forward': ('api/detector.html#detector.forward', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.intrinsic': ('api/detector.html#detector.intrinsic', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.translate': ('api/detector.html#detector.translate', 'diffdrr/detector.py'),
'diffdrr.detector.diffdrr_to_deepdrr': ('api/detector.html#diffdrr_to_deepdrr', 'diffdrr/detector.py')},
'diffdrr.drr': { 'diffdrr.drr.DRR': ('api/drr.html#drr', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.__init__': ('api/drr.html#drr.__init__', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.inverse_projection': ('api/drr.html#drr.inverse_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.perspective_projection': ('api/drr.html#drr.perspective_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.reshape_transform': ('api/drr.html#drr.reshape_transform', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.set_bone_attenuation_multiplier': ( 'api/drr.html#drr.set_bone_attenuation_multiplier',
'diffdrr/drr.py'),
Expand Down
44 changes: 44 additions & 0 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
__all__ = ['Detector', 'diffdrr_to_deepdrr']

# %% ../notebooks/api/02_detector.ipynb 5
from .pose import RigidTransform
from .utils import make_intrinsic_matrix


class Detector(torch.nn.Module):
"""Construct a 6 DoF X-ray detector system. This model is based on a C-Arm."""

Expand Down Expand Up @@ -44,6 +48,46 @@ def __init__(
self.register_buffer("source", source)
self.register_buffer("target", target)

# Anatomy to world coordinates
flip_xz = torch.tensor(
[
[0.0, 0.0, -1.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
translate = torch.tensor(
[
[1.0, 0.0, 0.0, -self.sdr],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
self.register_buffer("_flip_xz", flip_xz)
self.register_buffer("_translate", translate)

@property
def intrinsic(self):
return make_intrinsic_matrix(
self.sdr,
self.delx,
self.dely,
self.height,
self.width,
self.x0,
self.y0,
).to(self._flip_xz)

@property
def flip_xz(self):
return RigidTransform(self._flip_xz)

@property
def translate(self):
return RigidTransform(self._translate)

# %% ../notebooks/api/02_detector.ipynb 6
@patch
def _initialize_carm(self: Detector):
Expand Down
46 changes: 45 additions & 1 deletion diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def reshape_subsampled_drr(
return drr

# %% ../notebooks/api/00_drr.ipynb 10
# from diffdrr.se3 import RigidTransform, convert
from .pose import convert


Expand Down Expand Up @@ -170,7 +169,52 @@ def set_intrinsics(
reverse_x_axis=self.detector.reverse_x_axis,
).to(self.volume)

# %% ../notebooks/api/00_drr.ipynb 13
from .pose import RigidTransform


@patch
def perspective_projection(
self: DRR,
pose: RigidTransform,
pts: torch.Tensor,
):
extrinsic = (
pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz)
)
x = extrinsic(pts)
x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x)
z = x[..., -1].unsqueeze(-1).clone()
x = x / z
return x[..., :2]

# %% ../notebooks/api/00_drr.ipynb 14
from torch.nn.functional import pad


@patch
def inverse_projection(
self: DRR,
pose: RigidTransform,
pts: torch.Tensor,
):
extrinsic = (
self.detector.flip_xz.inverse()
.compose(self.detector.translate.inverse())
.compose(pose)
)
x = (
-2
* self.detector.sdr
* torch.einsum(
"ij, bnj -> bni",
self.detector.intrinsic.inverse(),
pad(pts, (0, 1), value=1), # Convert to homogenous coordinates
)
)
return extrinsic(x)

# %% ../notebooks/api/00_drr.ipynb 16
class Registration(nn.Module):
"""Perform automatic 2D-to-3D registration using differentiable rendering."""

Expand Down
11 changes: 7 additions & 4 deletions diffdrr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,22 @@ def parse_intrinsic_matrix(
return focal_length, x0, y0

# %% ../notebooks/api/07_utils.ipynb 7
import torch


def make_intrinsic_matrix(
sdr: float, # Source-to-detector radius (in units length)
delx: float, # X-direction spacing (in units length / pixel)
dely: float, # Y-direction spacing (in units length / pixel)
height: int, # Y-direction length (in units pixels)
width: int, # X-direction length (in units pixels)
delx: float, # X-direction spacing (in units length)
dely: float, # Y-direction spacing (in units length)
x0: float = 0.0, # Principal point x-coordinate (in units length)
y0: float = 0.0, # Principal point y-coordinate (in units length)
):
return torch.tensor(
[
[-2 * sdr / delx, 0.0, height / 2 - x0 / delx],
[0.0, -2 * sdr / dely, width / 2 - y0 / dely],
[2 * sdr / delx, 0.0, x0 / delx - height / 2],
[0.0, 2 * sdr / dely, y0 / dely - width / 2],
[0.0, 0.0, 1.0],
]
)
62 changes: 61 additions & 1 deletion notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@
"outputs": [],
"source": [
"#| export\n",
"# from diffdrr.se3 import RigidTransform, convert\n",
"from diffdrr.pose import convert\n",
"\n",
"\n",
Expand Down Expand Up @@ -306,6 +305,67 @@
" ).to(self.volume)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93a94ef3-5449-45dc-aa62-9fcf6fad643d",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from diffdrr.pose import RigidTransform\n",
"\n",
"\n",
"@patch\n",
"def perspective_projection(\n",
" self: DRR,\n",
" pose: RigidTransform,\n",
" pts: torch.Tensor,\n",
"):\n",
" extrinsic = (\n",
" pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz)\n",
" )\n",
" x = extrinsic(pts)\n",
" x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n",
" z = x[..., -1].unsqueeze(-1).clone()\n",
" x = x / z\n",
" return x[..., :2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "802ba874-eef8-4524-be5c-bd250e5639d7",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from torch.nn.functional import pad\n",
"\n",
"\n",
"@patch\n",
"def inverse_projection(\n",
" self: DRR,\n",
" pose: RigidTransform,\n",
" pts: torch.Tensor,\n",
"):\n",
" extrinsic = (\n",
" self.detector.flip_xz.inverse()\n",
" .compose(self.detector.translate.inverse())\n",
" .compose(pose)\n",
" )\n",
" x = (\n",
" -2\n",
" * self.detector.sdr\n",
" * torch.einsum(\n",
" \"ij, bnj -> bni\",\n",
" self.detector.intrinsic.inverse(),\n",
" pad(pts, (0, 1), value=1), # Convert to homogenous coordinates\n",
" )\n",
" )\n",
" return extrinsic(x)"
]
},
{
"cell_type": "markdown",
"id": "10bc1b2b-a444-45dc-8430-56646d54f95f",
Expand Down
46 changes: 45 additions & 1 deletion notebooks/api/02_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
"outputs": [],
"source": [
"#| export\n",
"from diffdrr.pose import RigidTransform\n",
"from diffdrr.utils import make_intrinsic_matrix\n",
"\n",
"\n",
"class Detector(torch.nn.Module):\n",
" \"\"\"Construct a 6 DoF X-ray detector system. This model is based on a C-Arm.\"\"\"\n",
"\n",
Expand Down Expand Up @@ -97,7 +101,47 @@
" # Initialize the source and detector plane in default positions (along the x-axis)\n",
" source, target = self._initialize_carm()\n",
" self.register_buffer(\"source\", source)\n",
" self.register_buffer(\"target\", target)"
" self.register_buffer(\"target\", target)\n",
"\n",
" # Anatomy to world coordinates\n",
" flip_xz = torch.tensor(\n",
" [\n",
" [0.0, 0.0, -1.0, 0.0],\n",
" [0.0, 1.0, 0.0, 0.0],\n",
" [1.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 1.0],\n",
" ]\n",
" )\n",
" translate = torch.tensor(\n",
" [\n",
" [1.0, 0.0, 0.0, -self.sdr],\n",
" [0.0, 1.0, 0.0, 0.0],\n",
" [0.0, 0.0, 1.0, 0.0],\n",
" [0.0, 0.0, 0.0, 1.0],\n",
" ]\n",
" )\n",
" self.register_buffer(\"_flip_xz\", flip_xz)\n",
" self.register_buffer(\"_translate\", translate)\n",
"\n",
" @property\n",
" def intrinsic(self):\n",
" return make_intrinsic_matrix(\n",
" self.sdr,\n",
" self.delx,\n",
" self.dely,\n",
" self.height,\n",
" self.width,\n",
" self.x0,\n",
" self.y0,\n",
" ).to(self._flip_xz)\n",
"\n",
" @property\n",
" def flip_xz(self):\n",
" return RigidTransform(self._flip_xz)\n",
"\n",
" @property\n",
" def translate(self):\n",
" return RigidTransform(self._translate)"
]
},
{
Expand Down
7 changes: 7 additions & 0 deletions notebooks/api/06_pose.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,13 @@
"display_name": "python3",
"language": "python",
"name": "python3"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
Expand Down
18 changes: 14 additions & 4 deletions notebooks/api/07_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,22 @@
"outputs": [],
"source": [
"#| export\n",
"import torch\n",
"\n",
"\n",
"def make_intrinsic_matrix(\n",
" sdr: float, # Source-to-detector radius (in units length)\n",
" delx: float, # X-direction spacing (in units length / pixel)\n",
" dely: float, # Y-direction spacing (in units length / pixel)\n",
" height: int, # Y-direction length (in units pixels)\n",
" width: int, # X-direction length (in units pixels)\n",
" delx: float, # X-direction spacing (in units length)\n",
" dely: float, # Y-direction spacing (in units length)\n",
" x0: float = 0.0, # Principal point x-coordinate (in units length)\n",
" y0: float = 0.0, # Principal point y-coordinate (in units length)\n",
"):\n",
" return torch.tensor(\n",
" [\n",
" [-2 * sdr / delx, 0.0, height / 2 - x0 / delx],\n",
" [0.0, -2 * sdr / dely, width / 2 - y0 / dely],\n",
" [2 * sdr / delx, 0.0, x0 / delx - height / 2],\n",
" [0.0, 2 * sdr / dely, y0 / dely - width / 2],\n",
" [0.0, 0.0, 1.0],\n",
" ]\n",
" )"
Expand Down Expand Up @@ -156,6 +159,13 @@
"display_name": "python3",
"language": "python",
"name": "python3"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
Expand Down
Loading