Skip to content

Commit

Permalink
Implement gaussian splatting ply file saving (#427)
Browse files Browse the repository at this point in the history
* implement ply saving

* fix colors

* Add own flag for saveing ply files

* fix appearance embeddings

* remove open3d

* align order with INRIA ply

* filter Nan and infs

* add flag to save ply and move save_ply to utils

---------

Co-authored-by: maturk <[email protected]>
  • Loading branch information
MrNeRF and maturk authored Jan 10, 2025
1 parent 2df0a95 commit 1a1e0cc
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
26 changes: 26 additions & 0 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from gsplat.optimizers import SelectiveAdam
from gsplat.utils import save_ply


@dataclass
Expand Down Expand Up @@ -85,6 +86,10 @@ class Config:
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to save ply file (storage size can be large)
save_ply: bool = False
# Steps to save the model as ply
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])

# Initialization strategy
init_type: str = "sfm"
Expand Down Expand Up @@ -167,6 +172,7 @@ class Config:
def adjust_steps(self, factor: float):
self.eval_steps = [int(i * factor) for i in self.eval_steps]
self.save_steps = [int(i * factor) for i in self.save_steps]
self.ply_steps = [int(i * factor) for i in self.ply_steps]
self.max_steps = int(self.max_steps * factor)
self.sh_degree_interval = int(self.sh_degree_interval * factor)

Expand Down Expand Up @@ -294,6 +300,8 @@ def __init__(
os.makedirs(self.stats_dir, exist_ok=True)
self.render_dir = f"{cfg.result_dir}/renders"
os.makedirs(self.render_dir, exist_ok=True)
self.ply_dir = f"{cfg.result_dir}/ply"
os.makedirs(self.ply_dir, exist_ok=True)

# Tensorboard
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
Expand Down Expand Up @@ -735,6 +743,24 @@ def train(self):
torch.save(
data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt"
)
if (
step in [i - 1 for i in cfg.ply_steps]
or step == max_steps - 1
and cfg.save_ply
):
rgb = None
if self.cfg.app_opt:
# eval at origin to bake the appeareance into the colors
rgb = self.app_module(
features=self.splats["features"],
embed_ids=None,
dirs=torch.zeros_like(self.splats["means"][None, :, :]),
sh_degree=sh_degree_to_use,
)
rgb = rgb + self.splats["colors"]
rgb = torch.sigmoid(rgb).squeeze(0)

save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply", rgb)

# Turn Gradients into Sparse Tensor before running optimizer
if cfg.sparse_grad:
Expand Down
93 changes: 93 additions & 0 deletions gsplat/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,101 @@
import math
import struct

import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np


def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None):
# Convert all tensors to numpy arrays in one go
print(f"Saving ply to {dir}")
numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()}

means = numpy_data["means"]
scales = numpy_data["scales"]
quats = numpy_data["quats"]
opacities = numpy_data["opacities"]

sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1)
shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1)

# Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays
invalid_mask = (
np.isnan(means).any(axis=1)
| np.isinf(means).any(axis=1)
| np.isnan(scales).any(axis=1)
| np.isinf(scales).any(axis=1)
| np.isnan(quats).any(axis=1)
| np.isinf(quats).any(axis=1)
| np.isnan(opacities).any(axis=0)
| np.isinf(opacities).any(axis=0)
| np.isnan(sh0).any(axis=1)
| np.isinf(sh0).any(axis=1)
| np.isnan(shN).any(axis=1)
| np.isinf(shN).any(axis=1)
)

# Filter out rows with NaNs or Infs from all data arrays
means = means[~invalid_mask]
scales = scales[~invalid_mask]
quats = quats[~invalid_mask]
opacities = opacities[~invalid_mask]
sh0 = sh0[~invalid_mask]
shN = shN[~invalid_mask]

num_points = means.shape[0]

with open(dir, "wb") as f:
# Write PLY header
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(f"element vertex {num_points}\n".encode())
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
f.write(b"property float nx\n")
f.write(b"property float ny\n")
f.write(b"property float nz\n")

if colors is not None:
for j in range(colors.shape[1]):
f.write(f"property float f_dc_{j}\n".encode())
else:
for i, data in enumerate([sh0, shN]):
prefix = "f_dc" if i == 0 else "f_rest"
for j in range(data.shape[1]):
f.write(f"property float {prefix}_{j}\n".encode())

f.write(b"property float opacity\n")

for i in range(scales.shape[1]):
f.write(f"property float scale_{i}\n".encode())
for i in range(quats.shape[1]):
f.write(f"property float rot_{i}\n".encode())

f.write(b"end_header\n")

# Write vertex data
for i in range(num_points):
f.write(struct.pack("<fff", *means[i])) # x, y, z
f.write(struct.pack("<fff", 0, 0, 0)) # nx, ny, nz (zeros)

if colors is not None:
color = colors.detach().cpu().numpy()
for j in range(color.shape[1]):
f_dc = (color[i, j] - 0.5) / 0.2820947917738781
f.write(struct.pack("<f", f_dc))
else:
for data in [sh0, shN]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))

f.write(struct.pack("<f", opacities[i])) # opacity

for data in [scales, quats]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))


def normalized_quat_to_rotmat(quat: Tensor) -> Tensor:
Expand Down

0 comments on commit 1a1e0cc

Please sign in to comment.