diff --git a/docs/nerfology/methods/splat.md b/docs/nerfology/methods/splat.md index fc1bb32a98..37190173e6 100644 --- a/docs/nerfology/methods/splat.md +++ b/docs/nerfology/methods/splat.md @@ -39,7 +39,7 @@ We provide a few additional variants: | `splatfacto-big` | More Gaussians, Higher Quality | ~12GB | Slower | -A full evalaution of Nerfstudio's implementation of Gaussian Splatting against the original Inria method can be found [here](https://docs.gsplat.studio/tests/eval.html). +A full evalaution of Nerfstudio's implementation of Gaussian Splatting against the original Inria method can be found [here](https://docs.gsplat.studio/main/tests/eval.html). #### Quality and Regularization The default settings provided maintain a balance between speed, quality, and splat file size, but if you care more about quality than training speed or size, you can decrease the alpha cull threshold diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 2349e6ab05..a67d492992 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -63,7 +63,7 @@ class FullImageDatamanagerConfig(DataManagerConfig): new images. If -1, never pick new images.""" eval_image_indices: Optional[Tuple[int, ...]] = (0,) """Specifies the image indices to use during eval; if None, uses all.""" - cache_images: Literal["cpu", "gpu"] = "cpu" + cache_images: Literal["cpu", "gpu"] = "gpu" """Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device.""" cache_images_type: Literal["uint8", "float32"] = "float32" """The image type returned from manager, caching images in uint8 saves memory""" @@ -247,11 +247,15 @@ def undistort_idx(idx: int) -> Dict[str, torch.Tensor]: cache["image"] = cache["image"].to(self.device) if "mask" in cache: cache["mask"] = cache["mask"].to(self.device) + if "depth" in cache: + cache["depth"] = cache["depth"].to(self.device) + self.train_cameras = self.train_dataset.cameras.to(self.device) elif cache_images_device == "cpu": for cache in undistorted_images: cache["image"] = cache["image"].pin_memory() if "mask" in cache: cache["mask"] = cache["mask"].pin_memory() + self.train_cameras = self.train_dataset.cameras else: assert_never(cache_images_device) @@ -340,11 +344,11 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]: if len(self.train_unseen_cameras) == 0: self.train_unseen_cameras = self.sample_train_cameras() - data = deepcopy(self.cached_train[image_idx]) + data = self.cached_train[image_idx] data["image"] = data["image"].to(self.device) - assert len(self.train_dataset.cameras.shape) == 1, "Assumes single batch dimension" - camera = self.train_dataset.cameras[image_idx : image_idx + 1].to(self.device) + assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension" + camera = self.train_cameras[image_idx : image_idx + 1].to(self.device) if camera.metadata is None: camera.metadata = {} camera.metadata["cam_idx"] = image_idx diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 9b29eca629..a88a306ced 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -25,10 +25,13 @@ import numpy as np import torch -from gsplat._torch_impl import quat_to_rotmat -from gsplat.project_gaussians import project_gaussians -from gsplat.rasterize import rasterize_gaussians -from gsplat.sh import num_sh_bases, spherical_harmonics +from gsplat.cuda_legacy._torch_impl import quat_to_rotmat + +try: + from gsplat.rendering import rasterization +except ImportError: + print("Please install gsplat>=1.0.0") +from gsplat.cuda_legacy._wrapper import num_sh_bases from pytorch_msssim import SSIM from torch.nn import Parameter from typing_extensions import Literal @@ -96,6 +99,25 @@ def resize_image(image: torch.Tensor, d: int): return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0) +@torch.compile() +def get_viewmat(optimized_camera_to_world): + """ + function that converts c2w to gsplat world2camera matrix, using compile for some speed + """ + R = optimized_camera_to_world[:, :3, :3] # 3 x 3 + T = optimized_camera_to_world[:, :3, 3:4] # 3 x 1 + # flip the z and y axes to align with gsplat conventions + R = R * torch.tensor([[[1, -1, -1]]], device=R.device, dtype=R.dtype) + # analytic matrix inverse to get world2camera matrix + R_inv = R.transpose(1, 2) + T_inv = -torch.bmm(R_inv, T) + viewmat = torch.zeros(R.shape[0], 4, 4, device=R.device, dtype=R.dtype) + viewmat[:, 3, 3] = 1.0 # homogenous + viewmat[:, :3, :3] = R_inv + viewmat[:, :3, 3:4] = T_inv + return viewmat + + @dataclass class SplatfactoModelConfig(ModelConfig): """Splatfacto Model Config, nerfstudio's implementation of Gaussian Splatting""" @@ -127,8 +149,6 @@ class SplatfactoModelConfig(ModelConfig): """number of samples to split gaussians into""" sh_degree_interval: int = 1000 """every n intervals turn on another sh degree""" - cull_screen_size: float = 0.15 - """if a gaussian is more than this percent of screen space, cull it""" split_screen_size: float = 0.05 """if a gaussian is more than this percent of screen space, split it""" stop_screen_size_at: int = 4000 @@ -191,7 +211,6 @@ def populate_modules(self): else: means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale) self.xys_grad_norm = None - self.max_2Dsize = None distances, _ = self.k_nearest_sklearn(means.data, 3) distances = torch.from_numpy(distances) # find the average of the three nearest neighbors for each point and use that as the scale @@ -395,25 +414,14 @@ def after_train(self, step: int): with torch.no_grad(): # keep track of a moving average of grad norms visible_mask = (self.radii > 0).flatten() - assert self.xys.absgrad is not None # type: ignore - grads = self.xys.absgrad.detach().norm(dim=-1) # type: ignore + grads = self.xys.absgrad[0][visible_mask].norm(dim=-1) # type: ignore # print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}") if self.xys_grad_norm is None: - self.xys_grad_norm = grads - self.vis_counts = torch.ones_like(self.xys_grad_norm) - else: - assert self.vis_counts is not None - self.vis_counts[visible_mask] = self.vis_counts[visible_mask] + 1 - self.xys_grad_norm[visible_mask] = grads[visible_mask] + self.xys_grad_norm[visible_mask] - - # update the max screen size, as a ratio of number of pixels - if self.max_2Dsize is None: - self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32) - newradii = self.radii.detach()[visible_mask] - self.max_2Dsize[visible_mask] = torch.maximum( - self.max_2Dsize[visible_mask], - newradii / float(max(self.last_size[0], self.last_size[1])), - ) + self.xys_grad_norm = torch.zeros(self.num_points, device=self.device, dtype=torch.float32) + self.vis_counts = torch.ones(self.num_points, device=self.device, dtype=torch.float32) + assert self.vis_counts is not None + self.vis_counts[visible_mask] += 1 + self.xys_grad_norm[visible_mask] += grads def set_crop(self, crop_box: Optional[OrientedBox]): self.crop_box = crop_box @@ -438,12 +446,10 @@ def refinement_after(self, optimizers: Optimizers, step): ) if do_densification: # then we densify - assert self.xys_grad_norm is not None and self.vis_counts is not None and self.max_2Dsize is not None + assert self.xys_grad_norm is not None and self.vis_counts is not None avg_grad_norm = (self.xys_grad_norm / self.vis_counts) * 0.5 * max(self.last_size[0], self.last_size[1]) high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze() splits = (self.scales.exp().max(dim=-1).values > self.config.densify_size_thresh).squeeze() - if self.step < self.config.stop_screen_size_at: - splits |= (self.max_2Dsize > self.config.split_screen_size).squeeze() splits &= high_grads nsamps = self.config.n_split_samples split_params = self.split_gaussians(splits, nsamps) @@ -456,16 +462,6 @@ def refinement_after(self, optimizers: Optimizers, step): torch.cat([param.detach(), split_params[name], dup_params[name]], dim=0) ) - # append zeros to the max_2Dsize tensor - self.max_2Dsize = torch.cat( - [ - self.max_2Dsize, - torch.zeros_like(split_params["scales"][:, 0]), - torch.zeros_like(dup_params["scales"][:, 0]), - ], - dim=0, - ) - split_idcs = torch.where(splits)[0] self.dup_in_all_optim(optimizers, split_idcs, nsamps) @@ -510,7 +506,6 @@ def refinement_after(self, optimizers: Optimizers, step): self.xys_grad_norm = None self.vis_counts = None - self.max_2Dsize = None def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None): """ @@ -527,10 +522,6 @@ def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None): if self.step > self.config.refine_every * self.config.reset_alpha_every: # cull huge ones toobigs = (torch.exp(self.scales).max(dim=-1).values > self.config.cull_scale_thresh).squeeze() - if self.step < self.config.stop_screen_size_at: - # cull big screen space - assert self.max_2Dsize is not None - toobigs = toobigs | (self.max_2Dsize > self.config.cull_screen_size).squeeze() culls = culls | toobigs toobigs_count = torch.sum(toobigs).item() for name, param in self.gauss_params.items(): @@ -670,12 +661,14 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: if not isinstance(camera, Cameras): print("Called get_outputs with not a camera") return {} - assert camera.shape[0] == 1, "Only one camera at a time" optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)[0, ...] # get the background color if self.training: + assert camera.shape[0] == 1, "Only one camera at a time" + optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera) + if self.config.background_color == "random": background = torch.rand(3, device=self.device) elif self.config.background_color == "white": @@ -685,6 +678,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: else: background = self.background_color.to(self.device) else: + optimized_camera_to_world = camera.camera_to_worlds + if renderers.BACKGROUND_COLOR_OVERRIDE is not None: background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device) else: @@ -696,25 +691,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: return self.get_empty_outputs(int(camera.width.item()), int(camera.height.item()), background) else: crop_ids = None - camera_downscale = self._get_downscale_factor() - camera.rescale_output_resolution(1 / camera_downscale) - # shift the camera to center of scene looking at center - R = optimized_camera_to_world[:3, :3] # 3 x 3 - T = optimized_camera_to_world[:3, 3:4] # 3 x 1 - - # flip the z and y axes to align with gsplat conventions - R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype)) - R = R @ R_edit - # analytic matrix inverse to get world2camera matrix - R_inv = R.T - T_inv = -R_inv @ T - viewmat = torch.eye(4, device=R.device, dtype=R.dtype) - viewmat[:3, :3] = R_inv - viewmat[:3, 3:4] = T_inv - # calculate the FOV of the camera given fx and fy, width and height - cx = camera.cx.item() - cy = camera.cy.item() - W, H = int(camera.width.item()), int(camera.height.item()) + camera_scale_fac = 1.0 / self._get_downscale_factor() + viewmat = get_viewmat(optimized_camera_to_world) + W, H = int(camera.width[0] * camera_scale_fac), int(camera.height[0] * camera_scale_fac) self.last_size = (H, W) if crop_ids is not None: @@ -734,79 +713,58 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1) BLOCK_WIDTH = 16 # this controls the tile size of rasterization, 16 is a good default - self.xys, depths, self.radii, conics, comp, num_tiles_hit, cov3d = project_gaussians( # type: ignore - means_crop, - torch.exp(scales_crop), - 1, - quats_crop / quats_crop.norm(dim=-1, keepdim=True), - viewmat.squeeze()[:3, :], - camera.fx.item(), - camera.fy.item(), - cx, - cy, - H, - W, - BLOCK_WIDTH, - ) # type: ignore - - # rescale the camera back to original dimensions before returning - camera.rescale_output_resolution(camera_downscale) - - if (self.radii).sum() == 0: - return self.get_empty_outputs(W, H, background) - - if self.config.sh_degree > 0: - viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3) - n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) - rgbs = spherical_harmonics(n, viewdirs, colors_crop) # input unnormalized viewdirs - rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore - else: - rgbs = torch.sigmoid(colors_crop[:, 0, :]) - - assert (num_tiles_hit > 0).any() # type: ignore - + K = camera.get_intrinsics_matrices().cuda() + K[:, :2, :] *= camera_scale_fac # apply the compensation of screen space blurring to gaussians - if self.config.rasterize_mode == "antialiased": - opacities = torch.sigmoid(opacities_crop) * comp[:, None] - elif self.config.rasterize_mode == "classic": - opacities = torch.sigmoid(opacities_crop) - else: + if self.config.rasterize_mode not in ["antialiased", "classic"]: raise ValueError("Unknown rasterize_mode: %s", self.config.rasterize_mode) - rgb, alpha = rasterize_gaussians( # type: ignore - self.xys, - depths, - self.radii, - conics, - num_tiles_hit, # type: ignore - rgbs, - opacities, - H, - W, - BLOCK_WIDTH, - background=background, - return_alpha=True, - ) # type: ignore - alpha = alpha[..., None] - rgb = torch.clamp(rgb, max=1.0) # type: ignore - depth_im = None if self.config.output_depth_during_training or not self.training: - depth_im = rasterize_gaussians( # type: ignore - self.xys, - depths, - self.radii, - conics, - num_tiles_hit, # type: ignore - depths[:, None].repeat(1, 3), - opacities, - H, - W, - BLOCK_WIDTH, - background=torch.zeros(3, device=self.device), - )[..., 0:1] # type: ignore - depth_im = torch.where(alpha > 0, depth_im / alpha, depth_im.detach().max()) - - return {"rgb": rgb, "depth": depth_im, "accumulation": alpha, "background": background} # type: ignore + render_mode = "RGB+ED" + else: + render_mode = "RGB" + + if self.config.sh_degree > 0: + sh_degree_to_use = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) + else: + sh_degree_to_use = None + + render, alpha, info = rasterization( + means=means_crop, + quats=quats_crop / quats_crop.norm(dim=-1, keepdim=True), + scales=torch.exp(scales_crop), + opacities=torch.sigmoid(opacities_crop).squeeze(-1), + colors=colors_crop, + viewmats=viewmat, # [1, 4, 4] + Ks=K, # [1, 3, 3] + width=W, + height=H, + tile_size=BLOCK_WIDTH, + packed=False, + near_plane=0.01, + far_plane=1e10, + render_mode=render_mode, + sh_degree=sh_degree_to_use, + sparse_grad=False, + absgrad=True, + rasterize_mode=self.config.rasterize_mode, + # set some threshold to disregrad small gaussians for faster rendering. + # radius_clip=3.0, + ) + if self.training and info["means2d"].requires_grad: + info["means2d"].retain_grad() + self.xys = info["means2d"] # [1, N, 2] + self.radii = info["radii"][0] # [N] + + alpha = alpha[:, ...] + rgb = render[:, ..., :3] + (1 - alpha) * background + rgb = torch.clamp(rgb, 0.0, 1.0) + if render_mode == "RGB+ED": + depth_im = render[:, ..., 3:4] + depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0) + else: + depth_im = None + return {"rgb": rgb.squeeze(0), "depth": depth_im, "accumulation": alpha.squeeze(0), "background": background} # type: ignore def get_gt_img(self, image: torch.Tensor): """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose diff --git a/nerfstudio/scripts/process_data.py b/nerfstudio/scripts/process_data.py index b2c2fa13fc..7a0968af52 100644 --- a/nerfstudio/scripts/process_data.py +++ b/nerfstudio/scripts/process_data.py @@ -152,6 +152,9 @@ def main(self) -> None: zip_ref.extractall(self.output_dir) extracted_folder = zip_ref.namelist()[0].split("/")[0] self.data = self.output_dir / extracted_folder + if not (self.data / "keyframes").exists(): + # new versions of polycam data have a different structure, strip the last dir off + self.data = self.output_dir if (self.data / "keyframes" / "corrected_images").exists() and not self.use_uncorrected_images: polycam_image_dir = self.data / "keyframes" / "corrected_images" diff --git a/pixi.lock b/pixi.lock index 72631ecaa0..d4036678c0 100644 --- a/pixi.lock +++ b/pixi.lock @@ -369,7 +369,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/fd/5b/8f0c4a5bb9fd491c277c21eff7ccae71b47d43c4446c9d0c6cff2fe8c2c4/gitdb-4.0.11-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/bd/cc3a402a6439c15c3d4294333e13042b915bbeab54edc457c723931fed3f/GitPython-3.1.43-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/47/82/5f51b0ac0e670aa6551f351c6c8a479149a36c413dd76db4b98d26dddbea/grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/b4/b2/0c3fe3a11a2e8cdf9216ba92e97172d08f769082181f6f10807517db9295/gsplat-0.1.11-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/53/71/d9bf12b11f608f0ad078fa962a9ab61a2cf28fa9739293a1e842656bc419/gsplat-1.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/94/00/94bf8573e7487b7c37f2b613fc381880d48ec2311f2e859b8a5817deb4df/h5py-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/78/d4/e5d7e4f2174f8a4d63c8897d79eb8fe2503f7ecc03282fee1fa2719c2704/httpcore-1.0.5-py3-none-any.whl @@ -2710,9 +2710,9 @@ packages: requires_python: '>=3.8' - kind: pypi name: gsplat - version: 0.1.11 - url: https://files.pythonhosted.org/packages/b4/b2/0c3fe3a11a2e8cdf9216ba92e97172d08f769082181f6f10807517db9295/gsplat-0.1.11-py3-none-any.whl - sha256: 2d47c5d4c245b46d85b7eeae8ca4a96df3f1e2354d8daf254871516cb251b75c + version: 1.0.0 + url: https://files.pythonhosted.org/packages/53/71/d9bf12b11f608f0ad078fa962a9ab61a2cf28fa9739293a1e842656bc419/gsplat-1.0.0-py3-none-any.whl + sha256: a21eead19150e80a0531dd24e5d717c67892cb381657c8411ec8b318b293a032 requires_dist: - jaxtyping - rich >=12 @@ -6085,7 +6085,7 @@ packages: - xatlas - trimesh >=3.20.2 - timm ==0.6.7 - - gsplat >=0.1.11 + - gsplat ==1.0.0 - pytorch-msssim - pathos - packaging diff --git a/pyproject.toml b/pyproject.toml index d6e61c0634..37fd2cd4a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ dependencies = [ "xatlas", "trimesh>=3.20.2", "timm==0.6.7", - "gsplat>=0.1.11,<1.0.0", + "gsplat==1.0.0", "pytorch-msssim", "pathos", "packaging",