diff --git a/diffdrr/detector.py b/diffdrr/detector.py index fd2343f0f..33941f9f0 100644 --- a/diffdrr/detector.py +++ b/diffdrr/detector.py @@ -52,8 +52,8 @@ def __init__( "_calibration", torch.tensor( [ - [delx, 0, 0, x0], - [0, dely, 0, y0], + [dely, 0, 0, -y0], + [0, delx, 0, x0], [0, 0, sdd, 0], [0, 0, 0, 1], ] @@ -66,19 +66,19 @@ def sdd(self): @property def delx(self): - return self._calibration[0, 0].item() + return self._calibration[1, 1].item() @property def dely(self): - return self._calibration[1, 1].item() + return self._calibration[0, 0].item() @property def x0(self): - return self._calibration[0, -1].item() + return -self._calibration[1, -1].item() @property def y0(self): - return self._calibration[1, -1].item() + return -self._calibration[0, -1].item() @property def reorient(self): @@ -96,10 +96,10 @@ def intrinsic(self): self.sdd, self.delx, self.dely, - self.height, self.width, - self.x0, + self.height, self.y0, + self.x0, ).to(self.source) # %% ../notebooks/api/02_detector.ipynb 6 @@ -113,36 +113,29 @@ def _initialize_carm(self: Detector): # Initialize the source at the origin and the center of the detector plane on the positive z-axis source = torch.tensor([[0.0, 0.0, 0.0]], device=device) - center = torch.tensor([[0.0, 0.0, 1.0]], device=device) # * self.sdd + center = torch.tensor([[0.0, 0.0, 1.0]], device=device) # Use the standard basis for the detector plane basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device) # Construct the detector plane with different offsets for even or odd heights + # These ensure that the detector plane is centered around (0, 0, 1) h_off = 1.0 if self.height % 2 else 0.5 w_off = 1.0 if self.width % 2 else 0.5 # Construct equally spaced points along the basis vectors - t = ( - torch.arange(-self.height // 2, self.height // 2, device=device) + h_off - ) # * self.delx - s = ( - torch.arange(-self.width // 2, self.width // 2, device=device) + w_off - ) # * self.dely + t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off + s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off if self.reverse_x_axis: s = -s coefs = torch.cartesian_prod(t, s).reshape(-1, 2) target = torch.einsum("cd,nc->nd", basis, coefs) target += center - # Batch source and target + # Add a batch dimension to the source and target so multiple poses can be passed at once source = source.unsqueeze(0) target = target.unsqueeze(0) - # # Apply principal point offset - # target[..., 1] -= self.x0 - # target[..., 0] -= self.y0 - if self.n_subsample is not None: sample = torch.randperm(self.height * self.width)[: int(self.n_subsample)] target = target[:, sample, :] diff --git a/diffdrr/drr.py b/diffdrr/drr.py index 4df725f50..97b650f97 100644 --- a/diffdrr/drr.py +++ b/diffdrr/drr.py @@ -183,6 +183,7 @@ def perspective_projection( pose: RigidTransform, pts: torch.Tensor, ): + """Project points in world coordinates (3D) onto the pixel plane (2D).""" extrinsic = (self.detector.reorient.compose(pose)).inverse() x = extrinsic(pts) x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x) @@ -190,7 +191,7 @@ def perspective_projection( x = x / z if self.detector.reverse_x_axis: x[..., 1] = self.detector.width - x[..., 1] - return x[..., :2] + return x[..., :2].flip(-1) # %% ../notebooks/api/00_drr.ipynb 13 from torch.nn.functional import pad @@ -202,10 +203,14 @@ def inverse_projection( pose: RigidTransform, pts: torch.Tensor, ): - extrinsic = self.detector.reorient.compose(pose) + """Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D).""" + pts = pts.flip(-1) + if self.detector.reverse_x_axis: + pts[..., 1] = self.detector.width - pts[..., 1] x = self.detector.sdd * torch.einsum( "ij, bnj -> bni", self.detector.intrinsic.inverse(), pad(pts, (0, 1), value=1), # Convert to homogenous coordinates ) + extrinsic = self.detector.reorient.compose(pose) return extrinsic(x) diff --git a/notebooks/api/00_drr.ipynb b/notebooks/api/00_drr.ipynb index 8c5cf2fa9..694838c35 100644 --- a/notebooks/api/00_drr.ipynb +++ b/notebooks/api/00_drr.ipynb @@ -318,6 +318,7 @@ " pose: RigidTransform,\n", " pts: torch.Tensor,\n", "):\n", + " \"\"\"Project points in world coordinates (3D) onto the pixel plane (2D).\"\"\"\n", " extrinsic = (self.detector.reorient.compose(pose)).inverse()\n", " x = extrinsic(pts)\n", " x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n", @@ -325,7 +326,7 @@ " x = x / z\n", " if self.detector.reverse_x_axis:\n", " x[..., 1] = self.detector.width - x[..., 1]\n", - " return x[..., :2]" + " return x[..., :2].flip(-1)" ] }, { @@ -345,12 +346,16 @@ " pose: RigidTransform,\n", " pts: torch.Tensor,\n", "):\n", - " extrinsic = self.detector.reorient.compose(pose)\n", + " \"\"\"Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D).\"\"\"\n", + " pts = pts.flip(-1)\n", + " if self.detector.reverse_x_axis:\n", + " pts[..., 1] = self.detector.width - pts[..., 1]\n", " x = self.detector.sdd * torch.einsum(\n", " \"ij, bnj -> bni\",\n", " self.detector.intrinsic.inverse(),\n", " pad(pts, (0, 1), value=1), # Convert to homogenous coordinates\n", " )\n", + " extrinsic = self.detector.reorient.compose(pose)\n", " return extrinsic(x)" ] }, diff --git a/notebooks/api/02_detector.ipynb b/notebooks/api/02_detector.ipynb index 4c0fcd2f8..6f7d94472 100644 --- a/notebooks/api/02_detector.ipynb +++ b/notebooks/api/02_detector.ipynb @@ -107,8 +107,8 @@ " \"_calibration\",\n", " torch.tensor(\n", " [\n", - " [delx, 0, 0, x0],\n", - " [0, dely, 0, y0],\n", + " [dely, 0, 0, -y0],\n", + " [0, delx, 0, x0],\n", " [0, 0, sdd, 0],\n", " [0, 0, 0, 1],\n", " ]\n", @@ -121,19 +121,19 @@ "\n", " @property\n", " def delx(self):\n", - " return self._calibration[0, 0].item()\n", + " return self._calibration[1, 1].item()\n", "\n", " @property\n", " def dely(self):\n", - " return self._calibration[1, 1].item()\n", + " return self._calibration[0, 0].item()\n", "\n", " @property\n", " def x0(self):\n", - " return self._calibration[0, -1].item()\n", + " return -self._calibration[1, -1].item()\n", "\n", " @property\n", " def y0(self):\n", - " return self._calibration[1, -1].item()\n", + " return -self._calibration[0, -1].item()\n", "\n", " @property\n", " def reorient(self):\n", @@ -151,10 +151,10 @@ " self.sdd,\n", " self.delx,\n", " self.dely,\n", - " self.height,\n", " self.width,\n", - " self.x0,\n", + " self.height,\n", " self.y0,\n", + " self.x0,\n", " ).to(self.source)" ] }, @@ -176,36 +176,29 @@ "\n", " # Initialize the source at the origin and the center of the detector plane on the positive z-axis\n", " source = torch.tensor([[0.0, 0.0, 0.0]], device=device)\n", - " center = torch.tensor([[0.0, 0.0, 1.0]], device=device) # * self.sdd\n", + " center = torch.tensor([[0.0, 0.0, 1.0]], device=device)\n", "\n", " # Use the standard basis for the detector plane\n", " basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)\n", "\n", " # Construct the detector plane with different offsets for even or odd heights\n", + " # These ensure that the detector plane is centered around (0, 0, 1)\n", " h_off = 1.0 if self.height % 2 else 0.5\n", " w_off = 1.0 if self.width % 2 else 0.5\n", "\n", " # Construct equally spaced points along the basis vectors\n", - " t = (\n", - " torch.arange(-self.height // 2, self.height // 2, device=device) + h_off\n", - " ) # * self.delx\n", - " s = (\n", - " torch.arange(-self.width // 2, self.width // 2, device=device) + w_off\n", - " ) # * self.dely\n", + " t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off\n", + " s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off\n", " if self.reverse_x_axis:\n", " s = -s\n", " coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n", " target = torch.einsum(\"cd,nc->nd\", basis, coefs)\n", " target += center\n", "\n", - " # Batch source and target\n", + " # Add a batch dimension to the source and target so multiple poses can be passed at once\n", " source = source.unsqueeze(0)\n", " target = target.unsqueeze(0)\n", "\n", - " # # Apply principal point offset\n", - " # target[..., 1] -= self.x0\n", - " # target[..., 0] -= self.y0\n", - "\n", " if self.n_subsample is not None:\n", " sample = torch.randperm(self.height * self.width)[: int(self.n_subsample)]\n", " target = target[:, sample, :]\n",