diff --git a/nerfstudio/cameras/camera_optimizers.py b/nerfstudio/cameras/camera_optimizers.py index b192ad8b93..e7e49f734e 100644 --- a/nerfstudio/cameras/camera_optimizers.py +++ b/nerfstudio/cameras/camera_optimizers.py @@ -23,6 +23,7 @@ from typing import Literal, Optional, Type, Union import torch +import tyro from jaxtyping import Float, Int from torch import Tensor, nn from typing_extensions import assert_never @@ -51,10 +52,11 @@ class CameraOptimizerConfig(InstantiateConfig): rot_l2_penalty: float = 1e-3 """L2 penalty on rotation parameters.""" - optimizer: Optional[OptimizerConfig] = field(default=None) + # tyro.conf.Suppress prevents us from creating CLI arguments for these fields. + optimizer: tyro.conf.Suppress[Optional[OptimizerConfig]] = field(default=None) """Deprecated, now specified inside the optimizers dict""" - scheduler: Optional[SchedulerConfig] = field(default=None) + scheduler: tyro.conf.Suppress[Optional[SchedulerConfig]] = field(default=None) """Deprecated, now specified inside the optimizers dict""" def __post_init__(self): diff --git a/nerfstudio/cameras/cameras.py b/nerfstudio/cameras/cameras.py index f81323b45b..4202c8c273 100644 --- a/nerfstudio/cameras/cameras.py +++ b/nerfstudio/cameras/cameras.py @@ -23,7 +23,6 @@ import cv2 import torch -import torchvision from jaxtyping import Float, Int, Shaped from torch import Tensor from torch.nn import Parameter @@ -959,7 +958,11 @@ def to_json( image_uint8 = (image * 255).detach().type(torch.uint8) if max_size is not None: image_uint8 = image_uint8.permute(2, 0, 1) - image_uint8 = torchvision.transforms.functional.resize(image_uint8, max_size, antialias=None) # type: ignore + + # torchvision can be slow to import, so we do it lazily. + import torchvision.transforms.functional as TF + + image_uint8 = TF.resize(image_uint8, max_size, antialias=None) # type: ignore image_uint8 = image_uint8.permute(1, 2, 0) image_uint8 = image_uint8.cpu().numpy() data = cv2.imencode(".jpg", image_uint8)[1].tobytes() # type: ignore diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 2131d5260b..210ce5757d 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -41,6 +41,7 @@ ) import torch +import tyro from torch import nn from torch.nn import Parameter from torch.utils.data.distributed import DistributedSampler @@ -334,7 +335,9 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - camera_optimizer: Optional[CameraOptimizerConfig] = field(default=None) + + # tyro.conf.Suppress prevents us from creating CLI arguments for this field. + camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) """Deprecated, has been moved to the model config.""" pixel_sampler: PixelSamplerConfig = field(default_factory=PixelSamplerConfig) """Specifies the pixel sampler used to sample pixels from images.""" diff --git a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py index aaa88c7691..554e88dac0 100644 --- a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py @@ -20,7 +20,6 @@ from typing import Literal, Optional, Type import numpy as np -import open3d as o3d import torch from PIL import Image @@ -337,6 +336,8 @@ def _generate_dataparser_outputs(self, split="train"): return dataparser_outputs def _load_3D_points(self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float): + import open3d as o3d # Importing open3d is slow, so we only do it if we need it. + pcd = o3d.io.read_point_cloud(str(ply_file_path)) points3D = torch.from_numpy(np.asarray(pcd.points, dtype=np.float32)) diff --git a/nerfstudio/data/dataparsers/nuscenes_dataparser.py b/nerfstudio/data/dataparsers/nuscenes_dataparser.py index 19d215b140..f1fecc8889 100644 --- a/nerfstudio/data/dataparsers/nuscenes_dataparser.py +++ b/nerfstudio/data/dataparsers/nuscenes_dataparser.py @@ -22,7 +22,6 @@ import numpy as np import pyquaternion import torch -from nuscenes.nuscenes import NuScenes as NuScenesDatabase from nerfstudio.cameras.cameras import Cameras, CameraType from nerfstudio.data.dataparsers.base_dataparser import DataParser, DataParserConfig, DataparserOutputs @@ -81,6 +80,9 @@ class NuScenes(DataParser): config: NuScenesDataParserConfig def _generate_dataparser_outputs(self, split="train"): + # nuscenes is slow to import, so we only do it if we need it. + from nuscenes.nuscenes import NuScenes as NuScenesDatabase + nusc = NuScenesDatabase( version=self.config.version, dataroot=str(self.config.data_dir.absolute()), diff --git a/nerfstudio/exporter/exporter_utils.py b/nerfstudio/exporter/exporter_utils.py index 45f6ef6866..40a12212bd 100644 --- a/nerfstudio/exporter/exporter_utils.py +++ b/nerfstudio/exporter/exporter_utils.py @@ -21,10 +21,9 @@ import sys from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import numpy as np -import open3d as o3d import pymeshlab import torch from jaxtyping import Float @@ -38,6 +37,11 @@ from nerfstudio.pipelines.base_pipeline import Pipeline, VanillaPipeline from nerfstudio.utils.rich_utils import CONSOLE, ItersPerSecColumn +if TYPE_CHECKING: + # Importing open3d can take ~1 second, so only do it below if we actually + # need it. + import open3d as o3d + @dataclass class Mesh: @@ -193,6 +197,8 @@ def generate_point_cloud( rgbs = torch.cat(rgbs, dim=0) view_directions = torch.cat(view_directions, dim=0).cpu() + import open3d as o3d + pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points.double().cpu().numpy()) pcd.colors = o3d.utility.Vector3dVector(rgbs.double().cpu().numpy()) diff --git a/nerfstudio/generative/deepfloyd.py b/nerfstudio/generative/deepfloyd.py index 03c02b5cdd..a2faa9b75b 100644 --- a/nerfstudio/generative/deepfloyd.py +++ b/nerfstudio/generative/deepfloyd.py @@ -13,6 +13,7 @@ # limitations under the License. import gc +import sys from pathlib import Path from typing import List, Optional, Union @@ -24,15 +25,7 @@ from torch import Generator, Tensor, nn from torch.cuda.amp.grad_scaler import GradScaler -from nerfstudio.generative.utils import CatchMissingPackages - -try: - from diffusers import DiffusionPipeline, IFPipeline, IFPipeline as IFOrig - from diffusers.pipelines.deepfloyd_if import IFPipelineOutput, IFPipelineOutput as IFOutputOrig - from transformers import T5EncoderModel - -except ImportError: - IFPipeline = IFPipelineOutput = T5EncoderModel = CatchMissingPackages() +from nerfstudio.utils.rich_utils import CONSOLE IMG_DIM = 64 @@ -47,6 +40,16 @@ def __init__(self, device: Union[torch.device, str]): super().__init__() self.device = device + try: + from diffusers import DiffusionPipeline, IFPipeline + from transformers import T5EncoderModel + + except ImportError: + CONSOLE.print("[bold red]Missing Stable Diffusion packages.") + CONSOLE.print(r"Install using [yellow]pip install nerfstudio\[gen][/yellow]") + CONSOLE.print(r"or [yellow]pip install -e .\[gen][/yellow] if installing from source.") + sys.exit(1) + self.text_encoder = T5EncoderModel.from_pretrained( "DeepFloyd/IF-I-L-v1.0", subfolder="text_encoder", @@ -90,6 +93,8 @@ def delete_text_encoder(self): gc.collect() torch.cuda.empty_cache() + from diffusers import DiffusionPipeline, IFPipeline + self.pipe = IFPipeline.from_pretrained( "DeepFloyd/IF-I-L-v1.0", text_encoder=None, @@ -126,6 +131,8 @@ def get_text_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + from diffusers import DiffusionPipeline + assert isinstance(self.pipe, DiffusionPipeline) with torch.no_grad(): prompt_embeds, negative_embeds = self.pipe.encode_prompt(prompt, negative_prompt=negative_prompt) @@ -200,6 +207,9 @@ def prompt_to_image( The generated image. """ + from diffusers import DiffusionPipeline, IFPipeline as IFOrig + from diffusers.pipelines.deepfloyd_if import IFPipelineOutput as IFOutputOrig + prompts = [prompts] if isinstance(prompts, str) else prompts negative_prompts = [negative_prompts] if isinstance(negative_prompts, str) else negative_prompts assert isinstance(self.pipe, DiffusionPipeline) diff --git a/nerfstudio/generative/stable_diffusion.py b/nerfstudio/generative/stable_diffusion.py index 7bf0fb96d8..66a899e9c5 100644 --- a/nerfstudio/generative/stable_diffusion.py +++ b/nerfstudio/generative/stable_diffusion.py @@ -16,10 +16,10 @@ # Modified from https://github.com/ashawkey/stable-dreamfusion/blob/main/nerf/sd.py +import sys from pathlib import Path from typing import List, Optional, Union -import mediapy import numpy as np import torch import torch.nn.functional as F @@ -28,16 +28,8 @@ from torch import Tensor, nn from torch.cuda.amp.grad_scaler import GradScaler -from nerfstudio.generative.utils import CatchMissingPackages from nerfstudio.utils.rich_utils import CONSOLE -try: - from diffusers import DiffusionPipeline, PNDMScheduler, StableDiffusionPipeline - -except ImportError: - PNDMScheduler = StableDiffusionPipeline = CatchMissingPackages() - - IMG_DIM = 512 CONST_SCALE = 0.18215 SD_IDENTIFIERS = { @@ -57,6 +49,15 @@ class StableDiffusion(nn.Module): def __init__(self, device: Union[torch.device, str], num_train_timesteps: int = 1000, version="1-5") -> None: super().__init__() + try: + from diffusers import DiffusionPipeline, PNDMScheduler, StableDiffusionPipeline + + except ImportError: + CONSOLE.print("[bold red]Missing Stable Diffusion packages.") + CONSOLE.print(r"Install using [yellow]pip install nerfstudio\[gen][/yellow]") + CONSOLE.print(r"or [yellow]pip install -e .\[gen][/yellow] if installing from source.") + sys.exit(1) + self.device = device self.num_train_timesteps = num_train_timesteps @@ -319,6 +320,9 @@ def generate_image( with torch.no_grad(): sd = StableDiffusion(cuda_device) imgs = sd.prompt_to_img(prompt, negative, steps) + + import mediapy # Slow to import, so we do it lazily. + mediapy.write_image(str(save_path), imgs[0]) diff --git a/nerfstudio/generative/utils.py b/nerfstudio/generative/utils.py deleted file mode 100644 index 5e8ed02e58..0000000000 --- a/nerfstudio/generative/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility helper functions for diffusion models""" - -import sys - -from nerfstudio.utils.rich_utils import CONSOLE - - -class CatchMissingPackages: - """Class to catch missing environment packages related to diffusion models.""" - - def __init__(self): - pass - - def __call__(self, *args, **kwargs): - CONSOLE.print("[bold red]Missing Stable Diffusion packages.") - CONSOLE.print(r"Install using [yellow]pip install nerfstudio\[gen][/yellow]") - CONSOLE.print(r"or [yellow]pip install -e .\[gen][/yellow] if installing from source.") - sys.exit(1) - - def __getattr__(self, attr): - return self.__call__ diff --git a/nerfstudio/models/base_surface_model.py b/nerfstudio/models/base_surface_model.py index 885e08e29f..9929683212 100644 --- a/nerfstudio/models/base_surface_model.py +++ b/nerfstudio/models/base_surface_model.py @@ -25,9 +25,6 @@ import torch import torch.nn.functional as F from torch.nn import Parameter -from torchmetrics.functional import structural_similarity_index_measure -from torchmetrics.image import PeakSignalNoiseRatio -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from nerfstudio.cameras.rays import RayBundle from nerfstudio.field_components.encodings import NeRFEncoding @@ -156,6 +153,10 @@ def populate_modules(self): self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1) # metrics + from torchmetrics.functional import structural_similarity_index_measure + from torchmetrics.image import PeakSignalNoiseRatio + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure self.lpips = LearnedPerceptualImagePatchSimilarity() diff --git a/nerfstudio/models/gaussian_splatting.py b/nerfstudio/models/gaussian_splatting.py index 60563a96ae..ab17c6a37c 100644 --- a/nerfstudio/models/gaussian_splatting.py +++ b/nerfstudio/models/gaussian_splatting.py @@ -25,17 +25,13 @@ import numpy as np import torch -import torchvision.transforms.functional as TF from gsplat._torch_impl import quat_to_rotmat from gsplat.compute_cumulative_intersects import compute_cumulative_intersects from gsplat.project_gaussians import ProjectGaussians from gsplat.rasterize import RasterizeGaussians from gsplat.sh import SphericalHarmonics, num_sh_bases from pytorch_msssim import SSIM -from sklearn.neighbors import NearestNeighbors from torch.nn import Parameter -from torchmetrics.image import PeakSignalNoiseRatio -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig from nerfstudio.cameras.cameras import Cameras @@ -205,6 +201,9 @@ def populate_modules(self): self.opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(self.num_points, 1))) # metrics + from torchmetrics.image import PeakSignalNoiseRatio + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = SSIM(data_range=1.0, size_average=True, channel=3) self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) @@ -248,7 +247,7 @@ def load_state_dict(self, dict, **kwargs): # type: ignore def k_nearest_sklearn(self, x: torch.Tensor, k: int): """ - Find k-nearest neighbors using sklearn's NearestNeighbors. + Find k-nearest neighbors using sklearn's NearestNeighbors. x: The data tensor of shape [num_samples, num_features] k: The number of neighbors to retrieve """ @@ -256,6 +255,8 @@ def k_nearest_sklearn(self, x: torch.Tensor, k: int): x_np = x.cpu().numpy() # Build the nearest neighbors model + from sklearn.neighbors import NearestNeighbors + nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np) # Find the k-nearest neighbors @@ -733,6 +734,10 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: d = self._get_downscale_factor() if d > 1: newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] + + # torchvision can be slow to import, so we do it lazily. + import torchvision.transforms.functional as TF + gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) else: gt_img = batch["image"] @@ -756,6 +761,10 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te d = self._get_downscale_factor() if d > 1: newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] + + # torchvision can be slow to import, so we do it lazily. + import torchvision.transforms.functional as TF + gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) else: gt_img = batch["image"] @@ -807,6 +816,9 @@ def get_image_metrics_and_images( """ d = self._get_downscale_factor() if d > 1: + # torchvision can be slow to import, so we do it lazily. + import torchvision.transforms.functional as TF + newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) predicted_rgb = TF.resize(outputs["rgb"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) diff --git a/nerfstudio/models/instant_ngp.py b/nerfstudio/models/instant_ngp.py index 7dbcc30ccf..1b0304d5bd 100644 --- a/nerfstudio/models/instant_ngp.py +++ b/nerfstudio/models/instant_ngp.py @@ -24,9 +24,6 @@ import nerfacc import torch from torch.nn import Parameter -from torchmetrics.functional import structural_similarity_index_measure -from torchmetrics.image import PeakSignalNoiseRatio -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from nerfstudio.cameras.rays import RayBundle from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation @@ -138,6 +135,10 @@ def populate_modules(self): self.rgb_loss = MSELoss() # metrics + from torchmetrics.functional import structural_similarity_index_measure + from torchmetrics.image import PeakSignalNoiseRatio + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) diff --git a/nerfstudio/models/mipnerf.py b/nerfstudio/models/mipnerf.py index b7e792100f..e033681f73 100644 --- a/nerfstudio/models/mipnerf.py +++ b/nerfstudio/models/mipnerf.py @@ -21,9 +21,6 @@ import torch from torch.nn import Parameter -from torchmetrics.functional import structural_similarity_index_measure -from torchmetrics.image import PeakSignalNoiseRatio -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from nerfstudio.cameras.rays import RayBundle from nerfstudio.field_components.encodings import NeRFEncoding @@ -85,6 +82,10 @@ def populate_modules(self): self.rgb_loss = MSELoss() # metrics + from torchmetrics.functional import structural_similarity_index_measure + from torchmetrics.image import PeakSignalNoiseRatio + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) diff --git a/nerfstudio/models/nerfacto.py b/nerfstudio/models/nerfacto.py index 667d23eb81..b1898f3d66 100644 --- a/nerfstudio/models/nerfacto.py +++ b/nerfstudio/models/nerfacto.py @@ -24,9 +24,6 @@ import numpy as np import torch from torch.nn import Parameter -from torchmetrics.functional import structural_similarity_index_measure -from torchmetrics.image import PeakSignalNoiseRatio -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig from nerfstudio.cameras.rays import RayBundle, RaySamples @@ -237,6 +234,10 @@ def update_schedule(step): self.rgb_loss = MSELoss() self.step = 0 # metrics + from torchmetrics.functional import structural_similarity_index_measure + from torchmetrics.image import PeakSignalNoiseRatio + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) diff --git a/nerfstudio/models/semantic_nerfw.py b/nerfstudio/models/semantic_nerfw.py index 8d5be08497..f3dff767be 100644 --- a/nerfstudio/models/semantic_nerfw.py +++ b/nerfstudio/models/semantic_nerfw.py @@ -24,9 +24,6 @@ import numpy as np import torch from torch.nn import Parameter -from torchmetrics.functional import structural_similarity_index_measure -from torchmetrics.image import PeakSignalNoiseRatio -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from nerfstudio.cameras.rays import RayBundle from nerfstudio.data.dataparsers.base_dataparser import Semantics @@ -135,6 +132,10 @@ def populate_modules(self): self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction="mean") # metrics + from torchmetrics.functional import structural_similarity_index_measure + from torchmetrics.image import PeakSignalNoiseRatio + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) diff --git a/nerfstudio/models/tensorf.py b/nerfstudio/models/tensorf.py index 8827f896f3..d19b0e2fb3 100644 --- a/nerfstudio/models/tensorf.py +++ b/nerfstudio/models/tensorf.py @@ -24,9 +24,6 @@ import numpy as np import torch from torch.nn import Parameter -from torchmetrics.functional import structural_similarity_index_measure -from torchmetrics.image import PeakSignalNoiseRatio -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig from nerfstudio.cameras.rays import RayBundle @@ -236,6 +233,10 @@ def populate_modules(self): self.rgb_loss = MSELoss() # metrics + from torchmetrics.functional import structural_similarity_index_measure + from torchmetrics.image import PeakSignalNoiseRatio + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) diff --git a/nerfstudio/models/vanilla_nerf.py b/nerfstudio/models/vanilla_nerf.py index 795bbb23c8..dd8565104f 100644 --- a/nerfstudio/models/vanilla_nerf.py +++ b/nerfstudio/models/vanilla_nerf.py @@ -23,9 +23,6 @@ import torch from torch.nn import Parameter -from torchmetrics.functional import structural_similarity_index_measure -from torchmetrics.image import PeakSignalNoiseRatio -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from nerfstudio.cameras.rays import RayBundle from nerfstudio.configs.config_utils import to_immutable_dict @@ -118,6 +115,10 @@ def populate_modules(self): self.rgb_loss = MSELoss() # metrics + from torchmetrics.functional import structural_similarity_index_measure + from torchmetrics.image import PeakSignalNoiseRatio + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) diff --git a/nerfstudio/utils/writer.py b/nerfstudio/utils/writer.py index 2bc5bcfc60..d460e9411b 100644 --- a/nerfstudio/utils/writer.py +++ b/nerfstudio/utils/writer.py @@ -24,9 +24,7 @@ from time import time from typing import Any, Dict, List, Optional, Union -import comet_ml import torch -import wandb from jaxtyping import Float from torch import Tensor from torch.utils.tensorboard import SummaryWriter @@ -307,6 +305,8 @@ class WandbWriter(Writer): """WandDB Writer Class""" def __init__(self, log_dir: Path, experiment_name: str, project_name: str = "nerfstudio-project"): + import wandb # wandb is slow to import, so we only import it if we need it. + wandb.init( project=os.environ.get("WANDB_PROJECT", project_name), dir=os.environ.get("WANDB_DIR", str(log_dir)), @@ -315,10 +315,14 @@ def __init__(self, log_dir: Path, experiment_name: str, project_name: str = "ner ) def write_image(self, name: str, image: Float[Tensor, "H W C"], step: int) -> None: + import wandb # wandb is slow to import, so we only import it if we need it. + image = torch.permute(image, (2, 0, 1)) wandb.log({name: wandb.Image(image)}, step=step) def write_scalar(self, name: str, scalar: Union[float, torch.Tensor], step: int) -> None: + import wandb # wandb is slow to import, so we only import it if we need it. + wandb.log({name: scalar}, step=step) def write_config(self, name: str, config_dict: Dict[str, Any], step: int): @@ -327,6 +331,8 @@ def write_config(self, name: str, config_dict: Dict[str, Any], step: int): Args: config: config dictionary to write out """ + import wandb # wandb is slow to import, so we only import it if we need it. + wandb.config.update(config_dict, allow_val_change=True) @@ -358,6 +364,9 @@ class CometWriter(Writer): """Comet_ML Writer Class""" def __init__(self, log_dir: Path, experiment_name: str, project_name: str = "nerfstudio-project"): + # comet_ml is slow to import, so we only do it if we need it. + import comet_ml + self.experiment = comet_ml.Experiment(project_name=project_name) if experiment_name != "unnamed": self.experiment.set_name(experiment_name) diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index 8cff2b9763..c7a881b80d 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -22,7 +22,6 @@ import numpy as np import torch -import torchvision import viser import viser.theme import viser.transforms as vtf @@ -409,6 +408,10 @@ def init_scene( camera = train_dataset.cameras[idx] image_uint8 = (image * 255).detach().type(torch.uint8) image_uint8 = image_uint8.permute(2, 0, 1) + + # torchvision can be slow to import, so we do it lazily. + import torchvision + image_uint8 = torchvision.transforms.functional.resize(image_uint8, 100, antialias=None) # type: ignore image_uint8 = image_uint8.permute(1, 2, 0) image_uint8 = image_uint8.cpu().numpy() diff --git a/pyproject.toml b/pyproject.toml index 7f5f602592..b0e93fe316 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "av>=9.2.0", "comet_ml>=3.33.8", "cryptography>=38", - "tyro>=0.5.10", + "tyro>=0.6.4", "gdown>=4.6.0", "ninja>=1.10", "h5py>=2.9.0", @@ -56,7 +56,7 @@ dependencies = [ "torchvision>=0.14.1", "torchmetrics[image]>=1.0.1", "typing_extensions>=4.4.0", - "viser==0.1.17", + "viser==0.1.20", "nuscenes-devkit>=1.1.1", "wandb>=0.13.3", "xatlas",