From 7d0c143eddc6cfe4b5711b51c42815389ed210ed Mon Sep 17 00:00:00 2001 From: hturki Date: Fri, 14 Apr 2023 06:29:15 -0400 Subject: [PATCH 1/2] return unscaled input in positional encoding --- nerfstudio/field_components/encodings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nerfstudio/field_components/encodings.py b/nerfstudio/field_components/encodings.py index 462080287a..8f283cfb77 100644 --- a/nerfstudio/field_components/encodings.py +++ b/nerfstudio/field_components/encodings.py @@ -141,9 +141,9 @@ def forward( Returns: Output values will be between -1 and 1 """ - in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi] + scaled_in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi] freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies).to(in_tensor.device) - scaled_inputs = in_tensor[..., None] * freqs # [..., "input_dim", "num_scales"] + scaled_inputs = scaled_in_tensor[..., None] * freqs # [..., "input_dim", "num_scales"] scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) # [..., "input_dim" * "num_scales"] if covs is None: From 80b149563cf4a7303018e65dc5c767ea5e942c59 Mon Sep 17 00:00:00 2001 From: hturki Date: Fri, 14 Apr 2023 22:03:06 -0400 Subject: [PATCH 2/2] Update encodings.py --- nerfstudio/field_components/encodings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nerfstudio/field_components/encodings.py b/nerfstudio/field_components/encodings.py index 8f283cfb77..ac7cf2aa17 100644 --- a/nerfstudio/field_components/encodings.py +++ b/nerfstudio/field_components/encodings.py @@ -201,8 +201,8 @@ def forward( Returns: Output values will be between -1 and 1 """ - in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi] - scaled_inputs = in_tensor @ self.b_matrix # [..., "num_frequencies"] + scaled_in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi] + scaled_inputs = scaled_in_tensor @ self.b_matrix # [..., "num_frequencies"] if covs is None: encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1))