diff --git a/k_diffusion/external.py b/k_diffusion/external.py index 71c5b94..63cc624 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -59,7 +59,21 @@ def sigma_to_t(self, sigma, quantize=None): dists = torch.abs(sigma - self.sigmas[:, None]) if quantize: return torch.argmin(dists, dim=0).view(sigma.shape) - low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] + topk_indices = torch.topk(dists, dim=0, k=2, largest=False).indices + topk_indices_device=topk_indices.device + + # TODO: revert this once MPS supports aten::sort.values_stable. + # we're transferring the topk indices to CPU, sorting them, then transferring the result (sort_values) back to GPU. + # it's fine to sort on-CPU, because it's a wee little 2x2 matrix. + # PYTORCH_ENABLE_MPS_FALLBACK=1 would do the same thing. but I want us to be able to run without that. + # so that we find out any time a fallback is required, and can review whether it's consequential. + must_sort_on_cpu = topk_indices_device.type == 'mps' + topk_indices = topk_indices.cpu() if must_sort_on_cpu else topk_indices + + sort_values = torch.sort(topk_indices, dim=0).values + sort_values = sort_values.to(topk_indices_device) if must_sort_on_cpu else sort_values + + low_idx, high_idx = sort_values low, high = self.sigmas[low_idx], self.sigmas[high_idx] w = (low - sigma) / (low - high) w = w.clamp(0, 1) @@ -68,7 +82,8 @@ def sigma_to_t(self, sigma, quantize=None): def t_to_sigma(self, t): t = t.float() - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + t_floor = t.floor() + low_idx, high_idx, w = t_floor.long(), t.ceil().long(), t-t_floor if t.device.type == 'mps' else t.frac() return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx] diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index b5a1c39..49df669 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -1,24 +1,38 @@ import math +import numpy as np +from functools import partial from scipy import integrate import torch +from torch import Tensor from torchdiffeq import odeint from tqdm.auto import trange, tqdm +from typing import Optional, Callable +try: + from typing import TypeAlias +except ImportError: + from typing_extensions import TypeAlias from . import utils +TensorOperator: TypeAlias = Callable[[Tensor], Tensor] + +def make_quantizer(quanta: Tensor) -> TensorOperator: + """Returns an monotype operator which accepts a single-element 1-dimensional Tensor, and rounds its element to the nearest element in `quanta`""" + return partial(utils.quantize, quanta) def append_zero(x): return torch.cat([x, x.new_zeros([1])]) -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu', concat_zero=True): """Constructs the noise schedule of Karras et al. (2022).""" - ramp = torch.linspace(0, 1, n) + ramp = torch.linspace(0, 1, n if concat_zero else n+1) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return append_zero(sigmas).to(device) + sigmas = sigmas.to(device) + return append_zero(sigmas) if concat_zero else sigmas def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): @@ -34,9 +48,13 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): return append_zero(sigmas) -def to_d(x, sigma, denoised): +def to_d(x, sigma, denoised, clone_please=False): """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / utils.append_dims(sigma, x.ndim) + coeff = utils.append_dims(sigma, x.ndim) + # for some reason, cloning coeff fixes a problem where values were returned as ±inf + # there's probably a better place to do the cloning than here, but this fixes sample_heun on MPS + coeff = coeff.detach().clone() if coeff.device.type == 'mps' and clone_please else coeff + return (x - denoised) / coeff def get_ancestral_step(sigma_from, sigma_to): @@ -48,7 +66,7 @@ def get_ancestral_step(sigma_from, sigma_to): @torch.no_grad() -def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -56,6 +74,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -78,7 +97,7 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - d = to_d(x, sigmas[i], denoised) + d = to_d(x, sigmas[i], denoised, clone_please=True) # Euler method dt = sigma_down - sigmas[i] x = x + d * dt @@ -87,7 +106,7 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis @torch.no_grad() -def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -95,6 +114,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -109,14 +129,14 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, # Heun's method x_2 = x + d * dt denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) - d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2, clone_please=True) d_prime = (d + d_2) / 2 x = x + d_prime * dt return x @torch.no_grad() -def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -124,6 +144,7 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -184,7 +205,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o ds = [] for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) - d = to_d(x, sigmas[i], denoised) + d = to_d(x, sigmas[i], denoised, clone_please=True) ds.append(d) if len(ds) > order: ds.pop(0) diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 8d700c2..9cb8b07 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -7,8 +7,9 @@ import warnings import torch -from torch import optim +from torch import optim, Tensor from torchvision.transforms import functional as TF +from typing import Union def from_pil_image(x): @@ -249,3 +250,7 @@ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.floa min_value = math.log(min_value) max_value = math.log(max_value) return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() + +def quantize(quanta: Tensor, candidate: Union[int, float, Tensor]) -> Tensor: + """Rounds `candidate` to the nearest element in `quanta`""" + return quanta[torch.argmin((quanta-candidate).abs(), dim=0)] diff --git a/requirements.txt b/requirements.txt index 54c5a95..bcdb6e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ torchdiffeq torchvision tqdm wandb -git+https://github.com/openai/CLIP +git+https://github.com/openai/CLIP#egg=clip +typing-extensions \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..35dca52 --- /dev/null +++ b/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup, find_packages + +setup( + name='k-diffusion', + version='0.0.1', + description='Karras et al. (2022) diffusion models for PyTorch', + packages=find_packages(), + install_requires=[ + 'accelerate', + 'clean-fid', + 'einops', + 'jsonmerge', + 'kornia', + 'Pillow', + 'resize-right', + 'scikit-image', + 'scipy', + 'torch', + 'torchdiffeq', + 'torchvision', + 'tqdm', + 'wandb', + ], +) \ No newline at end of file