Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support splatfacto training for metashape exports #3122

Merged
merged 13 commits into from
May 2, 2024
8 changes: 4 additions & 4 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,8 @@ def _downscale_if_required(self, image):
return image

@staticmethod
def get_empty_outputs(camera, background):
rgb = background.repeat(int(camera.height.item()), int(camera.width.item()), 1)
def get_empty_outputs(width: int, height: int, background: torch.Tensor) -> Dict[str, Union[torch.Tensor, List]]:
rgb = background.repeat(height, width, 1)
depth = background.new_ones(*rgb.shape[:2], 1) * 10
accumulation = background.new_zeros(*rgb.shape[:2], 1)
return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background}
Expand Down Expand Up @@ -694,7 +694,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
if self.crop_box is not None and not self.training:
crop_ids = self.crop_box.within(self.means).squeeze()
if crop_ids.sum() == 0:
return self.get_empty_outputs(camera, background)
return self.get_empty_outputs(int(camera.width.item()), int(camera.height.item()), background)
else:
crop_ids = None
camera_downscale = self._get_downscale_factor()
Expand Down Expand Up @@ -754,7 +754,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
camera.rescale_output_resolution(camera_downscale)

if (self.radii).sum() == 0:
return self.get_empty_outputs(camera, background)
return self.get_empty_outputs(W, H, background)

if self.config.sh_degree > 0:
viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3)
Expand Down
24 changes: 22 additions & 2 deletions nerfstudio/process_data/metashape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import json
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

import numpy as np
import open3d as o3d

from nerfstudio.process_data.process_data_utils import CAMERA_MODELS
from nerfstudio.utils.rich_utils import CONSOLE
Expand All @@ -36,6 +37,7 @@ def metashape_to_json(
image_filename_map: Dict[str, Path],
xml_filename: Path,
output_dir: Path,
ply_filename: Optional[Path] = None, # type: ignore
verbose: bool = False,
) -> List[str]:
"""Convert Metashape data into a nerfstudio dataset.
Expand All @@ -44,6 +46,7 @@ def metashape_to_json(
image_filename_map: Mapping of original image filenames to their saved locations.
xml_filename: Path to the metashape cameras xml file.
output_dir: Path to the output directory.
ply_filename: Path to the exported ply file.
verbose: Whether to print verbose output.

Returns:
Expand Down Expand Up @@ -181,17 +184,34 @@ def metashape_to_json(
if component_id in component_dict:
transform = component_dict[component_id] @ transform

# Metashape camera is looking towards -Z, +X is to the right and +Y is to the top/up of the first cam
# Rotate the scene according to nerfstudio convention
transform = transform[[2, 0, 1, 3], :]
# Convert from Metashape's camera coordinate system (OpenCV) to ours (OpenGL)
transform[:, 1:3] *= -1
frame["transform_matrix"] = transform.tolist()
frames.append(frame)

data["frames"] = frames
applied_transform = np.eye(4)[:3, :]
applied_transform = applied_transform[np.array([2, 0, 1]), :]
data["applied_transform"] = applied_transform.tolist()

summary = []

if ply_filename is not None:
assert ply_filename.exists()
pc = o3d.io.read_point_cloud(str(ply_filename))
points3D = np.asarray(pc.points)
points3D = np.einsum("ij,bj->bi", applied_transform[:3, :3], points3D) + applied_transform[:3, 3]
pc.points = o3d.utility.Vector3dVector(points3D)
o3d.io.write_point_cloud(str(output_dir / "sparse_pc.ply"), pc)
data["ply_file_path"] = "sparse_pc.ply"
summary.append(f"Imported {ply_filename} as starting points")

with open(output_dir / "transforms.json", "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)

summary = []
if num_skipped == 1:
summary.append(f"{num_skipped} image skipped because it was missing its camera pose.")
if num_skipped > 1:
Expand Down
28 changes: 20 additions & 8 deletions nerfstudio/scripts/process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
#!/usr/bin/env python
"""Processes a video or image sequence to a nerfstudio compatible dataset."""


import sys
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Union
from typing import Optional, Union

import numpy as np
import tyro
Expand Down Expand Up @@ -230,12 +229,18 @@ class ProcessMetashape(BaseConverterToNerfstudioDataset, _NoDefaultProcessMetash
This script assumes that cameras have been aligned using Metashape. After alignment, it is necessary to export the
camera poses as a `.xml` file. This option can be found under `File > Export > Export Cameras`.

Additionally, the points can be exported as pointcloud under `File > Export > Export Point Cloud`. Make sure to
export the data in non-binary format and exclude the normals.

This script does the following:

1. Scales images to a specified size.
2. Converts Metashape poses into the nerfstudio format.
"""

ply: Optional[Path] = None
"""Path to the Metashape point export ply file."""

num_downscales: int = 3
"""Number of times to downscale the images. Downscales by 2 each time. For example a value of 3
will downscale the images by 2x, 4x, and 8x."""
Expand All @@ -248,11 +253,17 @@ def main(self) -> None:

if self.xml.suffix != ".xml":
raise ValueError(f"XML file {self.xml} must have a .xml extension")
if not self.xml.exists:
if not self.xml.exists():
raise ValueError(f"XML file {self.xml} doesn't exist")
if self.eval_data is not None:
raise ValueError("Cannot use eval_data since cameras were already aligned with Metashape.")

if self.ply is not None:
if self.ply.suffix != ".ply":
raise ValueError(f"PLY file {self.ply} must have a .ply extension")
if not self.ply.exists():
raise ValueError(f"PLY file {self.ply} doesn't exist")

self.output_dir.mkdir(parents=True, exist_ok=True)
image_dir = self.output_dir / "images"
image_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -291,6 +302,7 @@ def main(self) -> None:
image_filename_map=image_filename_map,
xml_filename=self.xml,
output_dir=self.output_dir,
ply_filename=self.ply,
verbose=self.verbose,
)
)
Expand Down Expand Up @@ -335,7 +347,7 @@ def main(self) -> None:

if self.csv.suffix != ".csv":
raise ValueError(f"CSV file {self.csv} must have a .csv extension")
if not self.csv.exists:
if not self.csv.exists():
raise ValueError(f"CSV file {self.csv} doesn't exist")
if self.eval_data is not None:
raise ValueError("Cannot use eval_data since cameras were already aligned with RealityCapture.")
Expand Down Expand Up @@ -413,12 +425,12 @@ def main(self) -> None:
shots_file = self.data / "odm_report" / "shots.geojson"
reconstruction_file = self.data / "opensfm" / "reconstruction.json"

if not shots_file.exists:
if not shots_file.exists():
raise ValueError(f"shots file {shots_file} doesn't exist")
if not shots_file.exists:
if not shots_file.exists():
raise ValueError(f"cameras file {cameras_file} doesn't exist")

if not orig_images_dir.exists:
if not orig_images_dir.exists():
raise ValueError(f"Images dir {orig_images_dir} doesn't exist")

if self.eval_data is not None:
Expand Down Expand Up @@ -522,7 +534,7 @@ def entrypoint():
tyro.extras.set_accent_color("bright_yellow")
try:
tyro.cli(Commands).main()
except RuntimeError as e:
except (RuntimeError, ValueError) as e:
CONSOLE.log("[bold red]" + e.args[0])


Expand Down
Loading