diff --git a/configs/svd.yaml b/configs/svd.yaml index 5e00de80..0bc9ccea 100644 --- a/configs/svd.yaml +++ b/configs/svd.yaml @@ -15,7 +15,7 @@ data: # threestudio/data/svd.py -> SVDImageDataModuleConfig default_fovy_deg: 20.0 requires_depth: False requires_normal: False - random_camera: # threestudio/data/svd.py -> SVDDataModuleConfig + random_camera: # threestudio/data/svd_uncond.py -> SVDDataModuleConfig height: [64, 128, 256] width: [64, 128, 256] batch_size: [21, 21, 21] # same as video model # of frames @@ -42,7 +42,7 @@ data: # threestudio/data/svd.py -> SVDImageDataModuleConfig n_val_views: 30 n_test_views: 120 -system_type: "svd-system" +system_type: "zero123-system" system: geometry_type: "implicit-volume" geometry: @@ -101,16 +101,18 @@ system: stable_research_path: "/admin/home-vikram/ROBIN/stable-research" pretrained_model_name_or_path: "prediction_stable_jucifer_3D_OBJ_IMG" cond_aug: 0.00 - vram_O: true + num_steps: 50 height: ${max:${data.random_camera.height}} + vram_O: true cond_image_path: ${data.image_path} cond_elevation_deg: ${data.default_elevation_deg} cond_azimuth_deg: ${data.default_azimuth_deg} cond_camera_distance: ${data.default_camera_distance} - guidance_scale: 3.0 + guidance_scale: 4.0 min_step_percent: [50, 0.7, 0.3, 200] # (start_iter, start_val, end_val, end_iter) max_step_percent: [50, 0.98, 0.8, 200] noises_per_item: 1 + guidance_eval_freq: 0 ambient_ratio_min: [0, 0.1, 0.5, 400, 0.9, 401] diff --git a/configs/svd_rgb.yaml b/configs/svd_512.yaml similarity index 91% rename from configs/svd_rgb.yaml rename to configs/svd_512.yaml index 5e00de80..7dc49a44 100644 --- a/configs/svd_rgb.yaml +++ b/configs/svd_512.yaml @@ -8,18 +8,18 @@ data: # threestudio/data/svd.py -> SVDImageDataModuleConfig image_path: ??? # ./load/images/hamburger_rgba.png height: [128, 256, 512] width: [128, 256, 512] - resolution_milestones: [200, 300] + resolution_milestones: [50, 100] default_elevation_deg: 5.0 default_azimuth_deg: 0.0 default_camera_distance: 3.8 default_fovy_deg: 20.0 requires_depth: False requires_normal: False - random_camera: # threestudio/data/svd.py -> SVDDataModuleConfig + random_camera: # threestudio/data/svd_uncond.py -> SVDDataModuleConfig height: [64, 128, 256] width: [64, 128, 256] batch_size: [21, 21, 21] # same as video model # of frames - resolution_milestones: [200, 300] + resolution_milestones: [50, 100] eval_height: 512 eval_width: 512 eval_batch_size: 1 @@ -42,7 +42,7 @@ data: # threestudio/data/svd.py -> SVDImageDataModuleConfig n_val_views: 30 n_test_views: 120 -system_type: "svd-system" +system_type: "zero123-system" system: geometry_type: "implicit-volume" geometry: @@ -101,16 +101,18 @@ system: stable_research_path: "/admin/home-vikram/ROBIN/stable-research" pretrained_model_name_or_path: "prediction_stable_jucifer_3D_OBJ_IMG" cond_aug: 0.00 - vram_O: true + num_steps: 50 height: ${max:${data.random_camera.height}} + vram_O: true cond_image_path: ${data.image_path} cond_elevation_deg: ${data.default_elevation_deg} cond_azimuth_deg: ${data.default_azimuth_deg} cond_camera_distance: ${data.default_camera_distance} - guidance_scale: 3.0 - min_step_percent: [50, 0.7, 0.3, 200] # (start_iter, start_val, end_val, end_iter) - max_step_percent: [50, 0.98, 0.8, 200] + guidance_scale: 4.0 + min_step_percent: [30, 0.7, 0.3, 200] # (start_iter, start_val, end_val, end_iter) + max_step_percent: [30, 0.98, 0.8, 200] noises_per_item: 1 + guidance_eval_freq: 0 ambient_ratio_min: [0, 0.1, 0.5, 400, 0.9, 401] @@ -122,7 +124,7 @@ system: loss: lambda_sds: 0.1 - lambda_rgb: [100, 500., 1000., 400] + lambda_rgb: 1000. lambda_mask: 50. lambda_depth: 0. # 0.05 lambda_depth_rel: 0. # [0, 0, 0.05, 100] diff --git a/threestudio/models/guidance/svd_guidance.py b/threestudio/models/guidance/svd_guidance.py index 41abbaf1..152774a4 100644 --- a/threestudio/models/guidance/svd_guidance.py +++ b/threestudio/models/guidance/svd_guidance.py @@ -5,16 +5,15 @@ import cv2 import imageio import numpy as np +import threestudio import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as TT -from tqdm import tqdm - -import threestudio from threestudio.utils.base import BaseObject from threestudio.utils.misc import C, get_CPU_mem, get_GPU_mem from threestudio.utils.typing import * +from tqdm import tqdm def get_resizing_factor( @@ -61,16 +60,17 @@ class Config(BaseObject.Config): stable_research_path: str = "/admin/home-vikram/ROBIN/stable-research" pretrained_model_name_or_path: str = "prediction_stable_jucifer_3D_OBJ" cond_aug: float = 0.00 - vram_O: bool = True + num_steps: int = None # 50 height: int = 576 + guidance_scale: float = None # 4.0 + + vram_O: bool = True cond_image_path: str = "load/images/hamburger_rgba.png" cond_elevation_deg: float = 0.0 cond_azimuth_deg: float = 0.0 cond_camera_distance: float = 1.2 - guidance_scale: float = 5.0 - grad_clip: Optional[ Any ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) @@ -84,6 +84,11 @@ class Config(BaseObject.Config): """Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items.""" max_items_eval: int = 4 + guidance_eval_freq: int = 0 + guidance_eval_dir: str = os.path.realpath( + os.path.join(os.path.dirname(__file__), "../../../") + ) + cfg: Config def configure(self) -> None: @@ -97,6 +102,8 @@ def configure(self) -> None: self.model = JucifierDenoiser( self.cfg.pretrained_model_name_or_path, self.cfg.cond_aug, + self.cfg.num_steps, + self.cfg.guidance_scale, self.cfg.height, ) for p in self.model.parameters(): @@ -106,12 +113,18 @@ def configure(self) -> None: self.prepare_embeddings(self.cfg.cond_image_path) + self.T = self.model.T + self.num_steps = self.model.num_steps + self.set_min_max_steps() # set to default values + + self.count = 0 + threestudio.info(f"Loaded SVD!") @torch.cuda.amp.autocast(enabled=False) def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): - self.min_step = int(self.num_train_timesteps * min_step_percent) - self.max_step = int(self.num_train_timesteps * max_step_percent) + self.min_step = int(self.num_steps * min_step_percent) + self.max_step = int(self.num_steps * max_step_percent) @torch.cuda.amp.autocast(enabled=False) def prepare_embeddings( @@ -131,35 +144,55 @@ def encode_images( def __call__( self, rgb: Float[Tensor, "B H W C"], - elevation: Float[Tensor, "B"], - azimuth: Float[Tensor, "B"], - camera_distances: Float[Tensor, "B"], - rgb_as_latents=False, - guidance_eval=False, **kwargs, ): rgb_BCHW = rgb.permute(0, 3, 1, 2) latents = self.encode_images(rgb_BCHW) latents = torch.repeat_interleave(latents, self.noises_per_item, 0) - elevation = torch.repeat_interleave(elevation, self.noises_per_item) - azimuth = torch.repeat_interleave(azimuth, self.noises_per_item) - camera_distances = torch.repeat_interleave( - camera_distances, self.noises_per_item - ) batch_size = latents.shape[0] - # # timestep ~ U(0.02, 0.98) to avoid very high/low noise level - # t = torch.randint( - # self.min_step, - # self.max_step + 1, - # [batch_size], - # dtype=torch.long, - # device=self.device, - # ) + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint( + self.min_step, + self.max_step + 1, + [batch_size // self.T], + dtype=torch.long, + device=self.device, + ) + + guidance_eval = ( + self.cfg.guidance_eval_freq > 0 + and self.count % self.cfg.guidance_eval_freq == 0 + ) with torch.no_grad(): - rgb_pred = self.model(latents) + rgb_pred = self.model(latents, t, guidance_eval=guidance_eval) + + if guidance_eval: + rgb_pred, rgb_i, rgb_d, rgb_eval = rgb_pred + imageio.mimsave( + os.path.join( + self.cfg.guidance_eval_dir, + f"guidance_eval_input_{self.count:05d}.mp4", + ), + rgb_i, + ) + imageio.mimsave( + os.path.join( + self.cfg.guidance_eval_dir, + f"guidance_eval_denoise_{self.count:05d}.mp4", + ), + rgb_d, + ) + imageio.mimsave( + os.path.join( + self.cfg.guidance_eval_dir, + f"guidance_eval_final_{self.count:05d}.mp4", + ), + rgb_eval, + ) + self.count += 1 # TODO CFG # TODO min_step, max_step @@ -197,77 +230,6 @@ def __call__( return guidance_out - @torch.cuda.amp.autocast(enabled=False) - @torch.no_grad() - def guidance_eval(self, cond, t_orig, latents_noisy, noise_pred): - # use only 50 timesteps, and find nearest of those to t - self.scheduler.set_timesteps(50) - self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device) - bs = ( - min(self.cfg.max_items_eval, latents_noisy.shape[0]) - if self.cfg.max_items_eval > 0 - else latents_noisy.shape[0] - ) # batch size - large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[ - :bs - ].unsqueeze( - -1 - ) # sized [bs,50] > [bs,1] - idxs = torch.min(large_enough_idxs, dim=1)[1] - t = self.scheduler.timesteps_gpu[idxs] - - fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy()) - imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1) - - # get prev latent - latents_1step = [] - pred_1orig = [] - for b in range(bs): - step_output = self.scheduler.step( - noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1 - ) - latents_1step.append(step_output["prev_sample"]) - pred_1orig.append(step_output["pred_original_sample"]) - latents_1step = torch.cat(latents_1step) - pred_1orig = torch.cat(pred_1orig) - imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1) - imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1) - - latents_final = [] - for b, i in enumerate(idxs): - latents = latents_1step[b : b + 1] - c = { - "c_crossattn": [cond["c_crossattn"][0][[b, b + len(idxs)], ...]], - "c_concat": [cond["c_concat"][0][[b, b + len(idxs)], ...]], - } - for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False): - # pred noise - x_in = torch.cat([latents] * 2) - t_in = torch.cat([t.reshape(1)] * 2).to(self.device) - noise_pred = self.model(x_in, t_in, c) - # perform guidance - noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( - noise_pred_cond - noise_pred_uncond - ) - # get prev latent - latents = self.scheduler.step(noise_pred, t, latents, eta=1)[ - "prev_sample" - ] - latents_final.append(latents) - - latents_final = torch.cat(latents_final) - imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1) - - return { - "bs": bs, - "noise_levels": fracs, - "imgs_noisy": imgs_noisy, - "imgs_1step": imgs_1step, - "imgs_1orig": imgs_1orig, - "imgs_final": imgs_final, - } - def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): # clip grad for stable training as demonstrated in # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation @@ -275,10 +237,10 @@ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = Fals if self.cfg.grad_clip is not None: self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) - # self.set_min_max_steps( - # min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), - # max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), - # ) + self.set_min_max_steps( + min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), + max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), + ) if self.cfg.noises_per_item is not None: self.noises_per_item = np.floor( diff --git a/threestudio/systems/__init__.py b/threestudio/systems/__init__.py index 197095ee..edbe7bf2 100644 --- a/threestudio/systems/__init__.py +++ b/threestudio/systems/__init__.py @@ -9,7 +9,6 @@ magic123, prolificdreamer, sjc, - svd, textmesh, zero123, zero123_simple, diff --git a/threestudio/systems/svd.py b/threestudio/systems/svd.py deleted file mode 100644 index cb3c1cdb..00000000 --- a/threestudio/systems/svd.py +++ /dev/null @@ -1,410 +0,0 @@ -import os -import random -import shutil -from dataclasses import dataclass, field -from math import ceil - -import torch -import torch.nn.functional as F -from torchmetrics import PearsonCorrCoef - -import threestudio -from threestudio.systems.base import BaseLift3DSystem -from threestudio.utils.ops import binary_cross_entropy, dot -from threestudio.utils.typing import * - - -@threestudio.register("svd-system") -class StableVideoDiffusion(BaseLift3DSystem): - @dataclass - class Config(BaseLift3DSystem.Config): - freq: dict = field(default_factory=dict) - refinement: bool = False - ambient_ratio_min: float = 0.5 - rays_divisor_power: int = 0 - ref_batch_size: int = 1 - - cfg: Config - - def configure(self): - # create geometry, material, background, renderer - super().configure() - - def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: - render_out = self.renderer(**batch) - return { - **render_out, - } - - def on_fit_start(self) -> None: - super().on_fit_start() - # no prompt processor - self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance) - - # visualize all training images - all_images = self.trainer.datamodule.train_dataloader().dataset.get_all_images() - self.save_image_grid( - "all_training_images.png", - [ - {"type": "rgb", "img": image, "kwargs": {"data_format": "HWC"}} - for image in all_images - ], - name="on_fit_start", - step=self.true_global_step, - ) - - self.pearson = PearsonCorrCoef().to(self.device) - - def training_substep(self, batch, batch_idx, guidance: str): - """ - Args: - guidance: one of "ref" (reference image supervision), "svd" - """ - if guidance == "ref": - # bg_color = torch.rand_like(batch['rays_o']) - ambient_ratio = 1.0 - shading = "diffuse" - batch["shading"] = shading - rays_divisor = 2 ** ceil(self.C(self.cfg.rays_divisor_power)) - offset_x_tensor = torch.randint(0, rays_divisor, (self.cfg.ref_batch_size,)) - offset_y_tensor = torch.randint(0, rays_divisor, (self.cfg.ref_batch_size,)) - elif guidance == "svd": - batch = batch["random_camera"] - # ambient_ratio = ( - # self.C(self.cfg.ambient_ratio_min) - # + (1 - self.C(self.cfg.ambient_ratio_min)) * random.random() - # ) - ambient_ratio = self.C(self.cfg.ambient_ratio_min) - rays_divisor = 1 - offset_x_tensor = torch.zeros(1, dtype=torch.int64) - offset_y_tensor = torch.zeros(1, dtype=torch.int64) - - batch["rays_divisor"] = rays_divisor - batch["offset_x"] = offset_x_tensor - batch["offset_y"] = offset_y_tensor - - batch["bg_color"] = None - batch["ambient_ratio"] = ambient_ratio - - out = self(batch) - loss_prefix = f"loss_{guidance}_" - - loss_terms = {} - - def set_loss(name, value): - loss_terms[f"{loss_prefix}{name}"] = value - - guidance_eval = ( - guidance == "svd" - and getattr(self.cfg.freq, "guidance_eval", 0) > 0 - and self.true_global_step % getattr(self.cfg.freq, "guidance_eval", 0) == 0 - ) - - if guidance == "ref": - gt_mask = torch.cat( - [ - batch["mask"][:, xx::rays_divisor, yy::rays_divisor, :] - for (xx, yy) in zip(offset_x_tensor, offset_y_tensor) - ] - ) - gt_rgb = torch.cat( - [ - batch["rgb"][:, xx::rays_divisor, yy::rays_divisor, :] - for (xx, yy) in zip(offset_x_tensor, offset_y_tensor) - ] - ) - - # color loss - gt_rgb = gt_rgb * gt_mask.float() + out["comp_rgb_bg"] * ( - 1 - gt_mask.float() - ) - set_loss("rgb", F.mse_loss(gt_rgb, out["comp_rgb"])) - - # mask loss - set_loss("mask", F.mse_loss(gt_mask.float(), out["opacity"])) - - # depth loss - if self.C(self.cfg.loss.lambda_depth) > 0: - valid_gt_depth = torch.cat( - [ - batch["ref_depth"][:, xx::rays_divisor, yy::rays_divisor, :] - for (xx, yy) in zip(offset_x_tensor, offset_y_tensor) - ] - )[gt_mask.squeeze(-1)].unsqueeze(1) - valid_pred_depth = out["depth"][gt_mask].unsqueeze(1) - with torch.no_grad(): - A = torch.cat( - [valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1 - ) # [B, 2] - X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1] - valid_gt_depth = A @ X # [B, 1] - set_loss("depth", F.mse_loss(valid_gt_depth, valid_pred_depth)) - - # relative depth loss - if self.C(self.cfg.loss.lambda_depth_rel) > 0: - valid_gt_depth = torch.cat( - [ - batch["ref_depth"][:, xx::rays_divisor, yy::rays_divisor, :] - for (xx, yy) in zip(offset_x_tensor, offset_y_tensor) - ] - )[gt_mask.squeeze(-1)] - valid_pred_depth = out["depth"][gt_mask] # [B,] - set_loss( - "depth_rel", 1 - self.pearson(valid_pred_depth, valid_gt_depth) - ) - - # normal loss - if self.C(self.cfg.loss.lambda_normal) > 0: - valid_gt_normal = ( - 1 - - 2 - * torch.cat( - [ - batch["ref_normal"][ - :, xx::rays_divisor, yy::rays_divisor, : - ] - for (xx, yy) in zip(offset_x_tensor, offset_y_tensor) - ] - )[gt_mask.squeeze(-1)] - ) # [B, 3] - valid_pred_normal = ( - 2 * out["comp_normal"][gt_mask.squeeze(-1)] - 1 - ) # [B, 3] - set_loss( - "normal", - 1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean(), - ) - elif guidance == "svd": - # sgm svd - guidance_out = self.guidance( - out["comp_rgb"], - **batch, - rgb_as_latents=False, - guidance_eval=guidance_eval, - ) - # claforte: TODO: rename the loss_terms keys - set_loss("sds", guidance_out["loss_sds"]) - - self.log("train/mem_cpu", guidance_out["cpu_mem"], prog_bar=True) - self.log("train/mem_gpu", guidance_out["gpu_mem"], prog_bar=True) - - if self.C(self.cfg.loss.lambda_normal_smooth) > 0: - if "comp_normal" not in out: - raise ValueError( - "comp_normal is required for 2D normal smooth loss, no comp_normal is found in the output." - ) - normal = out["comp_normal"] - set_loss( - "normal_smooth", - (normal[:, 1:, :, :] - normal[:, :-1, :, :]).square().mean() - + (normal[:, :, 1:, :] - normal[:, :, :-1, :]).square().mean(), - ) - - if self.C(self.cfg.loss.lambda_3d_normal_smooth) > 0: - if "normal" not in out: - raise ValueError( - "Normal is required for normal smooth loss, no normal is found in the output." - ) - if "normal_perturb" not in out: - raise ValueError( - "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." - ) - normals = out["normal"] - normals_perturb = out["normal_perturb"] - set_loss("3d_normal_smooth", (normals - normals_perturb).abs().mean()) - - if not self.cfg.refinement: - if self.C(self.cfg.loss.lambda_orient) > 0: - if "normal" not in out: - raise ValueError( - "Normal is required for orientation loss, no normal is found in the output." - ) - set_loss( - "orient", - ( - out["weights"].detach() - * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2 - ).sum() - / (out["opacity"] > 0).sum(), - ) - - if guidance != "ref" and self.C(self.cfg.loss.lambda_sparsity) > 0: - set_loss("sparsity", (out["opacity"] ** 2 + 0.01).sqrt().mean()) - - if self.C(self.cfg.loss.lambda_opaque) > 0: - opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3) - set_loss( - "opaque", binary_cross_entropy(opacity_clamped, opacity_clamped) - ) - else: - if self.C(self.cfg.loss.lambda_normal_consistency) > 0: - set_loss("normal_consistency", out["mesh"].normal_consistency()) - if self.C(self.cfg.loss.lambda_laplacian_smoothness) > 0: - set_loss("laplacian_smoothness", out["mesh"].laplacian()) - - loss = 0.0 - for name, value in loss_terms.items(): - self.log(f"train/{name}", value) - if name.startswith(loss_prefix): - loss_weighted = value * self.C( - self.cfg.loss[name.replace(loss_prefix, "lambda_")] - ) - self.log(f"train/{name}_w", loss_weighted) - loss += loss_weighted - - for name, value in self.cfg.loss.items(): - self.log(f"train_params/{name}", self.C(value)) - - self.log(f"train/loss_{guidance}", loss) - - if guidance_eval: - self.guidance_evaluation_save( - out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]], - guidance_out["eval"], - ) - - return {"loss": loss} - - def training_step(self, batch, batch_idx): - total_loss = 0.0 - - # SGM SVD - out = self.training_substep(batch, batch_idx, guidance="svd") - total_loss += out["loss"] - - # REF - out = self.training_substep(batch, batch_idx, guidance="ref") - total_loss += out["loss"] - - self.log("train/loss", total_loss, prog_bar=True) - - return {"loss": total_loss} - - def validation_step(self, batch, batch_idx): - out = self(batch) - self.save_image_grid( - f"it{self.true_global_step}-val/{batch['index'][0]}.png", - # ( - # [ - # { - # "type": "rgb", - # "img": batch["rgb"][0], - # "kwargs": {"data_format": "HWC"}, - # } - # ] - # if "rgb" in batch - # else [] - # ) - # + - [ - { - "type": "rgb", - "img": out["comp_rgb"][0], - "kwargs": {"data_format": "HWC"}, - }, - ] - + ( - [ - { - "type": "rgb", - "img": out["comp_normal"][0], - "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, - } - ] - if "comp_normal" in out - else [] - ) - + ( - [ - { - "type": "grayscale", - "img": out["depth"][0], - "kwargs": {}, - } - ] - if "depth" in out - else [] - ) - + [ - { - "type": "grayscale", - "img": out["opacity"][0, :, :, 0], - "kwargs": {"cmap": None, "data_range": (0, 1)}, - }, - ], - name=None, - step=self.true_global_step, - ) - - def on_validation_epoch_end(self): - filestem = f"it{self.true_global_step}-val" - self.save_img_sequence( - filestem, - filestem, - "(\d+)\.png", - save_format="mp4", - fps=30, - name="validation_epoch_end", - step=self.true_global_step, - ) - shutil.rmtree( - os.path.join(self.get_save_dir(), f"it{self.true_global_step}-val") - ) - - def test_step(self, batch, batch_idx): - out = self(batch) - self.save_image_grid( - f"it{self.true_global_step}-test/{batch['index'][0]}.png", - [ - { - "type": "rgb", - "img": out["comp_rgb"][0], - "kwargs": {"data_format": "HWC"}, - }, - ] - + ( - [ - { - "type": "rgb", - "img": out["comp_normal"][0], - "kwargs": {"data_format": "HWC", "data_range": (0, 1)}, - } - ] - if "comp_normal" in out - else [] - ) - + ( - [ - { - "type": "grayscale", - "img": out["depth"][0], - "kwargs": {}, - } - ] - if "depth" in out - else [] - ) - + [ - { - "type": "grayscale", - "img": out["opacity"][0, :, :, 0], - "kwargs": {"cmap": None, "data_range": (0, 1)}, - }, - ], - name="test_step", - step=self.true_global_step, - ) - - def on_test_epoch_end(self): - self.save_img_sequence( - f"it{self.true_global_step}-test", - f"it{self.true_global_step}-test", - "(\d+)\.png", - save_format="mp4", - fps=30, - name="test", - step=self.true_global_step, - ) - shutil.rmtree( - os.path.join(self.get_save_dir(), f"it{self.true_global_step}-test") - )