Skip to content

Commit

Permalink
isort stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Vikram Voleti committed Nov 12, 2023
1 parent 8d1ecc7 commit b29c01d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 25 deletions.
9 changes: 5 additions & 4 deletions threestudio/models/guidance/zero123_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@

import cv2
import numpy as np
import threestudio
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDIMScheduler
from omegaconf import OmegaConf
from tqdm import tqdm

import threestudio
from threestudio.utils.base import BaseObject
from threestudio.utils.misc import C, get_CPU_mem, get_GPU_mem
from threestudio.utils.typing import *
from tqdm import tqdm


def get_obj_from_str(string, reload=False):
Expand Down Expand Up @@ -160,7 +159,9 @@ def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
self.max_step = int(self.num_train_timesteps * max_step_percent)

@torch.cuda.amp.autocast(enabled=False)
def prepare_embeddings(self, image_path: str, background_color: Tuple[int, int, int] = (255, 255, 255)) -> None:
def prepare_embeddings(
self, image_path: str, background_color: Tuple[int, int, int] = (255, 255, 255)
) -> None:
# load cond image for zero123
assert os.path.exists(image_path)
rgba = cv2.cvtColor(
Expand Down
21 changes: 16 additions & 5 deletions threestudio/models/renderers/nerf_volume_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from functools import partial

import nerfacc
import threestudio
import torch
import torch.nn.functional as F

import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.models.estimators import ImportanceEstimator
from threestudio.models.geometry.base import BaseImplicitGeometry
Expand Down Expand Up @@ -127,9 +126,21 @@ def forward(
**kwargs
) -> Dict[str, Float[Tensor, "..."]]:
if rays_divisor > 1:
rays_o = torch.cat([rays_o[:, xx::rays_divisor, yy::rays_divisor, :] for (xx, yy) in zip(offset_x, offset_y)])
rays_d = torch.cat([rays_d[:, xx::rays_divisor, yy::rays_divisor, :] for (xx, yy) in zip(offset_x, offset_y)])
light_positions = light_positions.repeat(len(rays_o)//len(light_positions), 1)
rays_o = torch.cat(
[
rays_o[:, xx::rays_divisor, yy::rays_divisor, :]
for (xx, yy) in zip(offset_x, offset_y)
]
)
rays_d = torch.cat(
[
rays_d[:, xx::rays_divisor, yy::rays_divisor, :]
for (xx, yy) in zip(offset_x, offset_y)
]
)
light_positions = light_positions.repeat(
len(rays_o) // len(light_positions), 1
)
batch_size, height, width = rays_o.shape[:3]
rays_o_flatten: Float[Tensor, "Nr 3"] = rays_o.reshape(-1, 3)
rays_d_flatten: Float[Tensor, "Nr 3"] = rays_d.reshape(-1, 3)
Expand Down
4 changes: 1 addition & 3 deletions threestudio/systems/magic123.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import shutil
from dataclasses import dataclass, field

import threestudio
import torch
import torch.nn.functional as F

import threestudio
from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.misc import get_CPU_mem, get_GPU_mem
from threestudio.utils.ops import binary_cross_entropy, dot
Expand All @@ -24,7 +23,6 @@ class Config(BaseLift3DSystem.Config):
rays_divisor_power: int = 0
ref_batch_size: int = 1


cfg: Config

def configure(self):
Expand Down
50 changes: 39 additions & 11 deletions threestudio/systems/zero123.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from dataclasses import dataclass, field
from math import ceil

import threestudio
import torch
import torch.nn.functional as F
from torchmetrics import PearsonCorrCoef

import threestudio
from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.ops import binary_cross_entropy, dot
from threestudio.utils.typing import *
from torchmetrics import PearsonCorrCoef


@threestudio.register("zero123-system")
Expand Down Expand Up @@ -65,9 +64,9 @@ def training_substep(self, batch, batch_idx, guidance: str):
ambient_ratio = 1.0
shading = "diffuse"
batch["shading"] = shading
rays_divisor = 2**ceil(self.C(self.cfg.rays_divisor_power))
offset_x_tensor = torch.randint(0,rays_divisor,(self.cfg.ref_batch_size,))
offset_y_tensor = torch.randint(0,rays_divisor,(self.cfg.ref_batch_size,))
rays_divisor = 2 ** ceil(self.C(self.cfg.rays_divisor_power))
offset_x_tensor = torch.randint(0, rays_divisor, (self.cfg.ref_batch_size,))
offset_y_tensor = torch.randint(0, rays_divisor, (self.cfg.ref_batch_size,))
elif guidance == "zero123":
batch = batch["random_camera"]
# ambient_ratio = (
Expand Down Expand Up @@ -101,8 +100,18 @@ def set_loss(name, value):
)

if guidance == "ref":
gt_mask = torch.cat([batch["mask"][:, xx::rays_divisor, yy::rays_divisor, :] for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)])
gt_rgb = torch.cat([batch["rgb"][:, xx::rays_divisor, yy::rays_divisor, :] for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)])
gt_mask = torch.cat(
[
batch["mask"][:, xx::rays_divisor, yy::rays_divisor, :]
for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)
]
)
gt_rgb = torch.cat(
[
batch["rgb"][:, xx::rays_divisor, yy::rays_divisor, :]
for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)
]
)

# color loss
gt_rgb = gt_rgb * gt_mask.float() + out["comp_rgb_bg"] * (
Expand All @@ -115,7 +124,12 @@ def set_loss(name, value):

# depth loss
if self.C(self.cfg.loss.lambda_depth) > 0:
valid_gt_depth = torch.cat([batch["ref_depth"][:, xx::rays_divisor, yy::rays_divisor, :] for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)])[gt_mask.squeeze(-1)].unsqueeze(1)
valid_gt_depth = torch.cat(
[
batch["ref_depth"][:, xx::rays_divisor, yy::rays_divisor, :]
for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)
]
)[gt_mask.squeeze(-1)].unsqueeze(1)
valid_pred_depth = out["depth"][gt_mask].unsqueeze(1)
with torch.no_grad():
A = torch.cat(
Expand All @@ -127,7 +141,12 @@ def set_loss(name, value):

# relative depth loss
if self.C(self.cfg.loss.lambda_depth_rel) > 0:
valid_gt_depth = torch.cat([batch["ref_depth"][:, xx::rays_divisor, yy::rays_divisor, :] for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)])[gt_mask.squeeze(-1)]
valid_gt_depth = torch.cat(
[
batch["ref_depth"][:, xx::rays_divisor, yy::rays_divisor, :]
for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)
]
)[gt_mask.squeeze(-1)]
valid_pred_depth = out["depth"][gt_mask] # [B,]
set_loss(
"depth_rel", 1 - self.pearson(valid_pred_depth, valid_gt_depth)
Expand All @@ -136,7 +155,16 @@ def set_loss(name, value):
# normal loss
if self.C(self.cfg.loss.lambda_normal) > 0:
valid_gt_normal = (
1 - 2 * torch.cat([batch["ref_normal"][:, xx::rays_divisor, yy::rays_divisor, :] for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)])[gt_mask.squeeze(-1)]
1
- 2
* torch.cat(
[
batch["ref_normal"][
:, xx::rays_divisor, yy::rays_divisor, :
]
for (xx, yy) in zip(offset_x_tensor, offset_y_tensor)
]
)[gt_mask.squeeze(-1)]
) # [B, 3]
valid_pred_normal = (
2 * out["comp_normal"][gt_mask.squeeze(-1)] - 1
Expand Down
3 changes: 1 addition & 2 deletions threestudio/systems/zero123_simple.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from dataclasses import dataclass, field

import torch

import threestudio
import torch
from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.ops import binary_cross_entropy, dot
from threestudio.utils.typing import *
Expand Down

0 comments on commit b29c01d

Please sign in to comment.