Skip to content

Commit

Permalink
Train and eval data splits (Nerfbusters) (#2207)
Browse files Browse the repository at this point in the history
* nerfbusters dataparser

* nerfbuster dataparser and visualization of eval views

* resolved some errors from  pyright

* enable ns-process-data to take two videos as input and use one for eval and the other for train

* resolve some pyright issues

* instead of data being a tuple, I added eval_data as another argument

* implemented train eval split for all cases

* implemented train eval split for all cases

* reverting equi changes

* reverting reality capture changes

* raise valueerror

* revert comment

* added split override code back to dataparser

* reverted one change

* using image border now

* adding docs

---------

Co-authored-by: Ethan Weber <[email protected]>
  • Loading branch information
FrederikWarburg and ethanweber authored Aug 4, 2023
1 parent 388fcb0 commit 9b03299
Show file tree
Hide file tree
Showing 14 changed files with 288 additions and 83 deletions.
4 changes: 4 additions & 0 deletions docs/quickstart/custom_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ ns-process-data {images, video} --data {DATA_PATH} --output-dir {PROCESSED_DATA_
ns-train nerfacto --data {PROCESSED_DATA_DIR}
```

### Training and evaluation on separate data

For `ns-process-data {images, video}`, you can optionally use a separate image directory or video for training and evaluation, as suggested in [Nerfbusters](https://ethanweber.me/nerfbusters/). To do this, run `ns-process-data {images, video} --data {DATA_PATH} --eval-data {EVAL_DATA_PATH} --output-dir {PROCESSED_DATA_DIR}`. Then when running nerfacto, run `ns-train nerfacto --data {PROCESSED_DATA_DIR} nerfstudio-data --eval-mode filename`.

### Installing COLMAP

There are many ways to install COLMAP, unfortunately it can sometimes be a bit finicky. If the following commands do not work, please refer to the [COLMAP installation guide](https://colmap.github.io/install.html) for additional installation methods. COLMAP install issues are common! Feel free to ask for help in on our [Discord](https://discord.gg/uMbNqcraFc).
Expand Down
59 changes: 42 additions & 17 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from __future__ import annotations

import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional, Type
Expand All @@ -26,12 +25,14 @@

from nerfstudio.cameras import camera_utils
from nerfstudio.cameras.cameras import CAMERA_MODEL_TO_TYPE, Cameras, CameraType
from nerfstudio.data.dataparsers.base_dataparser import (
DataParser,
DataParserConfig,
DataparserOutputs,
)
from nerfstudio.data.dataparsers.base_dataparser import DataParser, DataParserConfig, DataparserOutputs
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.data.utils.dataparsers_utils import (
get_train_eval_split_filename,
get_train_eval_split_fraction,
get_train_eval_split_interval,
get_train_eval_split_all,
)
from nerfstudio.utils.io import load_from_json
from nerfstudio.utils.rich_utils import CONSOLE

Expand All @@ -58,8 +59,18 @@ class NerfstudioDataParserConfig(DataParserConfig):
"""The method to use to center the poses."""
auto_scale_poses: bool = True
"""Whether to automatically scale the poses to fit in +/- 1 bounding box."""
eval_mode: Literal["fraction", "filename", "interval", "all"] = "fraction"
"""
The method to use for splitting the dataset into train and eval.
Fraction splits based on a percentage for train and the remaining for eval.
Filename splits based on filenames containing train/eval.
Interval uses every nth frame for eval.
All uses all the images for any split.
"""
train_split_fraction: float = 0.9
"""The fraction of images to use for training. The remaining images are for eval."""
"""The percentage of the dataset to use for training. Only used when eval_mode is train-split-fraction."""
eval_interval: int = 8
"""The interval between frames to use for eval. Only used when eval_mode is eval-interval."""
depth_unit_scale_factor: float = 1e-3
"""Scales the depth values to meters. Default value is 0.001 for a millimeter to meter conversion."""

Expand Down Expand Up @@ -105,9 +116,18 @@ def _generate_dataparser_outputs(self, split="train"):
width = []
distort = []

# sort the frames by fname
fnames = []
for frame in meta["frames"]:
filepath = Path(frame["file_path"])
fname = self._get_fname(filepath, data_dir)
fnames.append(fname)
inds = np.argsort(fnames)
frames = [meta["frames"][ind] for ind in inds]

for frame in frames:
filepath = Path(frame["file_path"])
fname = self._get_fname(filepath, data_dir)

if not fx_fixed:
assert "fl_x" in frame, "fx not specified in frame"
Expand Down Expand Up @@ -182,16 +202,21 @@ def _generate_dataparser_outputs(self, split="train"):
elif has_split_files_spec:
raise RuntimeError(f"The dataset's list of filenames for split {split} is missing.")
else:
# filter image_filenames and poses based on train/eval split percentage
num_images = len(image_filenames)
num_train_images = math.ceil(num_images * self.config.train_split_fraction)
num_eval_images = num_images - num_train_images
i_all = np.arange(num_images)
i_train = np.linspace(
0, num_images - 1, num_train_images, dtype=int
) # equally spaced training images starting and ending at 0 and num_images-1
i_eval = np.setdiff1d(i_all, i_train) # eval images are the remaining images
assert len(i_eval) == num_eval_images
# find train and eval indices based on the eval_mode specified
if self.config.eval_mode == "fraction":
i_train, i_eval = get_train_eval_split_fraction(image_filenames, self.config.train_split_fraction)
elif self.config.eval_mode == "filename":
i_train, i_eval = get_train_eval_split_filename(image_filenames)
elif self.config.eval_mode == "interval":
i_train, i_eval = get_train_eval_split_interval(image_filenames, self.config.eval_interval)
elif self.config.eval_mode == "all":
CONSOLE.log(
"[yellow] Be careful with '--eval-mode=all'. If using camera optimization, the cameras may diverge in the current implementation, giving unpredictable results."
)
i_train, i_eval = get_train_eval_split_all(image_filenames)
else:
raise ValueError(f"Unknown eval mode {self.config.eval_mode}")

if split == "train":
indices = i_train
elif split in ["val", "test"]:
Expand Down
101 changes: 101 additions & 0 deletions nerfstudio/data/utils/dataparsers_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Data parser utils for nerfstudio datasets. """

import math
import os
from typing import List, Tuple

import numpy as np


def get_train_eval_split_fraction(image_filenames: List, train_split_fraction: float) -> Tuple[np.ndarray, np.ndarray]:
"""
Get the train/eval split fraction based on the number of images and the train split fraction.
Args:
image_filenames: list of image filenames
train_split_fraction: fraction of images to use for training
"""

# filter image_filenames and poses based on train/eval split percentage
num_images = len(image_filenames)
num_train_images = math.ceil(num_images * train_split_fraction)
num_eval_images = num_images - num_train_images
i_all = np.arange(num_images)
i_train = np.linspace(
0, num_images - 1, num_train_images, dtype=int
) # equally spaced training images starting and ending at 0 and num_images-1
i_eval = np.setdiff1d(i_all, i_train) # eval images are the remaining images
assert len(i_eval) == num_eval_images

return i_train, i_eval


def get_train_eval_split_filename(image_filenames: List) -> Tuple[np.ndarray, np.ndarray]:
"""
Get the train/eval split based on the filename of the images.
Args:
image_filenames: list of image filenames
"""

num_images = len(image_filenames)
basenames = [os.path.basename(image_filename) for image_filename in image_filenames]
i_all = np.arange(num_images)
i_train = []
i_eval = []
for idx, basename in zip(i_all, basenames):
# check the frame index
if "train" in basename:
i_train.append(idx)
elif "eval" in basename:
i_eval.append(idx)
else:
raise ValueError("frame should contain train/eval in its name to use this eval-frame-index eval mode")

return np.array(i_train), np.array(i_eval)


def get_train_eval_split_interval(image_filenames: List, eval_interval: float) -> Tuple[np.ndarray, np.ndarray]:
"""
Get the train/eval split based on the interval of the images.
Args:
image_filenames: list of image filenames
eval_interval: interval of images to use for eval
"""

num_images = len(image_filenames)
all_indices = np.arange(num_images)
train_indices = all_indices[all_indices % eval_interval != 0]
eval_indices = all_indices[all_indices % eval_interval == 0]
i_train = train_indices
i_eval = eval_indices

return i_train, i_eval


def get_train_eval_split_all(image_filenames: List) -> Tuple[np.ndarray, np.ndarray]:
"""
Get the train/eval split where all indices are used for both train and eval.
Args:
image_filenames: list of image filenames
"""
num_images = len(image_filenames)
i_all = np.arange(num_images)
i_train = i_all
i_eval = i_all
return i_train, i_eval
3 changes: 2 additions & 1 deletion nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,9 @@ def _init_viewer_state(self) -> None:
"""Initializes viewer scene with given train dataset"""
assert self.viewer_state and self.pipeline.datamanager.train_dataset
self.viewer_state.init_scene(
dataset=self.pipeline.datamanager.train_dataset,
train_dataset=self.pipeline.datamanager.train_dataset,
train_state="training",
eval_dataset=self.pipeline.datamanager.eval_dataset,
)

@check_viewer_enabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Optional


@dataclass
class BaseConverterToNerfstudioDataset(ABC):
"""Base class to process images or video into a nerfstudio dataset"""
"""Base class to process images or video into a nerfstudio dataset."""

data: Path
"""Path the data, either a video file or a directory of images."""
output_dir: Path
"""Path to the output directory."""
eval_data: Optional[Path] = None
"""Path the eval data, either a video file or a directory of images. If set to None, the first will be used both for training and eval"""
verbose: bool = False
"""If True, print extra logging."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
from typing import Dict, List, Literal, Optional, Tuple

from nerfstudio.process_data import colmap_utils, hloc_utils, process_data_utils
from nerfstudio.process_data.base_converter_to_nerfstudio_dataset import (
BaseConverterToNerfstudioDataset,
)
from nerfstudio.process_data.base_converter_to_nerfstudio_dataset import BaseConverterToNerfstudioDataset
from nerfstudio.process_data.process_data_utils import CAMERA_MODELS
from nerfstudio.utils import install_checks
from nerfstudio.utils.rich_utils import CONSOLE
Expand Down
23 changes: 20 additions & 3 deletions nerfstudio/process_data/images_to_nerfstudio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
from typing import Optional

from nerfstudio.process_data import equirect_utils, process_data_utils
from nerfstudio.process_data.colmap_converter_to_nerfstudio_dataset import (
ColmapConverterToNerfstudioDataset,
)
from nerfstudio.process_data.colmap_converter_to_nerfstudio_dataset import ColmapConverterToNerfstudioDataset
from nerfstudio.utils.rich_utils import CONSOLE


Expand All @@ -47,11 +45,15 @@ def main(self) -> None:

# Generate planar projections if equirectangular
if self.camera_type == "equirectangular":
if self.eval_data is not None:
raise ValueError("Cannot use eval_data with camera_type equirectangular.")

pers_size = equirect_utils.compute_resolution_from_equirect(self.data, self.images_per_equirect)
CONSOLE.log(f"Generating {self.images_per_equirect} {pers_size} sized images per equirectangular image")
self.data = equirect_utils.generate_planar_projections_from_equirectangular(
self.data, pers_size, self.images_per_equirect, crop_factor=self.crop_factor
)

self.camera_type = "perspective"

summary_log = []
Expand All @@ -63,10 +65,25 @@ def main(self) -> None:
self.data,
image_dir=self.image_dir,
crop_factor=self.crop_factor,
image_prefix="frame_train_" if self.eval_data is not None else "frame_",
verbose=self.verbose,
num_downscales=self.num_downscales,
same_dimensions=self.same_dimensions,
keep_image_dir=False,
)
if self.eval_data is not None:
eval_image_rename_map_paths = process_data_utils.copy_images(
self.eval_data,
image_dir=self.image_dir,
crop_factor=self.crop_factor,
image_prefix="frame_eval_",
verbose=self.verbose,
num_downscales=self.num_downscales,
same_dimensions=self.same_dimensions,
keep_image_dir=True,
)
image_rename_map_paths.update(eval_image_rename_map_paths)

image_rename_map = dict((a.name, b.name) for a, b in image_rename_map_paths.items())
num_frames = len(image_rename_map)
summary_log.append(f"Starting with {num_frames} images")
Expand Down
Loading

0 comments on commit 9b03299

Please sign in to comment.