Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide simpler access to camera intrinsics for optimization #240

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
'diffdrr.detector.Detector.__init__': ('api/detector.html#detector.__init__', 'diffdrr/detector.py'),
'diffdrr.detector.Detector._initialize_carm': ( 'api/detector.html#detector._initialize_carm',
'diffdrr/detector.py'),
'diffdrr.detector.Detector.calibration': ( 'api/detector.html#detector.calibration',
'diffdrr/detector.py'),
'diffdrr.detector.Detector.forward': ('api/detector.html#detector.forward', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.intrinsic': ('api/detector.html#detector.intrinsic', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.reorient': ('api/detector.html#detector.reorient', 'diffdrr/detector.py')},
Expand Down
2 changes: 1 addition & 1 deletion diffdrr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def read(
mask = torch.any(
torch.stack([subject.mask.data.squeeze() == idx for idx in labels]), dim=0
)
subject.density = subject.density * mask
subject.density.data = subject.density.data * mask

return subject

Expand Down
61 changes: 42 additions & 19 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def __init__(
reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
):
super().__init__()
self.sdd = sdd
# self.sdd = sdd
self.height = height
self.width = width
self.delx = delx
self.dely = dely
self.x0 = x0
self.y0 = y0
# self.delx = delx
# self.dely = dely
# self.x0 = x0
# self.y0 = y0
self.n_subsample = n_subsample
if self.n_subsample is not None:
self.subsamples = []
Expand All @@ -52,20 +52,39 @@ def __init__(
# Create a pose to reorient the scanner
self.register_buffer("_reorient", reorient)

# Create a calibration matrix that holds the detector's intrinsic parameters
self.register_buffer(
"_calibration",
torch.tensor(
[
[delx, 0, 0, x0],
[0, dely, 0, y0],
[0, 0, sdd, 0],
[0, 0, 0, 1],
]
),
)

@property
def reorient(self):
return RigidTransform(self._reorient)

@property
def calibration(self):
"""A 4x4 matrix that rescales the detector plane to world coordinates."""
return RigidTransform(self._calibration)

@property
def intrinsic(self):
"""The 3x3 intrinsic matrix."""
return make_intrinsic_matrix(
self.sdd,
self.delx,
self.dely,
self._calibration[2, 2].item(),
self._calibration[0, 0].item(),
self._calibration[1, 1].item(),
self.height,
self.width,
self.x0,
self.y0,
self._calibration[0, -1].item(),
self._calibration[1, -1].item(),
).to(self.source)

# %% ../notebooks/api/02_detector.ipynb 6
Expand All @@ -79,7 +98,7 @@ def _initialize_carm(self: Detector):

# Initialize the source at the origin and the center of the detector plane on the positive z-axis
source = torch.tensor([[0.0, 0.0, 0.0]], device=device)
center = torch.tensor([[0.0, 0.0, 1.0]], device=device) * self.sdd
center = torch.tensor([[0.0, 0.0, 1.0]], device=device) # * self.sdd

# Use the standard basis for the detector plane
basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)
Expand All @@ -91,10 +110,10 @@ def _initialize_carm(self: Detector):
# Construct equally spaced points along the basis vectors
t = (
torch.arange(-self.height // 2, self.height // 2, device=device) + h_off
) * self.delx
) # * self.delx
s = (
torch.arange(-self.width // 2, self.width // 2, device=device) + w_off
) * self.dely
) # * self.dely
if self.reverse_x_axis:
s = -s
coefs = torch.cartesian_prod(t, s).reshape(-1, 2)
Expand All @@ -105,9 +124,9 @@ def _initialize_carm(self: Detector):
source = source.unsqueeze(0)
target = target.unsqueeze(0)

# Apply principal point offset
target[..., 1] -= self.x0
target[..., 0] -= self.y0
# # Apply principal point offset
# target[..., 1] -= self.x0
# target[..., 0] -= self.y0

if self.n_subsample is not None:
sample = torch.randperm(self.height * self.width)[: int(self.n_subsample)]
Expand All @@ -120,9 +139,13 @@ def _initialize_carm(self: Detector):


@patch
def forward(self: Detector, pose: RigidTransform):
def forward(self: Detector, extrinsic: RigidTransform, calibration: RigidTransform):
"""Create source and target points for X-rays to trace through the volume."""
pose = self.reorient.compose(pose)
if calibration is None:
target = self.calibration(self.target)
else:
target = calibration(self.target)
pose = self.reorient.compose(extrinsic)
source = pose(self.source)
target = pose(self.target)
target = pose(target)
return source, target
3 changes: 2 additions & 1 deletion diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def forward(
*args, # Some batched representation of SE(3)
parameterization: str = None, # Specifies the representation of the rotation
convention: str = None, # If parameterization is Euler angles, specify convention
calibration: RigidTransform = None, # Optional calibration matrix with the detector's intrinsic parameters
mask_to_channels: bool = False, # If True, structures from the CT mask are rendered in separate channels
**kwargs, # Passed to the renderer
):
Expand All @@ -122,7 +123,7 @@ def forward(
pose = args[0]
else:
pose = convert(*args, parameterization=parameterization, convention=convention)
source, target = self.detector(pose)
source, target = self.detector(pose, calibration)

# Render the DRR
kwargs["mask"] = self.mask if mask_to_channels else None
Expand Down
3 changes: 2 additions & 1 deletion notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
" *args, # Some batched representation of SE(3)\n",
" parameterization: str = None, # Specifies the representation of the rotation\n",
" convention: str = None, # If parameterization is Euler angles, specify convention\n",
" calibration: RigidTransform = None, # Optional calibration matrix with the detector's intrinsic parameters\n",
" mask_to_channels: bool = False, # If True, structures from the CT mask are rendered in separate channels\n",
" **kwargs, # Passed to the renderer\n",
"):\n",
Expand All @@ -239,7 +240,7 @@
" pose = args[0]\n",
" else:\n",
" pose = convert(*args, parameterization=parameterization, convention=convention)\n",
" source, target = self.detector(pose)\n",
" source, target = self.detector(pose, calibration)\n",
"\n",
" # Render the DRR\n",
" kwargs[\"mask\"] = self.mask if mask_to_channels else None\n",
Expand Down
61 changes: 42 additions & 19 deletions notebooks/api/02_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@
" reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis\n",
" ):\n",
" super().__init__()\n",
" self.sdd = sdd\n",
" # self.sdd = sdd\n",
" self.height = height\n",
" self.width = width\n",
" self.delx = delx\n",
" self.dely = dely\n",
" self.x0 = x0\n",
" self.y0 = y0\n",
" # self.delx = delx\n",
" # self.dely = dely\n",
" # self.x0 = x0\n",
" # self.y0 = y0\n",
" self.n_subsample = n_subsample\n",
" if self.n_subsample is not None:\n",
" self.subsamples = []\n",
Expand All @@ -107,20 +107,39 @@
" # Create a pose to reorient the scanner\n",
" self.register_buffer(\"_reorient\", reorient)\n",
"\n",
" # Create a calibration matrix that holds the detector's intrinsic parameters\n",
" self.register_buffer(\n",
" \"_calibration\",\n",
" torch.tensor(\n",
" [\n",
" [delx, 0, 0, x0],\n",
" [0, dely, 0, y0],\n",
" [0, 0, sdd, 0],\n",
" [0, 0, 0, 1],\n",
" ]\n",
" ),\n",
" )\n",
"\n",
" @property\n",
" def reorient(self):\n",
" return RigidTransform(self._reorient)\n",
"\n",
" @property\n",
" def calibration(self):\n",
" \"\"\"A 4x4 matrix that rescales the detector plane to world coordinates.\"\"\"\n",
" return RigidTransform(self._calibration)\n",
"\n",
" @property\n",
" def intrinsic(self):\n",
" \"\"\"The 3x3 intrinsic matrix.\"\"\"\n",
" return make_intrinsic_matrix(\n",
" self.sdd,\n",
" self.delx,\n",
" self.dely,\n",
" self._calibration[2, 2].item(),\n",
" self._calibration[0, 0].item(),\n",
" self._calibration[1, 1].item(),\n",
" self.height,\n",
" self.width,\n",
" self.x0,\n",
" self.y0,\n",
" self._calibration[0, -1].item(),\n",
" self._calibration[1, -1].item(),\n",
" ).to(self.source)"
]
},
Expand All @@ -142,7 +161,7 @@
"\n",
" # Initialize the source at the origin and the center of the detector plane on the positive z-axis\n",
" source = torch.tensor([[0.0, 0.0, 0.0]], device=device)\n",
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device) * self.sdd\n",
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device) # * self.sdd\n",
"\n",
" # Use the standard basis for the detector plane\n",
" basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)\n",
Expand All @@ -154,10 +173,10 @@
" # Construct equally spaced points along the basis vectors\n",
" t = (\n",
" torch.arange(-self.height // 2, self.height // 2, device=device) + h_off\n",
" ) * self.delx\n",
" ) # * self.delx\n",
" s = (\n",
" torch.arange(-self.width // 2, self.width // 2, device=device) + w_off\n",
" ) * self.dely\n",
" ) # * self.dely\n",
" if self.reverse_x_axis:\n",
" s = -s\n",
" coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n",
Expand All @@ -168,9 +187,9 @@
" source = source.unsqueeze(0)\n",
" target = target.unsqueeze(0)\n",
"\n",
" # Apply principal point offset\n",
" target[..., 1] -= self.x0\n",
" target[..., 0] -= self.y0\n",
" # # Apply principal point offset\n",
" # target[..., 1] -= self.x0\n",
" # target[..., 0] -= self.y0\n",
"\n",
" if self.n_subsample is not None:\n",
" sample = torch.randperm(self.height * self.width)[: int(self.n_subsample)]\n",
Expand All @@ -191,11 +210,15 @@
"\n",
"\n",
"@patch\n",
"def forward(self: Detector, pose: RigidTransform):\n",
"def forward(self: Detector, extrinsic: RigidTransform, calibration: RigidTransform):\n",
" \"\"\"Create source and target points for X-rays to trace through the volume.\"\"\"\n",
" pose = self.reorient.compose(pose)\n",
" if calibration is None:\n",
" target = self.calibration(self.target)\n",
" else:\n",
" target = calibration(self.target)\n",
" pose = self.reorient.compose(extrinsic)\n",
" source = pose(self.source)\n",
" target = pose(self.target)\n",
" target = pose(target)\n",
" return source, target"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/api/03_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
" mask = torch.any(\n",
" torch.stack([subject.mask.data.squeeze() == idx for idx in labels]), dim=0\n",
" )\n",
" subject.density = subject.density * mask\n",
" subject.density.data = subject.density.data * mask\n",
"\n",
" return subject"
]
Expand Down
22 changes: 11 additions & 11 deletions notebooks/tutorials/introduction.ipynb

Large diffs are not rendered by default.

Loading