Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: backport TypeAlias for py3.9 compat #1

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions k_diffusion/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,21 @@ def sigma_to_t(self, sigma, quantize=None):
dists = torch.abs(sigma - self.sigmas[:, None])
if quantize:
return torch.argmin(dists, dim=0).view(sigma.shape)
low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
topk_indices = torch.topk(dists, dim=0, k=2, largest=False).indices
topk_indices_device=topk_indices.device

# TODO: revert this once MPS supports aten::sort.values_stable.
# we're transferring the topk indices to CPU, sorting them, then transferring the result (sort_values) back to GPU.
# it's fine to sort on-CPU, because it's a wee little 2x2 matrix.
# PYTORCH_ENABLE_MPS_FALLBACK=1 would do the same thing. but I want us to be able to run without that.
# so that we find out any time a fallback is required, and can review whether it's consequential.
must_sort_on_cpu = topk_indices_device.type == 'mps'
topk_indices = topk_indices.cpu() if must_sort_on_cpu else topk_indices

sort_values = torch.sort(topk_indices, dim=0).values
sort_values = sort_values.to(topk_indices_device) if must_sort_on_cpu else sort_values

low_idx, high_idx = sort_values
low, high = self.sigmas[low_idx], self.sigmas[high_idx]
w = (low - sigma) / (low - high)
w = w.clamp(0, 1)
Expand All @@ -68,7 +82,8 @@ def sigma_to_t(self, sigma, quantize=None):

def t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
t_floor = t.floor()
low_idx, high_idx, w = t_floor.long(), t.ceil().long(), t-t_floor if t.device.type == 'mps' else t.frac()
return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]


Expand Down
43 changes: 32 additions & 11 deletions k_diffusion/sampling.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
import math
import numpy as np

from functools import partial
from scipy import integrate
import torch
from torch import Tensor
from torchdiffeq import odeint
from tqdm.auto import trange, tqdm
from typing import Optional, Callable
try:
from typing import TypeAlias
except ImportError:
from typing_extensions import TypeAlias

from . import utils

TensorOperator: TypeAlias = Callable[[Tensor], Tensor]

def make_quantizer(quanta: Tensor) -> TensorOperator:
"""Returns an monotype operator which accepts a single-element 1-dimensional Tensor, and rounds its element to the nearest element in `quanta`"""
return partial(utils.quantize, quanta)

def append_zero(x):
return torch.cat([x, x.new_zeros([1])])


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu', concat_zero=True):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n)
ramp = torch.linspace(0, 1, n if concat_zero else n+1)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
sigmas = sigmas.to(device)
return append_zero(sigmas) if concat_zero else sigmas


def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
Expand All @@ -34,9 +48,13 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
return append_zero(sigmas)


def to_d(x, sigma, denoised):
def to_d(x, sigma, denoised, clone_please=False):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
coeff = utils.append_dims(sigma, x.ndim)
# for some reason, cloning coeff fixes a problem where values were returned as ±inf
# there's probably a better place to do the cloning than here, but this fixes sample_heun on MPS
coeff = coeff.detach().clone() if coeff.device.type == 'mps' and clone_please else coeff
return (x - denoised) / coeff


def get_ancestral_step(sigma_from, sigma_to):
Expand All @@ -48,14 +66,15 @@ def get_ancestral_step(sigma_from, sigma_to):


@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
Expand All @@ -78,7 +97,7 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
d = to_d(x, sigmas[i], denoised, clone_please=True)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
Expand All @@ -87,14 +106,15 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis


@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
Expand All @@ -109,21 +129,22 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None,
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2, clone_please=True)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x


@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
Expand Down Expand Up @@ -184,7 +205,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
d = to_d(x, sigmas[i], denoised, clone_please=True)
ds.append(d)
if len(ds) > order:
ds.pop(0)
Expand Down
7 changes: 6 additions & 1 deletion k_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import warnings

import torch
from torch import optim
from torch import optim, Tensor
from torchvision.transforms import functional as TF
from typing import Union


def from_pil_image(x):
Expand Down Expand Up @@ -249,3 +250,7 @@ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.floa
min_value = math.log(min_value)
max_value = math.log(max_value)
return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()

def quantize(quanta: Tensor, candidate: Union[int, float, Tensor]) -> Tensor:
"""Rounds `candidate` to the nearest element in `quanta`"""
return quanta[torch.argmin((quanta-candidate).abs(), dim=0)]
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ torchdiffeq
torchvision
tqdm
wandb
git+https://github.com/openai/CLIP
git+https://github.com/openai/CLIP#egg=clip
typing-extensions
24 changes: 24 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from setuptools import setup, find_packages

setup(
name='k-diffusion',
version='0.0.1',
description='Karras et al. (2022) diffusion models for PyTorch',
packages=find_packages(),
install_requires=[
'accelerate',
'clean-fid',
'einops',
'jsonmerge',
'kornia',
'Pillow',
'resize-right',
'scikit-image',
'scipy',
'torch',
'torchdiffeq',
'torchvision',
'tqdm',
'wandb',
],
)