Skip to content

Commit

Permalink
support gaussian zero-123
Browse files Browse the repository at this point in the history
  • Loading branch information
DSaurus committed Dec 27, 2023
1 parent 652740a commit 1ea7a86
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
2 changes: 1 addition & 1 deletion threestudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__modules__ = {}
__version__ = "0.2.0"
__version__ = "0.2.1"


def register(name):
Expand Down
10 changes: 8 additions & 2 deletions threestudio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def setup(self, cfg, split):
[torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]],
dim=-1,
)
self.c2w4x4: Float[Tensor, "B 4 4"] = torch.cat(
[self.c2w, torch.zeros_like(self.c2w[:, :1])], dim=1
)
self.c2w4x4[:, 3, 3] = 1.0

self.camera_position = camera_position
self.light_position = light_position
Expand Down Expand Up @@ -258,8 +262,10 @@ def collate(self, batch) -> Dict[str, Any]:
"ref_depth": self.depth,
"ref_normal": self.normal,
"mask": self.mask,
"height": self.cfg.height,
"width": self.cfg.width,
"height": self.height,
"width": self.width,
"c2w": self.c2w4x4,
"fovy": self.fovy,
}
if self.cfg.use_random_camera:
batch["random_camera"] = self.random_pose_generator.collate(None)
Expand Down
18 changes: 11 additions & 7 deletions threestudio/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gc
import math
import os
import re

Expand Down Expand Up @@ -62,7 +63,7 @@ def load_module_weights(
return state_dict_to_load, ckpt["epoch"], ckpt["global_step"]


def C(value: Any, epoch: int, global_step: int) -> float:
def C(value: Any, epoch: int, global_step: int, interpolation="linear") -> float:
if isinstance(value, int) or isinstance(value, float):
pass
else:
Expand All @@ -84,15 +85,18 @@ def C(value: Any, epoch: int, global_step: int) -> float:
value = [start_step, start_value, end_value, end_step]
assert len(value) == 4
start_step, start_value, end_value, end_step = value
if isinstance(end_step, int):
if isinstance(end_step, int) or isinstance(end_step, float):
current_step = global_step
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
elif isinstance(end_step, float):
current_step = epoch
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
t = max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0)
if interpolation == "linear":
value = start_value + (end_value - start_value) * t
elif interpolation == "exp":
value = math.exp(math.log(start_value) * (1 - t) + math.log(end_value) * t)
else:
raise ValueError(
f"Unknown interpolation method: {interpolation}, only support linear and exp"
)
return value

Expand Down

0 comments on commit 1ea7a86

Please sign in to comment.