Skip to content

Commit

Permalink
[UPDATE] support scale-invariant mode; pass invariance flag through m…
Browse files Browse the repository at this point in the history
…odel_index.json
  • Loading branch information
markkua committed May 17, 2024
1 parent dfc2e11 commit 5126211
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 9 deletions.
7 changes: 5 additions & 2 deletions infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Last modified: 2024-04-15
# Last modified: 2024-05-17
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -213,7 +213,10 @@ def check_directory(directory):
logging.debug("run without xformers")

pipe = pipe.to(device)

logging.info(
f"{pipe.prediction_type = }, {pipe.scale_invariant = }, {pipe.shift_invariant = }"
)

# -------------------- Inference and saving --------------------
with torch.no_grad():
for batch in tqdm(
Expand Down
63 changes: 56 additions & 7 deletions marigold/marigold_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
from diffusers.utils import BaseOutput
from PIL import Image
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms.functional import resize, pil_to_tensor
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import pil_to_tensor, resize
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from .util.batchsize import find_batch_size
from .util.ensemble import ensemble_depths
from .util.ensemble import ensemble_depths, ensemble_depths_up2scale
from .util.image_util import (
chw2hwc,
colorize_depth_maps,
Expand Down Expand Up @@ -97,16 +97,44 @@ def __init__(
scheduler: Union[DDIMScheduler, LCMScheduler],
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
prediction_type: str = None,
scale_invariant: bool = None,
shift_invariant: bool = None,
):
super().__init__()

if prediction_type is None:
logging.warn(
"`prediction_type` is required but not given, filled with 'depth'"
)
prediction_type = "depth"
if scale_invariant is None:
logging.warn(
"`scale_invariant` is required but not given, filled with `True`"
)
scale_invariant = True
if shift_invariant is None:
logging.warn(
"`shift_invariant` is required but not given, filled with `True`"
)
shift_invariant = True

self.prediction_type = prediction_type
self.scale_invariant = scale_invariant
self.shift_invariant = shift_invariant

self.register_modules(
unet=unet,
vae=vae,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
self.register_to_config(
prediction_type=prediction_type,
scale_invariant=scale_invariant,
shift_invariant=shift_invariant,
)

self.empty_text_embed = None

Expand Down Expand Up @@ -152,6 +180,10 @@ def __call__(
Display a progress bar of diffusion denoising.
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
Colormap used to colorize the depth map.
scale_invariant (`str`, *optional*, defaults to `True`):
Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
shift_invariant (`str`, *optional*, defaults to `True`):
Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
Returns:
Expand Down Expand Up @@ -236,17 +268,34 @@ def __call__(

# ----------------- Test-time ensembling -----------------
if ensemble_size > 1:
depth_pred, pred_uncert = ensemble_depths(
depth_preds, **(ensemble_kwargs or {})
)
if self.scale_invariant and self.shift_invariant:
depth_pred, pred_uncert = ensemble_depths(
depth_preds, **(ensemble_kwargs or {})
)
elif self.scale_invariant and (not self.shift_invariant):
depth_pred, pred_uncert = ensemble_depths_up2scale(
depth_preds, **(ensemble_kwargs or {})
)
else:
raise NotImplementedError("Metric depth is not supported.")
else:
depth_pred = depth_preds
pred_uncert = None

# ----------------- Post processing -----------------
# Scale prediction to [0, 1]
min_d = torch.min(depth_pred)
max_d = torch.max(depth_pred)
if self.shift_invariant:
min_d = torch.min(depth_pred)
else:
min_d = 0

if self.scale_invariant:
max_d = torch.max(depth_pred)
else:
raise NotImplementedError(
"Metric depth is not supported."
)

depth_pred = (depth_pred - min_d) / (max_d - min_d)

# Resize back to original resolution
Expand Down
85 changes: 85 additions & 0 deletions marigold/util/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,88 @@ def closure(x):
uncertainty /= _max - _min

return aligned_images, uncertainty


def ensemble_depths_up2scale(
input_images: torch.Tensor,
regularizer_strength: float = 0.02,
max_iter: int = 2,
tol: float = 1e-3,
reduction: str = "median",
max_res: int = None,
):
"""
To ensemble multiple scale-invariant depth images (fixed near plane at 0)
"""
device = input_images.device
dtype = input_images.dtype
np_dtype = np.float32

original_input = input_images.clone()
n_img = input_images.shape[0]
ori_shape = input_images.shape

if max_res is not None:
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
if scale_factor < 1:
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
input_images = downscaler(input_images)

# init guess
_min = 0
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
s_init = 1.0 / (_max - _min).reshape((-1))
x = s_init

input_images = input_images.to(device)

# objective function
def closure(x):
s = torch.from_numpy(x).to(dtype=dtype).to(device)

transformed_arrays = input_images * s.view((-1, 1, 1))
dists = inter_distances(transformed_arrays)
sqrt_dist = torch.sqrt(torch.mean(dists**2))

if "mean" == reduction:
pred = torch.mean(transformed_arrays, dim=0)
elif "median" == reduction:
pred = torch.median(transformed_arrays, dim=0).values
else:
raise ValueError

near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)

err = sqrt_dist + (near_err + far_err) * regularizer_strength
err = err.detach().cpu().numpy().astype(np_dtype)
return err

res = minimize(
closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}
)
s = res.x

# Prediction
s = torch.from_numpy(s).to(dtype=dtype).to(device)
transformed_arrays = original_input * s.view(-1, 1, 1)
if "mean" == reduction:
aligned_images = torch.mean(transformed_arrays, dim=0)
std = torch.std(transformed_arrays, dim=0)
uncertainty = std
elif "median" == reduction:
aligned_images = torch.median(transformed_arrays, dim=0).values
# MAD (median absolute deviation) as uncertainty indicator
abs_dev = torch.abs(transformed_arrays - aligned_images)
mad = torch.median(abs_dev, dim=0).values
uncertainty = mad
else:
raise ValueError(f"Unknown reduction method: {reduction}")

# Scale and shift to [0, 1]
_min = 0
_max = torch.max(aligned_images)
aligned_images = (aligned_images - _min) / (_max - _min)
uncertainty /= _max - _min

return aligned_images, uncertainty
3 changes: 3 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@
pass # run without xformers

pipe = pipe.to(device)
logging.info(
f"{pipe.prediction_type = }, {pipe.scale_invariant = }, {pipe.shift_invariant = }"
)

# -------------------- Inference and saving --------------------
with torch.no_grad():
Expand Down

0 comments on commit 5126211

Please sign in to comment.