Skip to content

Commit

Permalink
[WORKS] Corrects noise addition to input, adds guidance_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
Vikram Voleti committed Nov 28, 2023
1 parent 71b167f commit a7c43e8
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 526 deletions.
10 changes: 6 additions & 4 deletions configs/svd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ data: # threestudio/data/svd.py -> SVDImageDataModuleConfig
default_fovy_deg: 20.0
requires_depth: False
requires_normal: False
random_camera: # threestudio/data/svd.py -> SVDDataModuleConfig
random_camera: # threestudio/data/svd_uncond.py -> SVDDataModuleConfig
height: [64, 128, 256]
width: [64, 128, 256]
batch_size: [21, 21, 21] # same as video model # of frames
Expand All @@ -42,7 +42,7 @@ data: # threestudio/data/svd.py -> SVDImageDataModuleConfig
n_val_views: 30
n_test_views: 120

system_type: "svd-system"
system_type: "zero123-system"
system:
geometry_type: "implicit-volume"
geometry:
Expand Down Expand Up @@ -101,16 +101,18 @@ system:
stable_research_path: "/admin/home-vikram/ROBIN/stable-research"
pretrained_model_name_or_path: "prediction_stable_jucifer_3D_OBJ_IMG"
cond_aug: 0.00
vram_O: true
num_steps: 50
height: ${max:${data.random_camera.height}}
vram_O: true
cond_image_path: ${data.image_path}
cond_elevation_deg: ${data.default_elevation_deg}
cond_azimuth_deg: ${data.default_azimuth_deg}
cond_camera_distance: ${data.default_camera_distance}
guidance_scale: 3.0
guidance_scale: 4.0
min_step_percent: [50, 0.7, 0.3, 200] # (start_iter, start_val, end_val, end_iter)
max_step_percent: [50, 0.98, 0.8, 200]
noises_per_item: 1
guidance_eval_freq: 0

ambient_ratio_min: [0, 0.1, 0.5, 400, 0.9, 401]

Expand Down
20 changes: 11 additions & 9 deletions configs/svd_rgb.yaml → configs/svd_512.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ data: # threestudio/data/svd.py -> SVDImageDataModuleConfig
image_path: ??? # ./load/images/hamburger_rgba.png
height: [128, 256, 512]
width: [128, 256, 512]
resolution_milestones: [200, 300]
resolution_milestones: [50, 100]
default_elevation_deg: 5.0
default_azimuth_deg: 0.0
default_camera_distance: 3.8
default_fovy_deg: 20.0
requires_depth: False
requires_normal: False
random_camera: # threestudio/data/svd.py -> SVDDataModuleConfig
random_camera: # threestudio/data/svd_uncond.py -> SVDDataModuleConfig
height: [64, 128, 256]
width: [64, 128, 256]
batch_size: [21, 21, 21] # same as video model # of frames
resolution_milestones: [200, 300]
resolution_milestones: [50, 100]
eval_height: 512
eval_width: 512
eval_batch_size: 1
Expand All @@ -42,7 +42,7 @@ data: # threestudio/data/svd.py -> SVDImageDataModuleConfig
n_val_views: 30
n_test_views: 120

system_type: "svd-system"
system_type: "zero123-system"
system:
geometry_type: "implicit-volume"
geometry:
Expand Down Expand Up @@ -101,16 +101,18 @@ system:
stable_research_path: "/admin/home-vikram/ROBIN/stable-research"
pretrained_model_name_or_path: "prediction_stable_jucifer_3D_OBJ_IMG"
cond_aug: 0.00
vram_O: true
num_steps: 50
height: ${max:${data.random_camera.height}}
vram_O: true
cond_image_path: ${data.image_path}
cond_elevation_deg: ${data.default_elevation_deg}
cond_azimuth_deg: ${data.default_azimuth_deg}
cond_camera_distance: ${data.default_camera_distance}
guidance_scale: 3.0
min_step_percent: [50, 0.7, 0.3, 200] # (start_iter, start_val, end_val, end_iter)
max_step_percent: [50, 0.98, 0.8, 200]
guidance_scale: 4.0
min_step_percent: [30, 0.7, 0.3, 200] # (start_iter, start_val, end_val, end_iter)
max_step_percent: [30, 0.98, 0.8, 200]
noises_per_item: 1
guidance_eval_freq: 0

ambient_ratio_min: [0, 0.1, 0.5, 400, 0.9, 401]

Expand All @@ -122,7 +124,7 @@ system:

loss:
lambda_sds: 0.1
lambda_rgb: [100, 500., 1000., 400]
lambda_rgb: 1000.
lambda_mask: 50.
lambda_depth: 0. # 0.05
lambda_depth_rel: 0. # [0, 0, 0.05, 100]
Expand Down
166 changes: 64 additions & 102 deletions threestudio/models/guidance/svd_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
import cv2
import imageio
import numpy as np
import threestudio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as TT
from tqdm import tqdm

import threestudio
from threestudio.utils.base import BaseObject
from threestudio.utils.misc import C, get_CPU_mem, get_GPU_mem
from threestudio.utils.typing import *
from tqdm import tqdm


def get_resizing_factor(
Expand Down Expand Up @@ -61,16 +60,17 @@ class Config(BaseObject.Config):
stable_research_path: str = "/admin/home-vikram/ROBIN/stable-research"
pretrained_model_name_or_path: str = "prediction_stable_jucifer_3D_OBJ"
cond_aug: float = 0.00
vram_O: bool = True
num_steps: int = None # 50
height: int = 576
guidance_scale: float = None # 4.0

vram_O: bool = True

cond_image_path: str = "load/images/hamburger_rgba.png"
cond_elevation_deg: float = 0.0
cond_azimuth_deg: float = 0.0
cond_camera_distance: float = 1.2

guidance_scale: float = 5.0

grad_clip: Optional[
Any
] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
Expand All @@ -84,6 +84,11 @@ class Config(BaseObject.Config):
"""Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items."""
max_items_eval: int = 4

guidance_eval_freq: int = 0
guidance_eval_dir: str = os.path.realpath(
os.path.join(os.path.dirname(__file__), "../../../")
)

cfg: Config

def configure(self) -> None:
Expand All @@ -97,6 +102,8 @@ def configure(self) -> None:
self.model = JucifierDenoiser(
self.cfg.pretrained_model_name_or_path,
self.cfg.cond_aug,
self.cfg.num_steps,
self.cfg.guidance_scale,
self.cfg.height,
)
for p in self.model.parameters():
Expand All @@ -106,12 +113,18 @@ def configure(self) -> None:

self.prepare_embeddings(self.cfg.cond_image_path)

self.T = self.model.T
self.num_steps = self.model.num_steps
self.set_min_max_steps() # set to default values

self.count = 0

threestudio.info(f"Loaded SVD!")

@torch.cuda.amp.autocast(enabled=False)
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
self.min_step = int(self.num_train_timesteps * min_step_percent)
self.max_step = int(self.num_train_timesteps * max_step_percent)
self.min_step = int(self.num_steps * min_step_percent)
self.max_step = int(self.num_steps * max_step_percent)

@torch.cuda.amp.autocast(enabled=False)
def prepare_embeddings(
Expand All @@ -131,35 +144,55 @@ def encode_images(
def __call__(
self,
rgb: Float[Tensor, "B H W C"],
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
rgb_as_latents=False,
guidance_eval=False,
**kwargs,
):
rgb_BCHW = rgb.permute(0, 3, 1, 2)
latents = self.encode_images(rgb_BCHW)

latents = torch.repeat_interleave(latents, self.noises_per_item, 0)
elevation = torch.repeat_interleave(elevation, self.noises_per_item)
azimuth = torch.repeat_interleave(azimuth, self.noises_per_item)
camera_distances = torch.repeat_interleave(
camera_distances, self.noises_per_item
)
batch_size = latents.shape[0]

# # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
# t = torch.randint(
# self.min_step,
# self.max_step + 1,
# [batch_size],
# dtype=torch.long,
# device=self.device,
# )
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(
self.min_step,
self.max_step + 1,
[batch_size // self.T],
dtype=torch.long,
device=self.device,
)

guidance_eval = (
self.cfg.guidance_eval_freq > 0
and self.count % self.cfg.guidance_eval_freq == 0
)

with torch.no_grad():
rgb_pred = self.model(latents)
rgb_pred = self.model(latents, t, guidance_eval=guidance_eval)

if guidance_eval:
rgb_pred, rgb_i, rgb_d, rgb_eval = rgb_pred
imageio.mimsave(
os.path.join(
self.cfg.guidance_eval_dir,
f"guidance_eval_input_{self.count:05d}.mp4",
),
rgb_i,
)
imageio.mimsave(
os.path.join(
self.cfg.guidance_eval_dir,
f"guidance_eval_denoise_{self.count:05d}.mp4",
),
rgb_d,
)
imageio.mimsave(
os.path.join(
self.cfg.guidance_eval_dir,
f"guidance_eval_final_{self.count:05d}.mp4",
),
rgb_eval,
)
self.count += 1

# TODO CFG
# TODO min_step, max_step
Expand Down Expand Up @@ -197,88 +230,17 @@ def __call__(

return guidance_out

@torch.cuda.amp.autocast(enabled=False)
@torch.no_grad()
def guidance_eval(self, cond, t_orig, latents_noisy, noise_pred):
# use only 50 timesteps, and find nearest of those to t
self.scheduler.set_timesteps(50)
self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device)
bs = (
min(self.cfg.max_items_eval, latents_noisy.shape[0])
if self.cfg.max_items_eval > 0
else latents_noisy.shape[0]
) # batch size
large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[
:bs
].unsqueeze(
-1
) # sized [bs,50] > [bs,1]
idxs = torch.min(large_enough_idxs, dim=1)[1]
t = self.scheduler.timesteps_gpu[idxs]

fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy())
imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1)

# get prev latent
latents_1step = []
pred_1orig = []
for b in range(bs):
step_output = self.scheduler.step(
noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1
)
latents_1step.append(step_output["prev_sample"])
pred_1orig.append(step_output["pred_original_sample"])
latents_1step = torch.cat(latents_1step)
pred_1orig = torch.cat(pred_1orig)
imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1)
imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1)

latents_final = []
for b, i in enumerate(idxs):
latents = latents_1step[b : b + 1]
c = {
"c_crossattn": [cond["c_crossattn"][0][[b, b + len(idxs)], ...]],
"c_concat": [cond["c_concat"][0][[b, b + len(idxs)], ...]],
}
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
# pred noise
x_in = torch.cat([latents] * 2)
t_in = torch.cat([t.reshape(1)] * 2).to(self.device)
noise_pred = self.model(x_in, t_in, c)
# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# get prev latent
latents = self.scheduler.step(noise_pred, t, latents, eta=1)[
"prev_sample"
]
latents_final.append(latents)

latents_final = torch.cat(latents_final)
imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1)

return {
"bs": bs,
"noise_levels": fracs,
"imgs_noisy": imgs_noisy,
"imgs_1step": imgs_1step,
"imgs_1orig": imgs_1orig,
"imgs_final": imgs_final,
}

def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# clip grad for stable training as demonstrated in
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
# http://arxiv.org/abs/2303.15413
if self.cfg.grad_clip is not None:
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)

# self.set_min_max_steps(
# min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
# max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
# )
self.set_min_max_steps(
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
)

if self.cfg.noises_per_item is not None:
self.noises_per_item = np.floor(
Expand Down
1 change: 0 additions & 1 deletion threestudio/systems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
magic123,
prolificdreamer,
sjc,
svd,
textmesh,
zero123,
zero123_simple,
Expand Down
Loading

0 comments on commit a7c43e8

Please sign in to comment.