Skip to content

Commit

Permalink
Flip height and width in the detector (#357)
Browse files Browse the repository at this point in the history
* Flip height and width; fix reorientations

* Fix orientations

* Rename f to sdd

* Handle errors in volume masking

* Fix perspective projection

* Fix x/y conventions

* Bump version
  • Loading branch information
eigenvivek authored Dec 20, 2024
1 parent e41c365 commit 11b1f0a
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 85 deletions.
2 changes: 1 addition & 1 deletion diffdrr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.5"
__version__ = "0.4.6"
47 changes: 27 additions & 20 deletions diffdrr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,35 +85,37 @@ def read(
# Frame-of-reference change
if orientation == "AP":
# Rotates the C-arm about the x-axis by 90 degrees
# Rotates the C-arm about the z-axis by -90 degrees
reorient = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, -1.0, 0.0],
[-1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
[1, 0, 0, 0],
[0, 0, -1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
],
dtype=torch.float32,
)
elif orientation == "PA":
# Rotates the C-arm about the x-axis by 90 degrees
# Rotates the C-arm about the z-axis by 90 degrees
# Reverses the direction of the y-axis
reorient = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[-1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
],
dtype=torch.float32,
)
elif orientation is None:
# Identity transform
reorient = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
],
dtype=torch.float32,
)
else:
raise ValueError(f"Unrecognized orientation {orientation}")
Expand All @@ -122,6 +124,7 @@ def read(
subject = Subject(
volume=volume,
mask=mask,
orientation=orientation,
reorient=reorient,
density=density,
fiducials=fiducials,
Expand Down Expand Up @@ -161,9 +164,13 @@ def read(
dim=0,
)

subject.volume.data = subject.volume.data * mask
subject.mask.data = subject.mask.data * mask
subject.density.data = subject.density.data * mask
# Mask all volumes, unless error, then just mask the density
try:
subject.volume.data = subject.volume.data * mask
subject.mask.data = subject.mask.data * mask
subject.density.data = subject.density.data * mask
except:
subject.density.data = subject.density.data * mask

return subject

Expand Down
20 changes: 12 additions & 8 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(
"_calibration",
torch.tensor(
[
[dely, 0, 0, -y0],
[0, delx, 0, -x0],
[delx, 0, 0, x0],
[0, dely, 0, y0],
[0, 0, sdd, 0],
[0, 0, 0, 1],
]
Expand All @@ -65,19 +65,19 @@ def sdd(self):

@property
def delx(self):
return self._calibration[1, 1].item()
return self._calibration[0, 0].item()

@property
def dely(self):
return self._calibration[0, 0].item()
return self._calibration[1, 1].item()

@property
def x0(self):
return -self._calibration[1, -1].item()
return -self._calibration[0, -1].item()

@property
def y0(self):
return -self._calibration[0, -1].item()
return -self._calibration[1, -1].item()

@property
def reorient(self):
Expand Down Expand Up @@ -107,7 +107,7 @@ def _initialize_carm(self: Detector):
center = torch.tensor([[0.0, 0.0, 1.0]], device=device)

# Use the standard basis for the detector plane
basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)
basis = torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], device=device)

# Construct the detector plane with different offsets for even or odd heights
# These ensure that the detector plane is centered around (0, 0, 1)
Expand All @@ -117,8 +117,12 @@ def _initialize_carm(self: Detector):
# Construct equally spaced points along the basis vectors
t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off
s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off
if self.reverse_x_axis:

t = -t
s = -s
if not self.reverse_x_axis:
s = -s

coefs = torch.cartesian_prod(t, s).reshape(-1, 2)
target = torch.einsum("cd,nc->nd", basis, coefs)
target += center
Expand Down
17 changes: 12 additions & 5 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def set_intrinsics_(
width if width is not None else self.detector.width,
delx if delx is not None else self.detector.delx,
dely if dely is not None else self.detector.dely,
x0 if x0 is not None else self.detector.x0,
y0 if y0 is not None else self.detector.y0,
x0 if x0 is not None else -self.detector.x0,
y0 if y0 is not None else -self.detector.y0,
self.subject.reorient,
n_subsample if n_subsample is not None else self.detector.n_subsample,
reverse_x_axis if reverse_x_axis is not None else self.detector.reverse_x_axis,
Expand All @@ -256,14 +256,21 @@ def perspective_projection(
pts: torch.Tensor,
):
"""Project points in world coordinates (3D) onto the pixel plane (2D)."""
# Poses in DiffDRR are world2camera, but perspective transforms use camera2world, so invert
extrinsic = (self.detector.reorient.compose(pose)).inverse()
x = extrinsic(pts)

# Project onto the detector plane
x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x)
z = x[..., -1].unsqueeze(-1).clone()
x = x / z

# Move origin to upper-left corner
x[..., 1] = self.detector.height - x[..., 1]
if self.detector.reverse_x_axis:
x[..., 1] = self.detector.width - x[..., 1]
return x[..., :2].flip(-1)
x[..., 0] = self.detector.width - x[..., 0]

return x[..., :2]

# %% ../notebooks/api/00_drr.ipynb 14
from torch.nn.functional import pad
Expand All @@ -276,7 +283,7 @@ def inverse_projection(
pts: torch.Tensor,
):
"""Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D)."""
pts = pts.flip(-1)
# pts = pts.flip(-1)
if self.detector.reverse_x_axis:
pts[..., 1] = self.detector.width - pts[..., 1]
x = self.detector.sdd * torch.einsum(
Expand Down
18 changes: 10 additions & 8 deletions diffdrr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def resample(

# %% ../notebooks/api/07_utils.ipynb 6
from kornia.geometry.camera.pinhole import PinholeCamera as KorniaPinholeCamera
from torchio import Subject

from diffdrr.detector import Detector

Expand All @@ -66,9 +67,11 @@ def __init__(
height: torch.Tensor,
width: torch.Tensor,
detector: Detector,
subject: Subject,
):
super().__init__(intrinsics, extrinsics, height, width)
self.f = detector.sdd
multiplier = -1 if subject.orientation == "PA" else 1
self.sdd = multiplier * detector.sdd
self.delx = detector.delx
self.dely = detector.dely
self.x0 = detector.x0
Expand All @@ -94,9 +97,9 @@ def pose(self):

from kornia.geometry.calibration import solve_pnp_dlt

from .detector import make_intrinsic_matrix
from .drr import DRR
from .pose import RigidTransform
from .detector import make_intrinsic_matrix


def get_pinhole_camera(
Expand All @@ -107,14 +110,12 @@ def get_pinhole_camera(
pose = deepcopy(pose).to(device="cpu", dtype=dtype)

# Make the intrinsic matrix (in pixels)
fx = drr.detector.sdd / drr.detector.delx
fy = drr.detector.sdd / drr.detector.dely
multiplier = -1 if drr.subject.orientation == "PA" else 1
fx = multiplier * drr.detector.sdd / drr.detector.delx
fy = multiplier * drr.detector.sdd / drr.detector.dely
u0 = drr.detector.x0 / drr.detector.delx + drr.detector.width / 2
v0 = drr.detector.y0 / drr.detector.dely + drr.detector.height / 2
intrinsics = torch.eye(4)[None]
intrinsics[0, :3, :3] = make_intrinsic_matrix(drr.detector)

torch.tensor(
intrinsics = torch.tensor(
[
[
[fx, 0.0, u0, 0.0],
Expand Down Expand Up @@ -156,6 +157,7 @@ def get_pinhole_camera(
torch.tensor([drr.detector.height]),
torch.tensor([drr.detector.width]),
drr.detector,
drr.subject,
)

return camera
17 changes: 12 additions & 5 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@
" width if width is not None else self.detector.width,\n",
" delx if delx is not None else self.detector.delx,\n",
" dely if dely is not None else self.detector.dely,\n",
" x0 if x0 is not None else self.detector.x0,\n",
" y0 if y0 is not None else self.detector.y0,\n",
" x0 if x0 is not None else -self.detector.x0,\n",
" y0 if y0 is not None else -self.detector.y0,\n",
" self.subject.reorient,\n",
" n_subsample if n_subsample is not None else self.detector.n_subsample,\n",
" reverse_x_axis if reverse_x_axis is not None else self.detector.reverse_x_axis,\n",
Expand Down Expand Up @@ -399,14 +399,21 @@
" pts: torch.Tensor,\n",
"):\n",
" \"\"\"Project points in world coordinates (3D) onto the pixel plane (2D).\"\"\"\n",
" # Poses in DiffDRR are world2camera, but perspective transforms use camera2world, so invert\n",
" extrinsic = (self.detector.reorient.compose(pose)).inverse()\n",
" x = extrinsic(pts)\n",
"\n",
" # Project onto the detector plane\n",
" x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n",
" z = x[..., -1].unsqueeze(-1).clone()\n",
" x = x / z\n",
"\n",
" # Move origin to upper-left corner\n",
" x[..., 1] = self.detector.height - x[..., 1]\n",
" if self.detector.reverse_x_axis:\n",
" x[..., 1] = self.detector.width - x[..., 1]\n",
" return x[..., :2].flip(-1)"
" x[..., 0] = self.detector.width - x[..., 0]\n",
" \n",
" return x[..., :2]"
]
},
{
Expand All @@ -427,7 +434,7 @@
" pts: torch.Tensor,\n",
"):\n",
" \"\"\"Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D).\"\"\"\n",
" pts = pts.flip(-1)\n",
" # pts = pts.flip(-1)\n",
" if self.detector.reverse_x_axis:\n",
" pts[..., 1] = self.detector.width - pts[..., 1]\n",
" x = self.detector.sdd * torch.einsum(\n",
Expand Down
20 changes: 12 additions & 8 deletions notebooks/api/02_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@
" \"_calibration\",\n",
" torch.tensor(\n",
" [\n",
" [dely, 0, 0, -y0],\n",
" [0, delx, 0, -x0],\n",
" [delx, 0, 0, x0],\n",
" [0, dely, 0, y0],\n",
" [0, 0, sdd, 0],\n",
" [0, 0, 0, 1],\n",
" ]\n",
Expand All @@ -120,19 +120,19 @@
"\n",
" @property\n",
" def delx(self):\n",
" return self._calibration[1, 1].item()\n",
" return self._calibration[0, 0].item()\n",
"\n",
" @property\n",
" def dely(self):\n",
" return self._calibration[0, 0].item()\n",
" return self._calibration[1, 1].item()\n",
"\n",
" @property\n",
" def x0(self):\n",
" return -self._calibration[1, -1].item()\n",
" return -self._calibration[0, -1].item()\n",
"\n",
" @property\n",
" def y0(self):\n",
" return -self._calibration[0, -1].item()\n",
" return -self._calibration[1, -1].item()\n",
"\n",
" @property\n",
" def reorient(self):\n",
Expand Down Expand Up @@ -170,7 +170,7 @@
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device)\n",
"\n",
" # Use the standard basis for the detector plane\n",
" basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)\n",
" basis = torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], device=device)\n",
"\n",
" # Construct the detector plane with different offsets for even or odd heights\n",
" # These ensure that the detector plane is centered around (0, 0, 1)\n",
Expand All @@ -180,8 +180,12 @@
" # Construct equally spaced points along the basis vectors\n",
" t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off\n",
" s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off\n",
" if self.reverse_x_axis:\n",
"\n",
" t = -t\n",
" s = -s\n",
" if not self.reverse_x_axis:\n",
" s = -s\n",
"\n",
" coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n",
" target = torch.einsum(\"cd,nc->nd\", basis, coefs)\n",
" target += center\n",
Expand Down
Loading

0 comments on commit 11b1f0a

Please sign in to comment.