Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fourier Feature encodings and polyhedron encodings #2463

Merged
merged 18 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 78 additions & 20 deletions nerfstudio/field_components/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
from torch import Tensor, nn

from nerfstudio.field_components.base_field_component import FieldComponent
from nerfstudio.utils.math import components_from_spherical_harmonics, expected_sin
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
from nerfstudio.utils.math import (
components_from_spherical_harmonics,
expected_sin,
generate_polyhedron_basis,
)
from nerfstudio.utils.printing import print_tcnn_speed_warning
from nerfstudio.utils.external import tcnn, TCNN_EXISTS


class Encoding(FieldComponent):
Expand Down Expand Up @@ -153,7 +157,7 @@ def pytorch_fwd(
Output values will be between -1 and 1
"""
scaled_in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi]
freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies).to(in_tensor.device)
freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device)
scaled_inputs = scaled_in_tensor[..., None] * freqs # [..., "input_dim", "num_scales"]
scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) # [..., "input_dim" * "num_scales"]

Expand All @@ -178,34 +182,40 @@ def forward(
return self.pytorch_fwd(in_tensor, covs)


class RFFEncoding(Encoding):
"""Random Fourier Feature encoding. Supports integrated encodings.
class FFEncoding(Encoding):
"""Fourier Feature encoding. Supports integrated encodings.

Args:
in_dim: Input dimension of tensor
num_frequencies: Number of encoding frequencies
scale: Std of Gaussian to sample frequencies. Must be greater than zero
basis: Basis matrix from which to construct the Fourier features.
num_frequencies: Number of encoded frequencies per axis
min_freq_exp: Minimum frequency exponent
max_freq_exp: Maximum frequency exponent
jkulhanek marked this conversation as resolved.
Show resolved Hide resolved
include_input: Append the input coordinate to the encoding
"""

def __init__(self, in_dim: int, num_frequencies: int, scale: float, include_input: bool = False) -> None:
def __init__(
self,
in_dim: int,
basis: Float[Tensor, "M N"],
num_frequencies: int,
min_freq_exp: float,
max_freq_exp: float,
include_input: bool = False,
) -> None:
super().__init__(in_dim)

self.num_frequencies = num_frequencies
if not scale > 0:
raise ValueError("RFF encoding scale should be greater than zero")
self.scale = scale
if self.in_dim is None:
raise ValueError("Input dimension has not been set")
b_matrix = torch.normal(mean=0, std=self.scale, size=(self.in_dim, self.num_frequencies))
self.register_buffer(name="b_matrix", tensor=b_matrix)
self.min_freq = min_freq_exp
self.max_freq = max_freq_exp
self.register_buffer(name="b_matrix", tensor=basis)
self.include_input = include_input

def get_out_dim(self) -> int:
out_dim = self.num_frequencies * 2
if self.in_dim is None:
raise ValueError("Input dimension has not been set")
assert isinstance(self.b_matrix, Tensor)
out_dim = self.b_matrix.shape[1] * self.num_frequencies * 2
if self.include_input:
if self.in_dim is None:
raise ValueError("Input dimension has not been set")
out_dim += self.in_dim
return out_dim

Expand All @@ -214,7 +224,7 @@ def forward(
in_tensor: Float[Tensor, "*bs input_dim"],
covs: Optional[Float[Tensor, "*bs input_dim input_dim"]] = None,
) -> Float[Tensor, "*bs output_dim"]:
"""Calculates RFF encoding. If covariances are provided the encodings will be integrated as proposed
"""Calculates FF encoding. If covariances are provided the encodings will be integrated as proposed
in mip-NeRF.

Args:
Expand All @@ -226,11 +236,16 @@ def forward(
"""
scaled_in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi]
scaled_inputs = scaled_in_tensor @ self.b_matrix # [..., "num_frequencies"]
freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device)
scaled_inputs = scaled_inputs[..., None] * freqs # [..., "input_dim", "num_scales"]
scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) # [..., "input_dim" * "num_scales"]

if covs is None:
encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1))
else:
input_var = torch.sum((covs @ self.b_matrix) * self.b_matrix, -2)
input_var = input_var[..., :, None] * freqs[None, :] ** 2
input_var = input_var.reshape((*input_var.shape[:-2], -1))
encoded_inputs = expected_sin(
torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1), torch.cat(2 * [input_var], dim=-1)
)
Expand All @@ -241,6 +256,49 @@ def forward(
return encoded_inputs


class RFFEncoding(FFEncoding):
"""Random Fourier Feature encoding. Supports integrated encodings.

Args:
in_dim: Input dimension of tensor
num_frequencies: Number of encoding frequencies
scale: Std of Gaussian to sample frequencies. Must be greater than zero
include_input: Append the input coordinate to the encoding
"""

def __init__(self, in_dim: int, num_frequencies: int, scale: float, include_input: bool = False) -> None:
if not scale > 0:
raise ValueError("RFF encoding scale should be greater than zero")

b_matrix = torch.normal(mean=0, std=scale, size=(in_dim, num_frequencies))
super().__init__(in_dim, b_matrix, 1, 0.0, 0.0, include_input)


class PolyhedronFFEncoding(FFEncoding):
"""Fourier Feature encoding using polyhedron basis as proposed by mip-NeRF360. Supports integrated encodings.

Args:
num_frequencies: Number of encoded frequencies per axis
min_freq_exp: Minimum frequency exponent
max_freq_exp: Maximum frequency exponent
basis_shape: Shape of polyhedron basis. Either "octahedron" or "icosahedron"
basis_subdivisions: Number of times to tesselate the polyhedron.
include_input: Append the input coordinate to the encoding
"""

def __init__(
self,
num_frequencies: int,
min_freq_exp: float,
max_freq_exp: float,
basis_shape: Literal["octahedron", "icosahedron"] = "octahedron",
basis_subdivisions: int = 1,
include_input: bool = False,
) -> None:
basis_t = generate_polyhedron_basis(basis_shape, basis_subdivisions).T
super().__init__(3, basis_t, num_frequencies, min_freq_exp, max_freq_exp, include_input)


class HashEncoding(Encoding):
"""Hash encoding

Expand Down
161 changes: 159 additions & 2 deletions nerfstudio/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

""" Math Helper Functions """

import itertools
import math
from dataclasses import dataclass
from typing import Literal, Tuple

Expand Down Expand Up @@ -195,7 +197,6 @@ def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The expected value of sin.
"""

return torch.exp(-0.5 * x_vars) * torch.sin(x_means)


Expand Down Expand Up @@ -360,4 +361,160 @@ def normalized_depth_scale_and_shift(
shift[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]

return scale, shift
return scale, shift


def columnwise_squared_l2_distance(
x: Float[Tensor, "*M N"],
y: Float[Tensor, "*M N"],
) -> Float[Tensor, "N N"]:
"""Compute the squared Euclidean distance between all pairs of columns.
Adapted from https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/geopoly.py

Args:
x: tensor of floats, with shape [M, N].
y: tensor of floats, with shape [M, N].
Returns:
sq_dist: tensor of floats, with shape [N, N].
"""
# Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y.
sq_norm_x = torch.sum(x**2, 0)
sq_norm_y = torch.sum(y**2, 0)
sq_dist = sq_norm_x[:, None] + sq_norm_y[None, :] - 2 * x.T @ y
return sq_dist


def _compute_tesselation_weights(v: int) -> Tensor:
"""Tesselate the vertices of a triangle by a factor of `v`.
Adapted from https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/geopoly.py

Args:
v: int, the factor of the tesselation (v==1 is a no-op to the triangle).

Returns:
weights: tesselated weights.
"""
if v < 1:
raise ValueError(f"v {v} must be >= 1")
int_weights = []
for i in range(v + 1):
for j in range(v + 1 - i):
int_weights.append((i, j, v - (i + j)))
int_weights = torch.FloatTensor(int_weights)
weights = int_weights / v # Barycentric weights.
return weights


def _tesselate_geodesic(
vertices: Float[Tensor, "N 3"], faces: Float[Tensor, "M 3"], v: int, eps: float = 1e-4
) -> Tensor:
"""Tesselate the vertices of a geodesic polyhedron.

Adapted from https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/geopoly.py

Args:
vertices: tensor of floats, the vertex coordinates of the geodesic.
faces: tensor of ints, the indices of the vertices of base_verts that
constitute eachface of the polyhedra.
v: int, the factor of the tesselation (v==1 is a no-op).
eps: float, a small value used to determine if two vertices are the same.

Returns:
verts: a tensor of floats, the coordinates of the tesselated vertices.
"""
tri_weights = _compute_tesselation_weights(v)

verts = []
for face in faces:
new_verts = torch.matmul(tri_weights, vertices[face, :])
new_verts /= torch.sqrt(torch.sum(new_verts**2, 1, keepdim=True))
verts.append(new_verts)
verts = torch.concatenate(verts, 0)

sq_dist = columnwise_squared_l2_distance(verts.T, verts.T)
assignment = torch.tensor([torch.min(torch.argwhere(d <= eps)) for d in sq_dist])
unique = torch.unique(assignment)
verts = verts[unique, :]
return verts


def generate_polyhedron_basis(
basis_shape: Literal["icosahedron", "octahedron"],
angular_tesselation: int,
remove_symmetries: bool = True,
eps: float = 1e-4,
) -> Tensor:
"""Generates a 3D basis by tesselating a geometric polyhedron.
Basis is used to construct Fourier features for positional encoding.
See Mip-Nerf360 paper: https://arxiv.org/abs/2111.12077
Adapted from https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/geopoly.py

Args:
base_shape: string, the name of the starting polyhedron, must be either
'icosahedron' or 'octahedron'.
angular_tesselation: int, the number of times to tesselate the polyhedron,
must be >= 1 (a value of 1 is a no-op to the polyhedron).
remove_symmetries: bool, if True then remove the symmetric basis columns,
which is usually a good idea because otherwise projections onto the basis
will have redundant negative copies of each other.
eps: float, a small number used to determine symmetries.

Returns:
basis: a matrix with shape [3, n].
"""
if basis_shape == "icosahedron":
a = (math.sqrt(5) + 1) / 2
verts = torch.FloatTensor(
[
(-1, 0, a),
(1, 0, a),
(-1, 0, -a),
(1, 0, -a),
(0, a, 1),
(0, a, -1),
(0, -a, 1),
(0, -a, -1),
(a, 1, 0),
(-a, 1, 0),
(a, -1, 0),
(-a, -1, 0),
]
) / math.sqrt(a + 2)
faces = torch.tensor(
[
(0, 4, 1),
(0, 9, 4),
(9, 5, 4),
(4, 5, 8),
(4, 8, 1),
(8, 10, 1),
(8, 3, 10),
(5, 3, 8),
(5, 2, 3),
(2, 7, 3),
(7, 10, 3),
(7, 6, 10),
(7, 11, 6),
(11, 0, 6),
(0, 1, 6),
(6, 1, 10),
(9, 0, 11),
(9, 11, 2),
(9, 2, 5),
(7, 2, 11),
]
)
verts = _tesselate_geodesic(verts, faces, angular_tesselation)
elif basis_shape == "octahedron":
verts = torch.FloatTensor([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)])
corners = torch.FloatTensor(list(itertools.product([-1, 1], repeat=3)))
pairs = torch.argwhere(columnwise_squared_l2_distance(corners.T, verts.T) == 2)
faces, _ = torch.sort(torch.reshape(pairs[:, 1], [3, -1]).T, 1)
verts = _tesselate_geodesic(verts, faces, angular_tesselation)

if remove_symmetries:
# Remove elements of `verts` that are reflections of each other.
match = columnwise_squared_l2_distance(verts.T, -verts.T) < eps
verts = verts[torch.any(torch.triu(match), 1), :]

basis = verts.flip(-1)
return basis