Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Python 3.11 dataclass field default_factory defaults #2746

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions nerfstudio/configs/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@

import yaml

from nerfstudio.configs.base_config import (
InstantiateConfig,
LoggingConfig,
MachineConfig,
ViewerConfig,
)
from nerfstudio.configs.base_config import InstantiateConfig, LoggingConfig, MachineConfig, ViewerConfig
from nerfstudio.configs.config_utils import to_immutable_dict
from nerfstudio.engine.optimizers import OptimizerConfig
from nerfstudio.engine.schedulers import SchedulerConfig
Expand All @@ -51,13 +46,13 @@ class ExperimentConfig(InstantiateConfig):
"""Project name."""
timestamp: str = "{timestamp}"
"""Experiment timestamp."""
machine: MachineConfig = field(default_factory=lambda: MachineConfig())
machine: MachineConfig = field(default_factory=MachineConfig)
"""Machine configuration"""
logging: LoggingConfig = field(default_factory=lambda: LoggingConfig())
logging: LoggingConfig = field(default_factory=LoggingConfig)
"""Logging configuration"""
viewer: ViewerConfig = field(default_factory=lambda: ViewerConfig())
viewer: ViewerConfig = field(default_factory=ViewerConfig)
"""Viewer configuration"""
pipeline: VanillaPipelineConfig = field(default_factory=lambda: VanillaPipelineConfig())
pipeline: VanillaPipelineConfig = field(default_factory=VanillaPipelineConfig)
"""Pipeline configuration"""
optimizers: Dict[str, Any] = to_immutable_dict(
{
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ class VanillaDataManagerConfig(DataManagerConfig):

_target: Type = field(default_factory=lambda: VanillaDataManager)
"""Target class to instantiate."""
dataparser: AnnotatedDataParserUnion = field(default_factory=lambda: BlenderDataParserConfig())
dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig)
"""Specifies the dataparser used to unpack the data."""
train_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per training iteration."""
Expand Down Expand Up @@ -344,7 +344,7 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
camera_optimizer: Optional[CameraOptimizerConfig] = field(default=None)
"""Deprecated, has been moved to the model config."""
pixel_sampler: PixelSamplerConfig = field(default_factory=lambda: PixelSamplerConfig())
pixel_sampler: PixelSamplerConfig = field(default_factory=PixelSamplerConfig)
"""Specifies the pixel sampler used to sample pixels from images."""

def __post_init__(self):
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

import random
from copy import deepcopy
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
Expand All @@ -30,7 +31,6 @@
import cv2
import numpy as np
import torch
from copy import deepcopy
from torch.nn import Parameter
from tqdm import tqdm

Expand All @@ -47,7 +47,7 @@
@dataclass
class FullImageDatamanagerConfig(DataManagerConfig):
_target: Type = field(default_factory=lambda: FullImageDatamanager)
dataparser: AnnotatedDataParserUnion = NerfstudioDataParserConfig()
dataparser: AnnotatedDataParserUnion = field(default_factory=NerfstudioDataParserConfig)
camera_res_scale_factor: float = 1.0
"""The scale factor for scaling spatial data such as images, mask, semantics
along with relevant information about camera intrinsics
Expand Down
16 changes: 3 additions & 13 deletions nerfstudio/models/base_surface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,9 @@
from nerfstudio.fields.nerfacto_field import NerfactoField
from nerfstudio.fields.sdf_field import SDFFieldConfig
from nerfstudio.fields.vanilla_nerf_field import NeRFField
from nerfstudio.model_components.losses import (
L1Loss,
MSELoss,
ScaleAndShiftInvariantLoss,
monosdf_normal_loss,
)
from nerfstudio.model_components.losses import L1Loss, MSELoss, ScaleAndShiftInvariantLoss, monosdf_normal_loss
from nerfstudio.model_components.ray_samplers import LinearDisparitySampler
from nerfstudio.model_components.renderers import (
AccumulationRenderer,
DepthRenderer,
RGBRenderer,
SemanticRenderer,
)
from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, RGBRenderer, SemanticRenderer
from nerfstudio.model_components.scene_colliders import AABBBoxCollider, NearFarCollider
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils import colormaps
Expand Down Expand Up @@ -79,7 +69,7 @@ class SurfaceModelConfig(ModelConfig):
"""Monocular normal consistency loss multiplier."""
mono_depth_loss_mult: float = 0.0
"""Monocular depth consistency loss multiplier."""
sdf_field: SDFFieldConfig = field(default_factory=lambda: SDFFieldConfig())
sdf_field: SDFFieldConfig = field(default_factory=SDFFieldConfig)
"""Config for SDF Field"""
background_model: Literal["grid", "mlp", "none"] = "mlp"
"""background models"""
Expand Down
30 changes: 14 additions & 16 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,33 @@

from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union
from nerfstudio.data.scene_box import OrientedBox

import numpy as np
import torch
import torchvision.transforms.functional as TF
from gsplat._torch_impl import quat_to_rotmat
from gsplat.compute_cumulative_intersects import compute_cumulative_intersects
from gsplat.project_gaussians import ProjectGaussians
from gsplat.rasterize import RasterizeGaussians
from gsplat.sh import SphericalHarmonics, num_sh_bases
from pytorch_msssim import SSIM
from sklearn.neighbors import NearestNeighbors
from torch.nn import Parameter
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import torchvision.transforms.functional as TF

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
from gsplat._torch_impl import quat_to_rotmat
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
from nerfstudio.engine.optimizers import Optimizers
from nerfstudio.models.base_model import Model, ModelConfig
import math
import numpy as np
from sklearn.neighbors import NearestNeighbors
from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig

from gsplat.rasterize import RasterizeGaussians
from gsplat.project_gaussians import ProjectGaussians
from gsplat.sh import SphericalHarmonics, num_sh_bases

from gsplat.compute_cumulative_intersects import compute_cumulative_intersects
from pytorch_msssim import SSIM

# need following import for background color override
from nerfstudio.model_components import renderers
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils.rich_utils import CONSOLE


Expand Down Expand Up @@ -149,7 +147,7 @@ class GaussianSplattingModelConfig(ModelConfig):
"""stop splitting at this step"""
sh_degree: int = 3
"""maximum degree of spherical harmonics to use"""
camera_optimizer: CameraOptimizerConfig = CameraOptimizerConfig(mode="off")
camera_optimizer: CameraOptimizerConfig = field(default_factory=CameraOptimizerConfig)
"""camera optimizer config"""
use_scale_regularization: bool = False
"""If enabled, a scale regularization introduced in PhysGauss (https://xpandora.github.io/PhysGaussian/) is used for reducing huge spikey gaussians."""
Expand Down
16 changes: 6 additions & 10 deletions nerfstudio/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,10 @@
from torch.nn import Parameter
from torch.nn.parallel import DistributedDataParallel as DDP

from nerfstudio.configs import base_config as cfg
from nerfstudio.data.datamanagers.base_datamanager import (
DataManager,
DataManagerConfig,
VanillaDataManager,
)
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, VanillaDataManager
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils import profiler
Expand Down Expand Up @@ -213,14 +209,14 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:


@dataclass
class VanillaPipelineConfig(cfg.InstantiateConfig):
class VanillaPipelineConfig(InstantiateConfig):
"""Configuration for pipeline instantiation"""

_target: Type = field(default_factory=lambda: VanillaPipeline)
"""target class to instantiate"""
datamanager: DataManagerConfig = field(default_factory=lambda: DataManagerConfig())
datamanager: DataManagerConfig = field(default_factory=DataManagerConfig)
"""specifies the datamanager config"""
model: ModelConfig = field(default_factory=lambda: ModelConfig())
model: ModelConfig = field(default_factory=ModelConfig)
"""specifies the model config"""


Expand Down
9 changes: 4 additions & 5 deletions tests/plugins/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
"""
import os
import sys
from dataclasses import dataclass
from dataclasses import dataclass, field

from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig
from nerfstudio.plugins import registry
from nerfstudio.plugins import registry, registry_dataparser
from nerfstudio.plugins.registry_dataparser import DataParserConfig, DataParserSpecification, discover_dataparsers
from nerfstudio.plugins.types import MethodSpecification
from nerfstudio.plugins import registry_dataparser
from nerfstudio.plugins.registry_dataparser import DataParserSpecification, discover_dataparsers, DataParserConfig

if sys.version_info < (3, 10):
import importlib_metadata
Expand Down Expand Up @@ -100,7 +99,7 @@ def test_discover_methods_from_environment_variable_instance():

@dataclass
class TestDataparserConfigClass(DataParserSpecification):
config: DataParserConfig = DataParserConfig()
config: DataParserConfig = field(default_factory=DataParserConfig)
description: str = "Test description"


Expand Down
Loading