From f3fd98ae039b2df602c5164fe7f038293c712153 Mon Sep 17 00:00:00 2001 From: Vivek Gopalakrishnan Date: Wed, 21 Feb 2024 18:07:29 -0500 Subject: [PATCH] Add projection code --- diffdrr/_modidx.py | 5 +++ diffdrr/detector.py | 44 +++++++++++++++++++++++ diffdrr/drr.py | 46 +++++++++++++++++++++++- diffdrr/utils.py | 11 +++--- notebooks/api/00_drr.ipynb | 62 ++++++++++++++++++++++++++++++++- notebooks/api/02_detector.ipynb | 46 +++++++++++++++++++++++- notebooks/api/06_pose.ipynb | 7 ++++ notebooks/api/07_utils.ipynb | 18 +++++++--- 8 files changed, 228 insertions(+), 11 deletions(-) diff --git a/diffdrr/_modidx.py b/diffdrr/_modidx.py index 7c9bf2bb3..c2eff5ccf 100644 --- a/diffdrr/_modidx.py +++ b/diffdrr/_modidx.py @@ -11,11 +11,16 @@ 'diffdrr.detector.Detector.__init__': ('api/detector.html#detector.__init__', 'diffdrr/detector.py'), 'diffdrr.detector.Detector._initialize_carm': ( 'api/detector.html#detector._initialize_carm', 'diffdrr/detector.py'), + 'diffdrr.detector.Detector.flip_xz': ('api/detector.html#detector.flip_xz', 'diffdrr/detector.py'), 'diffdrr.detector.Detector.forward': ('api/detector.html#detector.forward', 'diffdrr/detector.py'), + 'diffdrr.detector.Detector.intrinsic': ('api/detector.html#detector.intrinsic', 'diffdrr/detector.py'), + 'diffdrr.detector.Detector.translate': ('api/detector.html#detector.translate', 'diffdrr/detector.py'), 'diffdrr.detector.diffdrr_to_deepdrr': ('api/detector.html#diffdrr_to_deepdrr', 'diffdrr/detector.py')}, 'diffdrr.drr': { 'diffdrr.drr.DRR': ('api/drr.html#drr', 'diffdrr/drr.py'), 'diffdrr.drr.DRR.__init__': ('api/drr.html#drr.__init__', 'diffdrr/drr.py'), 'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'), + 'diffdrr.drr.DRR.inverse_projection': ('api/drr.html#drr.inverse_projection', 'diffdrr/drr.py'), + 'diffdrr.drr.DRR.perspective_projection': ('api/drr.html#drr.perspective_projection', 'diffdrr/drr.py'), 'diffdrr.drr.DRR.reshape_transform': ('api/drr.html#drr.reshape_transform', 'diffdrr/drr.py'), 'diffdrr.drr.DRR.set_bone_attenuation_multiplier': ( 'api/drr.html#drr.set_bone_attenuation_multiplier', 'diffdrr/drr.py'), diff --git a/diffdrr/detector.py b/diffdrr/detector.py index 8f290f87a..1e491bb11 100644 --- a/diffdrr/detector.py +++ b/diffdrr/detector.py @@ -11,6 +11,10 @@ __all__ = ['Detector', 'diffdrr_to_deepdrr'] # %% ../notebooks/api/02_detector.ipynb 5 +from .pose import RigidTransform +from .utils import make_intrinsic_matrix + + class Detector(torch.nn.Module): """Construct a 6 DoF X-ray detector system. This model is based on a C-Arm.""" @@ -44,6 +48,46 @@ def __init__( self.register_buffer("source", source) self.register_buffer("target", target) + # Anatomy to world coordinates + flip_xz = torch.tensor( + [ + [0.0, 0.0, -1.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + translate = torch.tensor( + [ + [1.0, 0.0, 0.0, -self.sdr], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + self.register_buffer("_flip_xz", flip_xz) + self.register_buffer("_translate", translate) + + @property + def intrinsic(self): + return make_intrinsic_matrix( + self.sdr, + self.delx, + self.dely, + self.height, + self.width, + self.x0, + self.y0, + ).to(self._flip_xz) + + @property + def flip_xz(self): + return RigidTransform(self._flip_xz) + + @property + def translate(self): + return RigidTransform(self._translate) + # %% ../notebooks/api/02_detector.ipynb 6 @patch def _initialize_carm(self: Detector): diff --git a/diffdrr/drr.py b/diffdrr/drr.py index 9d946c2a0..daa5086c3 100644 --- a/diffdrr/drr.py +++ b/diffdrr/drr.py @@ -100,7 +100,6 @@ def reshape_subsampled_drr( return drr # %% ../notebooks/api/00_drr.ipynb 10 -# from diffdrr.se3 import RigidTransform, convert from .pose import convert @@ -170,7 +169,52 @@ def set_intrinsics( reverse_x_axis=self.detector.reverse_x_axis, ).to(self.volume) +# %% ../notebooks/api/00_drr.ipynb 13 +from .pose import RigidTransform + + +@patch +def perspective_projection( + self: DRR, + pose: RigidTransform, + pts: torch.Tensor, +): + extrinsic = ( + pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz) + ) + x = extrinsic(pts) + x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x) + z = x[..., -1].unsqueeze(-1).clone() + x = x / z + return x[..., :2] + # %% ../notebooks/api/00_drr.ipynb 14 +from torch.nn.functional import pad + + +@patch +def inverse_projection( + self: DRR, + pose: RigidTransform, + pts: torch.Tensor, +): + extrinsic = ( + self.detector.flip_xz.inverse() + .compose(self.detector.translate.inverse()) + .compose(pose) + ) + x = ( + -2 + * self.detector.sdr + * torch.einsum( + "ij, bnj -> bni", + self.detector.intrinsic.inverse(), + pad(pts, (0, 1), value=1), # Convert to homogenous coordinates + ) + ) + return extrinsic(x) + +# %% ../notebooks/api/00_drr.ipynb 16 class Registration(nn.Module): """Perform automatic 2D-to-3D registration using differentiable rendering.""" diff --git a/diffdrr/utils.py b/diffdrr/utils.py index 6c49b641c..e9dde8e06 100644 --- a/diffdrr/utils.py +++ b/diffdrr/utils.py @@ -38,19 +38,22 @@ def parse_intrinsic_matrix( return focal_length, x0, y0 # %% ../notebooks/api/07_utils.ipynb 7 +import torch + + def make_intrinsic_matrix( sdr: float, # Source-to-detector radius (in units length) + delx: float, # X-direction spacing (in units length / pixel) + dely: float, # Y-direction spacing (in units length / pixel) height: int, # Y-direction length (in units pixels) width: int, # X-direction length (in units pixels) - delx: float, # X-direction spacing (in units length) - dely: float, # Y-direction spacing (in units length) x0: float = 0.0, # Principal point x-coordinate (in units length) y0: float = 0.0, # Principal point y-coordinate (in units length) ): return torch.tensor( [ - [-2 * sdr / delx, 0.0, height / 2 - x0 / delx], - [0.0, -2 * sdr / dely, width / 2 - y0 / dely], + [2 * sdr / delx, 0.0, x0 / delx - height / 2], + [0.0, 2 * sdr / dely, y0 / dely - width / 2], [0.0, 0.0, 1.0], ] ) diff --git a/notebooks/api/00_drr.ipynb b/notebooks/api/00_drr.ipynb index 2f7e4646b..b1fe00446 100644 --- a/notebooks/api/00_drr.ipynb +++ b/notebooks/api/00_drr.ipynb @@ -219,7 +219,6 @@ "outputs": [], "source": [ "#| export\n", - "# from diffdrr.se3 import RigidTransform, convert\n", "from diffdrr.pose import convert\n", "\n", "\n", @@ -306,6 +305,67 @@ " ).to(self.volume)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "93a94ef3-5449-45dc-aa62-9fcf6fad643d", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from diffdrr.pose import RigidTransform\n", + "\n", + "\n", + "@patch\n", + "def perspective_projection(\n", + " self: DRR,\n", + " pose: RigidTransform,\n", + " pts: torch.Tensor,\n", + "):\n", + " extrinsic = (\n", + " pose.inverse().compose(self.detector.translate).compose(self.detector.flip_xz)\n", + " )\n", + " x = extrinsic(pts)\n", + " x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n", + " z = x[..., -1].unsqueeze(-1).clone()\n", + " x = x / z\n", + " return x[..., :2]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "802ba874-eef8-4524-be5c-bd250e5639d7", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from torch.nn.functional import pad\n", + "\n", + "\n", + "@patch\n", + "def inverse_projection(\n", + " self: DRR,\n", + " pose: RigidTransform,\n", + " pts: torch.Tensor,\n", + "):\n", + " extrinsic = (\n", + " self.detector.flip_xz.inverse()\n", + " .compose(self.detector.translate.inverse())\n", + " .compose(pose)\n", + " )\n", + " x = (\n", + " -2\n", + " * self.detector.sdr\n", + " * torch.einsum(\n", + " \"ij, bnj -> bni\",\n", + " self.detector.intrinsic.inverse(),\n", + " pad(pts, (0, 1), value=1), # Convert to homogenous coordinates\n", + " )\n", + " )\n", + " return extrinsic(x)" + ] + }, { "cell_type": "markdown", "id": "10bc1b2b-a444-45dc-8430-56646d54f95f", diff --git a/notebooks/api/02_detector.ipynb b/notebooks/api/02_detector.ipynb index 8c822f198..1d3522c2f 100644 --- a/notebooks/api/02_detector.ipynb +++ b/notebooks/api/02_detector.ipynb @@ -66,6 +66,10 @@ "outputs": [], "source": [ "#| export\n", + "from diffdrr.pose import RigidTransform\n", + "from diffdrr.utils import make_intrinsic_matrix\n", + "\n", + "\n", "class Detector(torch.nn.Module):\n", " \"\"\"Construct a 6 DoF X-ray detector system. This model is based on a C-Arm.\"\"\"\n", "\n", @@ -97,7 +101,47 @@ " # Initialize the source and detector plane in default positions (along the x-axis)\n", " source, target = self._initialize_carm()\n", " self.register_buffer(\"source\", source)\n", - " self.register_buffer(\"target\", target)" + " self.register_buffer(\"target\", target)\n", + "\n", + " # Anatomy to world coordinates\n", + " flip_xz = torch.tensor(\n", + " [\n", + " [0.0, 0.0, -1.0, 0.0],\n", + " [0.0, 1.0, 0.0, 0.0],\n", + " [1.0, 0.0, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0, 1.0],\n", + " ]\n", + " )\n", + " translate = torch.tensor(\n", + " [\n", + " [1.0, 0.0, 0.0, -self.sdr],\n", + " [0.0, 1.0, 0.0, 0.0],\n", + " [0.0, 0.0, 1.0, 0.0],\n", + " [0.0, 0.0, 0.0, 1.0],\n", + " ]\n", + " )\n", + " self.register_buffer(\"_flip_xz\", flip_xz)\n", + " self.register_buffer(\"_translate\", translate)\n", + "\n", + " @property\n", + " def intrinsic(self):\n", + " return make_intrinsic_matrix(\n", + " self.sdr,\n", + " self.delx,\n", + " self.dely,\n", + " self.height,\n", + " self.width,\n", + " self.x0,\n", + " self.y0,\n", + " ).to(self._flip_xz)\n", + "\n", + " @property\n", + " def flip_xz(self):\n", + " return RigidTransform(self._flip_xz)\n", + "\n", + " @property\n", + " def translate(self):\n", + " return RigidTransform(self._translate)" ] }, { diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index ed2132c56..1d67ec2fb 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -1535,6 +1535,13 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } } }, "nbformat": 4, diff --git a/notebooks/api/07_utils.ipynb b/notebooks/api/07_utils.ipynb index 903a57cf0..7fa225ed6 100644 --- a/notebooks/api/07_utils.ipynb +++ b/notebooks/api/07_utils.ipynb @@ -111,19 +111,22 @@ "outputs": [], "source": [ "#| export\n", + "import torch\n", + "\n", + "\n", "def make_intrinsic_matrix(\n", " sdr: float, # Source-to-detector radius (in units length)\n", + " delx: float, # X-direction spacing (in units length / pixel)\n", + " dely: float, # Y-direction spacing (in units length / pixel)\n", " height: int, # Y-direction length (in units pixels)\n", " width: int, # X-direction length (in units pixels)\n", - " delx: float, # X-direction spacing (in units length)\n", - " dely: float, # Y-direction spacing (in units length)\n", " x0: float = 0.0, # Principal point x-coordinate (in units length)\n", " y0: float = 0.0, # Principal point y-coordinate (in units length)\n", "):\n", " return torch.tensor(\n", " [\n", - " [-2 * sdr / delx, 0.0, height / 2 - x0 / delx],\n", - " [0.0, -2 * sdr / dely, width / 2 - y0 / dely],\n", + " [2 * sdr / delx, 0.0, x0 / delx - height / 2],\n", + " [0.0, 2 * sdr / dely, y0 / dely - width / 2],\n", " [0.0, 0.0, 1.0],\n", " ]\n", " )" @@ -156,6 +159,13 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } } }, "nbformat": 4,