Skip to content

Commit

Permalink
Allow colmap parser to load 3D points (#2408)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkulhanek authored and maturk committed Sep 11, 2023
1 parent f486bd9 commit 3e9313b
Showing 1 changed file with 77 additions and 1 deletion.
78 changes: 77 additions & 1 deletion nerfstudio/data/dataparsers/colmap_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class ColmapDataParserConfig(DataParserConfig):
"""Path to depth maps directory. If not set, depths are not loaded."""
colmap_path: Path = Path("sparse/0")
"""Path to the colmap reconstruction directory relative to the data path."""
load_3D_points: bool = True
"""Whether to load the 3D points from the colmap reconstruction."""
max_2D_matches_per_3D_point: int = -1
"""Maximum number of 2D matches per 3D point. If set to -1, all 2D matches are loaded. If set to 0, no 2D matches are loaded."""


class ColmapDataParser(DataParser):
Expand Down Expand Up @@ -202,7 +206,7 @@ def _get_image_indices(self, image_filenames, split):
raise ValueError(f"Unknown dataparser split {split}")
return indices

def _generate_dataparser_outputs(self, split: str = "train"):
def _generate_dataparser_outputs(self, split: str = "train", **kwargs):
assert self.config.data.exists(), f"Data directory {self.config.data} does not exist."
colmap_path = self.config.data / self.config.colmap_path
assert colmap_path.exists(), f"Colmap path {colmap_path} does not exist."
Expand Down Expand Up @@ -328,6 +332,11 @@ def _generate_dataparser_outputs(self, split: str = "train"):
applied_scale = float(meta["applied_scale"])
scale_factor *= applied_scale

metadata = {}
if self.config.load_3D_points:
# Load 3D points
metadata.update(self._load_3D_points(colmap_path, transform_matrix, scale_factor))

dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames,
cameras=cameras,
Expand All @@ -338,10 +347,77 @@ def _generate_dataparser_outputs(self, split: str = "train"):
metadata={
"depth_filenames": depth_filenames if len(depth_filenames) > 0 else None,
"depth_unit_scale_factor": self.config.depth_unit_scale_factor,
**metadata,
},
)
return dataparser_outputs

def _load_3D_points(self, colmap_path: Path, transform_matrix: torch.Tensor, scale_factor: float):
if (colmap_path / "points3D.bin").exists():
colmap_points = colmap_utils.read_points3D_binary(colmap_path / "points3D.bin")
elif (colmap_path / "points3D.txt").exists():
colmap_points = colmap_utils.read_points3D_text(colmap_path / "points3D.txt")
else:
raise ValueError(f"Could not find points3D.txt or points3D.bin in {colmap_path}")
points3D = torch.from_numpy(np.array([p.xyz for p in colmap_points.values()], dtype=np.float32))
points3D = (
torch.cat(
(
points3D,
torch.ones_like(points3D[..., :1]),
),
-1,
)
@ transform_matrix.T
)
points3D *= scale_factor

# Load point colours
points3D_rgb = torch.from_numpy(np.array([p.rgb for p in colmap_points.values()], dtype=np.uint8))
points3D_num_points = torch.tensor([len(p.image_ids) for p in colmap_points.values()], dtype=torch.int64)
out = {
"points3D_xyz": points3D,
"points3D_rgb": points3D_rgb,
"points3D_error": torch.from_numpy(np.array([p.error for p in colmap_points.values()], dtype=np.float32)),
"points3D_num_points": points3D_num_points,
}
if self.config.max_2D_matches_per_3D_point != 0:
if (colmap_path / "images.txt").exists():
im_id_to_image = colmap_utils.read_images_text(colmap_path / "images.txt")
elif (colmap_path / "images.bin").exists():
im_id_to_image = colmap_utils.read_images_binary(colmap_path / "images.bin")
else:
raise ValueError(f"Could not find images.txt or images.bin in {colmap_path}")
downscale_factor = self._downscale_factor
max_num_points = int(torch.max(points3D_num_points).item())
if self.config.max_2D_matches_per_3D_point > 0:
max_num_points = min(max_num_points, self.config.max_2D_matches_per_3D_point)
points3D_image_ids = []
points3D_image_xy = []
for p in colmap_points.values():
nids = np.array(p.image_ids, dtype=np.int64)
nxy_ids = np.array(p.point2D_idxs, dtype=np.int32)
if self.config.max_2D_matches_per_3D_point != -1:
# Randomly sample 2D matches
idxs = np.argsort(p.error)[: self.config.max_2D_matches_per_3D_point]
nids = nids[idxs]
nxy_ids = nxy_ids[idxs]
nxy = [im_id_to_image[im_id].xys[pt_idx] for im_id, pt_idx in zip(nids, nxy_ids)]
nxy = torch.from_numpy(np.stack(nxy).astype(np.float32))
nids = torch.from_numpy(nids)
assert len(nids.shape) == 1
assert len(nxy.shape) == 2
points3D_image_ids.append(
torch.cat((nids, torch.full((max_num_points - len(nids),), -1, dtype=torch.int64)))
)
points3D_image_xy.append(
torch.cat((nxy, torch.full((max_num_points - len(nxy), nxy.shape[-1]), 0, dtype=torch.float32)))
/ downscale_factor
)
out["points3D_image_ids"] = torch.stack(points3D_image_ids, dim=0)
out["points3D_image_xy"] = torch.stack(points3D_image_xy, dim=0)
return out

def _setup_downscale_factor(
self, image_filenames: List[Path], mask_filenames: List[Path], depth_filenames: List[Path]
):
Expand Down

0 comments on commit 3e9313b

Please sign in to comment.