Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flip height and width in the detector #357

Merged
merged 7 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diffdrr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.5"
__version__ = "0.4.6"
47 changes: 27 additions & 20 deletions diffdrr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,35 +85,37 @@ def read(
# Frame-of-reference change
if orientation == "AP":
# Rotates the C-arm about the x-axis by 90 degrees
# Rotates the C-arm about the z-axis by -90 degrees
reorient = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, -1.0, 0.0],
[-1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
[1, 0, 0, 0],
[0, 0, -1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
],
dtype=torch.float32,
)
elif orientation == "PA":
# Rotates the C-arm about the x-axis by 90 degrees
# Rotates the C-arm about the z-axis by 90 degrees
# Reverses the direction of the y-axis
reorient = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[-1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
],
dtype=torch.float32,
)
elif orientation is None:
# Identity transform
reorient = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
],
dtype=torch.float32,
)
else:
raise ValueError(f"Unrecognized orientation {orientation}")
Expand All @@ -122,6 +124,7 @@ def read(
subject = Subject(
volume=volume,
mask=mask,
orientation=orientation,
reorient=reorient,
density=density,
fiducials=fiducials,
Expand Down Expand Up @@ -161,9 +164,13 @@ def read(
dim=0,
)

subject.volume.data = subject.volume.data * mask
subject.mask.data = subject.mask.data * mask
subject.density.data = subject.density.data * mask
# Mask all volumes, unless error, then just mask the density
try:
subject.volume.data = subject.volume.data * mask
subject.mask.data = subject.mask.data * mask
subject.density.data = subject.density.data * mask
except:
subject.density.data = subject.density.data * mask

return subject

Expand Down
20 changes: 12 additions & 8 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(
"_calibration",
torch.tensor(
[
[dely, 0, 0, -y0],
[0, delx, 0, -x0],
[delx, 0, 0, x0],
[0, dely, 0, y0],
[0, 0, sdd, 0],
[0, 0, 0, 1],
]
Expand All @@ -65,19 +65,19 @@ def sdd(self):

@property
def delx(self):
return self._calibration[1, 1].item()
return self._calibration[0, 0].item()

@property
def dely(self):
return self._calibration[0, 0].item()
return self._calibration[1, 1].item()

@property
def x0(self):
return -self._calibration[1, -1].item()
return -self._calibration[0, -1].item()

@property
def y0(self):
return -self._calibration[0, -1].item()
return -self._calibration[1, -1].item()

@property
def reorient(self):
Expand Down Expand Up @@ -107,7 +107,7 @@ def _initialize_carm(self: Detector):
center = torch.tensor([[0.0, 0.0, 1.0]], device=device)

# Use the standard basis for the detector plane
basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)
basis = torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], device=device)

# Construct the detector plane with different offsets for even or odd heights
# These ensure that the detector plane is centered around (0, 0, 1)
Expand All @@ -117,8 +117,12 @@ def _initialize_carm(self: Detector):
# Construct equally spaced points along the basis vectors
t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off
s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off
if self.reverse_x_axis:

t = -t
s = -s
if not self.reverse_x_axis:
s = -s

coefs = torch.cartesian_prod(t, s).reshape(-1, 2)
target = torch.einsum("cd,nc->nd", basis, coefs)
target += center
Expand Down
17 changes: 12 additions & 5 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def set_intrinsics_(
width if width is not None else self.detector.width,
delx if delx is not None else self.detector.delx,
dely if dely is not None else self.detector.dely,
x0 if x0 is not None else self.detector.x0,
y0 if y0 is not None else self.detector.y0,
x0 if x0 is not None else -self.detector.x0,
y0 if y0 is not None else -self.detector.y0,
self.subject.reorient,
n_subsample if n_subsample is not None else self.detector.n_subsample,
reverse_x_axis if reverse_x_axis is not None else self.detector.reverse_x_axis,
Expand All @@ -256,14 +256,21 @@ def perspective_projection(
pts: torch.Tensor,
):
"""Project points in world coordinates (3D) onto the pixel plane (2D)."""
# Poses in DiffDRR are world2camera, but perspective transforms use camera2world, so invert
extrinsic = (self.detector.reorient.compose(pose)).inverse()
x = extrinsic(pts)

# Project onto the detector plane
x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x)
z = x[..., -1].unsqueeze(-1).clone()
x = x / z

# Move origin to upper-left corner
x[..., 1] = self.detector.height - x[..., 1]
if self.detector.reverse_x_axis:
x[..., 1] = self.detector.width - x[..., 1]
return x[..., :2].flip(-1)
x[..., 0] = self.detector.width - x[..., 0]

return x[..., :2]

# %% ../notebooks/api/00_drr.ipynb 14
from torch.nn.functional import pad
Expand All @@ -276,7 +283,7 @@ def inverse_projection(
pts: torch.Tensor,
):
"""Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D)."""
pts = pts.flip(-1)
# pts = pts.flip(-1)
if self.detector.reverse_x_axis:
pts[..., 1] = self.detector.width - pts[..., 1]
x = self.detector.sdd * torch.einsum(
Expand Down
18 changes: 10 additions & 8 deletions diffdrr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def resample(

# %% ../notebooks/api/07_utils.ipynb 6
from kornia.geometry.camera.pinhole import PinholeCamera as KorniaPinholeCamera
from torchio import Subject

from diffdrr.detector import Detector

Expand All @@ -66,9 +67,11 @@ def __init__(
height: torch.Tensor,
width: torch.Tensor,
detector: Detector,
subject: Subject,
):
super().__init__(intrinsics, extrinsics, height, width)
self.f = detector.sdd
multiplier = -1 if subject.orientation == "PA" else 1
self.sdd = multiplier * detector.sdd
self.delx = detector.delx
self.dely = detector.dely
self.x0 = detector.x0
Expand All @@ -94,9 +97,9 @@ def pose(self):

from kornia.geometry.calibration import solve_pnp_dlt

from .detector import make_intrinsic_matrix
from .drr import DRR
from .pose import RigidTransform
from .detector import make_intrinsic_matrix


def get_pinhole_camera(
Expand All @@ -107,14 +110,12 @@ def get_pinhole_camera(
pose = deepcopy(pose).to(device="cpu", dtype=dtype)

# Make the intrinsic matrix (in pixels)
fx = drr.detector.sdd / drr.detector.delx
fy = drr.detector.sdd / drr.detector.dely
multiplier = -1 if drr.subject.orientation == "PA" else 1
fx = multiplier * drr.detector.sdd / drr.detector.delx
fy = multiplier * drr.detector.sdd / drr.detector.dely
u0 = drr.detector.x0 / drr.detector.delx + drr.detector.width / 2
v0 = drr.detector.y0 / drr.detector.dely + drr.detector.height / 2
intrinsics = torch.eye(4)[None]
intrinsics[0, :3, :3] = make_intrinsic_matrix(drr.detector)

torch.tensor(
intrinsics = torch.tensor(
[
[
[fx, 0.0, u0, 0.0],
Expand Down Expand Up @@ -156,6 +157,7 @@ def get_pinhole_camera(
torch.tensor([drr.detector.height]),
torch.tensor([drr.detector.width]),
drr.detector,
drr.subject,
)

return camera
17 changes: 12 additions & 5 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@
" width if width is not None else self.detector.width,\n",
" delx if delx is not None else self.detector.delx,\n",
" dely if dely is not None else self.detector.dely,\n",
" x0 if x0 is not None else self.detector.x0,\n",
" y0 if y0 is not None else self.detector.y0,\n",
" x0 if x0 is not None else -self.detector.x0,\n",
" y0 if y0 is not None else -self.detector.y0,\n",
" self.subject.reorient,\n",
" n_subsample if n_subsample is not None else self.detector.n_subsample,\n",
" reverse_x_axis if reverse_x_axis is not None else self.detector.reverse_x_axis,\n",
Expand Down Expand Up @@ -399,14 +399,21 @@
" pts: torch.Tensor,\n",
"):\n",
" \"\"\"Project points in world coordinates (3D) onto the pixel plane (2D).\"\"\"\n",
" # Poses in DiffDRR are world2camera, but perspective transforms use camera2world, so invert\n",
" extrinsic = (self.detector.reorient.compose(pose)).inverse()\n",
" x = extrinsic(pts)\n",
"\n",
" # Project onto the detector plane\n",
" x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n",
" z = x[..., -1].unsqueeze(-1).clone()\n",
" x = x / z\n",
"\n",
" # Move origin to upper-left corner\n",
" x[..., 1] = self.detector.height - x[..., 1]\n",
" if self.detector.reverse_x_axis:\n",
" x[..., 1] = self.detector.width - x[..., 1]\n",
" return x[..., :2].flip(-1)"
" x[..., 0] = self.detector.width - x[..., 0]\n",
" \n",
" return x[..., :2]"
]
},
{
Expand All @@ -427,7 +434,7 @@
" pts: torch.Tensor,\n",
"):\n",
" \"\"\"Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D).\"\"\"\n",
" pts = pts.flip(-1)\n",
" # pts = pts.flip(-1)\n",
" if self.detector.reverse_x_axis:\n",
" pts[..., 1] = self.detector.width - pts[..., 1]\n",
" x = self.detector.sdd * torch.einsum(\n",
Expand Down
20 changes: 12 additions & 8 deletions notebooks/api/02_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@
" \"_calibration\",\n",
" torch.tensor(\n",
" [\n",
" [dely, 0, 0, -y0],\n",
" [0, delx, 0, -x0],\n",
" [delx, 0, 0, x0],\n",
" [0, dely, 0, y0],\n",
" [0, 0, sdd, 0],\n",
" [0, 0, 0, 1],\n",
" ]\n",
Expand All @@ -120,19 +120,19 @@
"\n",
" @property\n",
" def delx(self):\n",
" return self._calibration[1, 1].item()\n",
" return self._calibration[0, 0].item()\n",
"\n",
" @property\n",
" def dely(self):\n",
" return self._calibration[0, 0].item()\n",
" return self._calibration[1, 1].item()\n",
"\n",
" @property\n",
" def x0(self):\n",
" return -self._calibration[1, -1].item()\n",
" return -self._calibration[0, -1].item()\n",
"\n",
" @property\n",
" def y0(self):\n",
" return -self._calibration[0, -1].item()\n",
" return -self._calibration[1, -1].item()\n",
"\n",
" @property\n",
" def reorient(self):\n",
Expand Down Expand Up @@ -170,7 +170,7 @@
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device)\n",
"\n",
" # Use the standard basis for the detector plane\n",
" basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)\n",
" basis = torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], device=device)\n",
"\n",
" # Construct the detector plane with different offsets for even or odd heights\n",
" # These ensure that the detector plane is centered around (0, 0, 1)\n",
Expand All @@ -180,8 +180,12 @@
" # Construct equally spaced points along the basis vectors\n",
" t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off\n",
" s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off\n",
" if self.reverse_x_axis:\n",
"\n",
" t = -t\n",
" s = -s\n",
" if not self.reverse_x_axis:\n",
" s = -s\n",
"\n",
" coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n",
" target = torch.einsum(\"cd,nc->nd\", basis, coefs)\n",
" target += center\n",
Expand Down
Loading
Loading