Skip to content

Commit

Permalink
Fix intrinsics and backprojection (drr.inverse_projection) (#270)
Browse files Browse the repository at this point in the history
* Remove commented code

* Fix variable name

* Fix order in intrinsic matrix for non-square detectors

* Reverse x-/y-axes in intrinsic

* Standardize axis convention in projection
  • Loading branch information
eigenvivek authored Jun 12, 2024
1 parent 14a348e commit 858bb1a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 44 deletions.
33 changes: 13 additions & 20 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def __init__(
"_calibration",
torch.tensor(
[
[delx, 0, 0, x0],
[0, dely, 0, y0],
[dely, 0, 0, -y0],
[0, delx, 0, x0],
[0, 0, sdd, 0],
[0, 0, 0, 1],
]
Expand All @@ -66,19 +66,19 @@ def sdd(self):

@property
def delx(self):
return self._calibration[0, 0].item()
return self._calibration[1, 1].item()

@property
def dely(self):
return self._calibration[1, 1].item()
return self._calibration[0, 0].item()

@property
def x0(self):
return self._calibration[0, -1].item()
return -self._calibration[1, -1].item()

@property
def y0(self):
return self._calibration[1, -1].item()
return -self._calibration[0, -1].item()

@property
def reorient(self):
Expand All @@ -96,10 +96,10 @@ def intrinsic(self):
self.sdd,
self.delx,
self.dely,
self.height,
self.width,
self.x0,
self.height,
self.y0,
self.x0,
).to(self.source)

# %% ../notebooks/api/02_detector.ipynb 6
Expand All @@ -113,36 +113,29 @@ def _initialize_carm(self: Detector):

# Initialize the source at the origin and the center of the detector plane on the positive z-axis
source = torch.tensor([[0.0, 0.0, 0.0]], device=device)
center = torch.tensor([[0.0, 0.0, 1.0]], device=device) # * self.sdd
center = torch.tensor([[0.0, 0.0, 1.0]], device=device)

# Use the standard basis for the detector plane
basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)

# Construct the detector plane with different offsets for even or odd heights
# These ensure that the detector plane is centered around (0, 0, 1)
h_off = 1.0 if self.height % 2 else 0.5
w_off = 1.0 if self.width % 2 else 0.5

# Construct equally spaced points along the basis vectors
t = (
torch.arange(-self.height // 2, self.height // 2, device=device) + h_off
) # * self.delx
s = (
torch.arange(-self.width // 2, self.width // 2, device=device) + w_off
) # * self.dely
t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off
s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off
if self.reverse_x_axis:
s = -s
coefs = torch.cartesian_prod(t, s).reshape(-1, 2)
target = torch.einsum("cd,nc->nd", basis, coefs)
target += center

# Batch source and target
# Add a batch dimension to the source and target so multiple poses can be passed at once
source = source.unsqueeze(0)
target = target.unsqueeze(0)

# # Apply principal point offset
# target[..., 1] -= self.x0
# target[..., 0] -= self.y0

if self.n_subsample is not None:
sample = torch.randperm(self.height * self.width)[: int(self.n_subsample)]
target = target[:, sample, :]
Expand Down
9 changes: 7 additions & 2 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,15 @@ def perspective_projection(
pose: RigidTransform,
pts: torch.Tensor,
):
"""Project points in world coordinates (3D) onto the pixel plane (2D)."""
extrinsic = (self.detector.reorient.compose(pose)).inverse()
x = extrinsic(pts)
x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x)
z = x[..., -1].unsqueeze(-1).clone()
x = x / z
if self.detector.reverse_x_axis:
x[..., 1] = self.detector.width - x[..., 1]
return x[..., :2]
return x[..., :2].flip(-1)

# %% ../notebooks/api/00_drr.ipynb 13
from torch.nn.functional import pad
Expand All @@ -202,10 +203,14 @@ def inverse_projection(
pose: RigidTransform,
pts: torch.Tensor,
):
extrinsic = self.detector.reorient.compose(pose)
"""Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D)."""
pts = pts.flip(-1)
if self.detector.reverse_x_axis:
pts[..., 1] = self.detector.width - pts[..., 1]
x = self.detector.sdd * torch.einsum(
"ij, bnj -> bni",
self.detector.intrinsic.inverse(),
pad(pts, (0, 1), value=1), # Convert to homogenous coordinates
)
extrinsic = self.detector.reorient.compose(pose)
return extrinsic(x)
9 changes: 7 additions & 2 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,15 @@
" pose: RigidTransform,\n",
" pts: torch.Tensor,\n",
"):\n",
" \"\"\"Project points in world coordinates (3D) onto the pixel plane (2D).\"\"\"\n",
" extrinsic = (self.detector.reorient.compose(pose)).inverse()\n",
" x = extrinsic(pts)\n",
" x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n",
" z = x[..., -1].unsqueeze(-1).clone()\n",
" x = x / z\n",
" if self.detector.reverse_x_axis:\n",
" x[..., 1] = self.detector.width - x[..., 1]\n",
" return x[..., :2]"
" return x[..., :2].flip(-1)"
]
},
{
Expand All @@ -345,12 +346,16 @@
" pose: RigidTransform,\n",
" pts: torch.Tensor,\n",
"):\n",
" extrinsic = self.detector.reorient.compose(pose)\n",
" \"\"\"Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D).\"\"\"\n",
" pts = pts.flip(-1)\n",
" if self.detector.reverse_x_axis:\n",
" pts[..., 1] = self.detector.width - pts[..., 1]\n",
" x = self.detector.sdd * torch.einsum(\n",
" \"ij, bnj -> bni\",\n",
" self.detector.intrinsic.inverse(),\n",
" pad(pts, (0, 1), value=1), # Convert to homogenous coordinates\n",
" )\n",
" extrinsic = self.detector.reorient.compose(pose)\n",
" return extrinsic(x)"
]
},
Expand Down
33 changes: 13 additions & 20 deletions notebooks/api/02_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@
" \"_calibration\",\n",
" torch.tensor(\n",
" [\n",
" [delx, 0, 0, x0],\n",
" [0, dely, 0, y0],\n",
" [dely, 0, 0, -y0],\n",
" [0, delx, 0, x0],\n",
" [0, 0, sdd, 0],\n",
" [0, 0, 0, 1],\n",
" ]\n",
Expand All @@ -121,19 +121,19 @@
"\n",
" @property\n",
" def delx(self):\n",
" return self._calibration[0, 0].item()\n",
" return self._calibration[1, 1].item()\n",
"\n",
" @property\n",
" def dely(self):\n",
" return self._calibration[1, 1].item()\n",
" return self._calibration[0, 0].item()\n",
"\n",
" @property\n",
" def x0(self):\n",
" return self._calibration[0, -1].item()\n",
" return -self._calibration[1, -1].item()\n",
"\n",
" @property\n",
" def y0(self):\n",
" return self._calibration[1, -1].item()\n",
" return -self._calibration[0, -1].item()\n",
"\n",
" @property\n",
" def reorient(self):\n",
Expand All @@ -151,10 +151,10 @@
" self.sdd,\n",
" self.delx,\n",
" self.dely,\n",
" self.height,\n",
" self.width,\n",
" self.x0,\n",
" self.height,\n",
" self.y0,\n",
" self.x0,\n",
" ).to(self.source)"
]
},
Expand All @@ -176,36 +176,29 @@
"\n",
" # Initialize the source at the origin and the center of the detector plane on the positive z-axis\n",
" source = torch.tensor([[0.0, 0.0, 0.0]], device=device)\n",
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device) # * self.sdd\n",
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device)\n",
"\n",
" # Use the standard basis for the detector plane\n",
" basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)\n",
"\n",
" # Construct the detector plane with different offsets for even or odd heights\n",
" # These ensure that the detector plane is centered around (0, 0, 1)\n",
" h_off = 1.0 if self.height % 2 else 0.5\n",
" w_off = 1.0 if self.width % 2 else 0.5\n",
"\n",
" # Construct equally spaced points along the basis vectors\n",
" t = (\n",
" torch.arange(-self.height // 2, self.height // 2, device=device) + h_off\n",
" ) # * self.delx\n",
" s = (\n",
" torch.arange(-self.width // 2, self.width // 2, device=device) + w_off\n",
" ) # * self.dely\n",
" t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off\n",
" s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off\n",
" if self.reverse_x_axis:\n",
" s = -s\n",
" coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n",
" target = torch.einsum(\"cd,nc->nd\", basis, coefs)\n",
" target += center\n",
"\n",
" # Batch source and target\n",
" # Add a batch dimension to the source and target so multiple poses can be passed at once\n",
" source = source.unsqueeze(0)\n",
" target = target.unsqueeze(0)\n",
"\n",
" # # Apply principal point offset\n",
" # target[..., 1] -= self.x0\n",
" # target[..., 0] -= self.y0\n",
"\n",
" if self.n_subsample is not None:\n",
" sample = torch.randperm(self.height * self.width)[: int(self.n_subsample)]\n",
" target = target[:, sample, :]\n",
Expand Down

0 comments on commit 858bb1a

Please sign in to comment.