diff --git a/nerfstudio/data/dataparsers/colmap_dataparser.py b/nerfstudio/data/dataparsers/colmap_dataparser.py index a8a993d305..046f6d5d35 100644 --- a/nerfstudio/data/dataparsers/colmap_dataparser.py +++ b/nerfstudio/data/dataparsers/colmap_dataparser.py @@ -71,6 +71,10 @@ class ColmapDataParserConfig(DataParserConfig): """Path to depth maps directory. If not set, depths are not loaded.""" colmap_path: Path = Path("sparse/0") """Path to the colmap reconstruction directory relative to the data path.""" + load_3D_points: bool = True + """Whether to load the 3D points from the colmap reconstruction.""" + max_2D_matches_per_3D_point: int = -1 + """Maximum number of 2D matches per 3D point. If set to -1, all 2D matches are loaded. If set to 0, no 2D matches are loaded.""" class ColmapDataParser(DataParser): @@ -202,7 +206,7 @@ def _get_image_indices(self, image_filenames, split): raise ValueError(f"Unknown dataparser split {split}") return indices - def _generate_dataparser_outputs(self, split: str = "train"): + def _generate_dataparser_outputs(self, split: str = "train", **kwargs): assert self.config.data.exists(), f"Data directory {self.config.data} does not exist." colmap_path = self.config.data / self.config.colmap_path assert colmap_path.exists(), f"Colmap path {colmap_path} does not exist." @@ -328,6 +332,11 @@ def _generate_dataparser_outputs(self, split: str = "train"): applied_scale = float(meta["applied_scale"]) scale_factor *= applied_scale + metadata = {} + if self.config.load_3D_points: + # Load 3D points + metadata.update(self._load_3D_points(colmap_path, transform_matrix, scale_factor)) + dataparser_outputs = DataparserOutputs( image_filenames=image_filenames, cameras=cameras, @@ -338,10 +347,77 @@ def _generate_dataparser_outputs(self, split: str = "train"): metadata={ "depth_filenames": depth_filenames if len(depth_filenames) > 0 else None, "depth_unit_scale_factor": self.config.depth_unit_scale_factor, + **metadata, }, ) return dataparser_outputs + def _load_3D_points(self, colmap_path: Path, transform_matrix: torch.Tensor, scale_factor: float): + if (colmap_path / "points3D.bin").exists(): + colmap_points = colmap_utils.read_points3D_binary(colmap_path / "points3D.bin") + elif (colmap_path / "points3D.txt").exists(): + colmap_points = colmap_utils.read_points3D_text(colmap_path / "points3D.txt") + else: + raise ValueError(f"Could not find points3D.txt or points3D.bin in {colmap_path}") + points3D = torch.from_numpy(np.array([p.xyz for p in colmap_points.values()], dtype=np.float32)) + points3D = ( + torch.cat( + ( + points3D, + torch.ones_like(points3D[..., :1]), + ), + -1, + ) + @ transform_matrix.T + ) + points3D *= scale_factor + + # Load point colours + points3D_rgb = torch.from_numpy(np.array([p.rgb for p in colmap_points.values()], dtype=np.uint8)) + points3D_num_points = torch.tensor([len(p.image_ids) for p in colmap_points.values()], dtype=torch.int64) + out = { + "points3D_xyz": points3D, + "points3D_rgb": points3D_rgb, + "points3D_error": torch.from_numpy(np.array([p.error for p in colmap_points.values()], dtype=np.float32)), + "points3D_num_points": points3D_num_points, + } + if self.config.max_2D_matches_per_3D_point != 0: + if (colmap_path / "images.txt").exists(): + im_id_to_image = colmap_utils.read_images_text(colmap_path / "images.txt") + elif (colmap_path / "images.bin").exists(): + im_id_to_image = colmap_utils.read_images_binary(colmap_path / "images.bin") + else: + raise ValueError(f"Could not find images.txt or images.bin in {colmap_path}") + downscale_factor = self._downscale_factor + max_num_points = int(torch.max(points3D_num_points).item()) + if self.config.max_2D_matches_per_3D_point > 0: + max_num_points = min(max_num_points, self.config.max_2D_matches_per_3D_point) + points3D_image_ids = [] + points3D_image_xy = [] + for p in colmap_points.values(): + nids = np.array(p.image_ids, dtype=np.int64) + nxy_ids = np.array(p.point2D_idxs, dtype=np.int32) + if self.config.max_2D_matches_per_3D_point != -1: + # Randomly sample 2D matches + idxs = np.argsort(p.error)[: self.config.max_2D_matches_per_3D_point] + nids = nids[idxs] + nxy_ids = nxy_ids[idxs] + nxy = [im_id_to_image[im_id].xys[pt_idx] for im_id, pt_idx in zip(nids, nxy_ids)] + nxy = torch.from_numpy(np.stack(nxy).astype(np.float32)) + nids = torch.from_numpy(nids) + assert len(nids.shape) == 1 + assert len(nxy.shape) == 2 + points3D_image_ids.append( + torch.cat((nids, torch.full((max_num_points - len(nids),), -1, dtype=torch.int64))) + ) + points3D_image_xy.append( + torch.cat((nxy, torch.full((max_num_points - len(nxy), nxy.shape[-1]), 0, dtype=torch.float32))) + / downscale_factor + ) + out["points3D_image_ids"] = torch.stack(points3D_image_ids, dim=0) + out["points3D_image_xy"] = torch.stack(points3D_image_xy, dim=0) + return out + def _setup_downscale_factor( self, image_filenames: List[Path], mask_filenames: List[Path], depth_filenames: List[Path] ):