forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* k-diffusion-euler * make style make quality * make fix-copies * fix tests for euler a * Update src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py Co-authored-by: Anton Lozhkov <[email protected]> * Update src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py Co-authored-by: Anton Lozhkov <[email protected]> * Update src/diffusers/schedulers/scheduling_euler_discrete.py Co-authored-by: Anton Lozhkov <[email protected]> * Update src/diffusers/schedulers/scheduling_euler_discrete.py Co-authored-by: Anton Lozhkov <[email protected]> * remove unused arg and method * update doc * quality * make flake happy * use logger instead of warn * raise error instead of deprication * don't require scipy * pass generator in step * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Update tests/test_scheduler.py Co-authored-by: Patrick von Platen <[email protected]> * remove unused generator * pass generator as extra_step_kwargs * update tests * pass generator as kwarg * pass generator as kwarg * quality * fix test for lms * fix tests Co-authored-by: patil-suraj <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
- Loading branch information
1 parent
bf7b0bc
commit a1ea8c0
Showing
11 changed files
with
858 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
261 changes: 261 additions & 0 deletions
261
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from ..utils import BaseOutput, deprecate, logging | ||
from .scheduling_utils import SchedulerMixin | ||
|
||
|
||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
@dataclass | ||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete | ||
class EulerAncestralDiscreteSchedulerOutput(BaseOutput): | ||
""" | ||
Output class for the scheduler's step function output. | ||
Args: | ||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): | ||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the | ||
denoising loop. | ||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): | ||
The predicted denoised sample (x_{0}) based on the model output from the current timestep. | ||
`pred_original_sample` can be used to preview progress or for guidance. | ||
""" | ||
|
||
prev_sample: torch.FloatTensor | ||
pred_original_sample: Optional[torch.FloatTensor] = None | ||
|
||
|
||
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | ||
""" | ||
Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: | ||
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 | ||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | ||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | ||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and | ||
[`~ConfigMixin.from_config`] functions. | ||
Args: | ||
num_train_timesteps (`int`): number of diffusion steps used to train the model. | ||
beta_start (`float`): the starting `beta` value of inference. | ||
beta_end (`float`): the final `beta` value. | ||
beta_schedule (`str`): | ||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | ||
`linear` or `scaled_linear`. | ||
trained_betas (`np.ndarray`, optional): | ||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. | ||
""" | ||
|
||
@register_to_config | ||
def __init__( | ||
self, | ||
num_train_timesteps: int = 1000, | ||
beta_start: float = 0.0001, | ||
beta_end: float = 0.02, | ||
beta_schedule: str = "linear", | ||
trained_betas: Optional[np.ndarray] = None, | ||
): | ||
if trained_betas is not None: | ||
self.betas = torch.from_numpy(trained_betas) | ||
elif beta_schedule == "linear": | ||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | ||
elif beta_schedule == "scaled_linear": | ||
# this schedule is very specific to the latent diffusion model. | ||
self.betas = ( | ||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | ||
) | ||
else: | ||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
|
||
self.alphas = 1.0 - self.betas | ||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
|
||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) | ||
self.sigmas = torch.from_numpy(sigmas) | ||
|
||
# standard deviation of the initial noise distribution | ||
self.init_noise_sigma = self.sigmas.max() | ||
|
||
# setable values | ||
self.num_inference_steps = None | ||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() | ||
self.timesteps = torch.from_numpy(timesteps) | ||
self.is_scale_input_called = False | ||
|
||
def scale_model_input( | ||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] | ||
) -> torch.FloatTensor: | ||
""" | ||
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. | ||
Args: | ||
sample (`torch.FloatTensor`): input sample | ||
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain | ||
Returns: | ||
`torch.FloatTensor`: scaled input sample | ||
""" | ||
if isinstance(timestep, torch.Tensor): | ||
timestep = timestep.to(self.timesteps.device) | ||
step_index = (self.timesteps == timestep).nonzero().item() | ||
sigma = self.sigmas[step_index] | ||
sample = sample / ((sigma**2 + 1) ** 0.5) | ||
self.is_scale_input_called = True | ||
return sample | ||
|
||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): | ||
""" | ||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
Args: | ||
num_inference_steps (`int`): | ||
the number of diffusion steps used when generating samples with a pre-trained model. | ||
device (`str` or `torch.device`, optional): | ||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | ||
""" | ||
self.num_inference_steps = num_inference_steps | ||
|
||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() | ||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
self.sigmas = torch.from_numpy(sigmas).to(device=device) | ||
self.timesteps = torch.from_numpy(timesteps).to(device=device) | ||
|
||
def step( | ||
self, | ||
model_output: torch.FloatTensor, | ||
timestep: Union[float, torch.FloatTensor], | ||
sample: torch.FloatTensor, | ||
generator: Optional[torch.Generator] = None, | ||
return_dict: bool = True, | ||
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: | ||
""" | ||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
process from the learned model outputs (most often the predicted noise). | ||
Args: | ||
model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
timestep (`float`): current timestep in the diffusion chain. | ||
sample (`torch.FloatTensor`): | ||
current instance of sample being created by diffusion process. | ||
generator (`torch.Generator`, optional): Random number generator. | ||
return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class | ||
Returns: | ||
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: | ||
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise | ||
a `tuple`. When returning a tuple, the first element is the sample tensor. | ||
""" | ||
|
||
if ( | ||
isinstance(timestep, int) | ||
or isinstance(timestep, torch.IntTensor) | ||
or isinstance(timestep, torch.LongTensor) | ||
): | ||
raise ValueError( | ||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | ||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | ||
" one of the `scheduler.timesteps` as a timestep.", | ||
) | ||
|
||
if not self.is_scale_input_called: | ||
logger.warn( | ||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. " | ||
"See `StableDiffusionPipeline` for a usage example." | ||
) | ||
|
||
if isinstance(timestep, torch.Tensor): | ||
timestep = timestep.to(self.timesteps.device) | ||
|
||
step_index = (self.timesteps == timestep).nonzero().item() | ||
sigma = self.sigmas[step_index] | ||
|
||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
pred_original_sample = sample - sigma * model_output | ||
sigma_from = self.sigmas[step_index] | ||
sigma_to = self.sigmas[step_index + 1] | ||
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 | ||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | ||
|
||
# 2. Convert to an ODE derivative | ||
derivative = (sample - pred_original_sample) / sigma | ||
|
||
dt = sigma_down - sigma | ||
|
||
prev_sample = sample + derivative * dt | ||
|
||
device = model_output.device if torch.is_tensor(model_output) else "cpu" | ||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) | ||
prev_sample = prev_sample + noise * sigma_up | ||
|
||
if not return_dict: | ||
return (prev_sample,) | ||
|
||
return EulerAncestralDiscreteSchedulerOutput( | ||
prev_sample=prev_sample, pred_original_sample=pred_original_sample | ||
) | ||
|
||
def add_noise( | ||
self, | ||
original_samples: torch.FloatTensor, | ||
noise: torch.FloatTensor, | ||
timesteps: torch.FloatTensor, | ||
) -> torch.FloatTensor: | ||
# Make sure sigmas and timesteps have the same device and dtype as original_samples | ||
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) | ||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): | ||
# mps does not support float64 | ||
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) | ||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32) | ||
else: | ||
self.timesteps = self.timesteps.to(original_samples.device) | ||
timesteps = timesteps.to(original_samples.device) | ||
|
||
schedule_timesteps = self.timesteps | ||
|
||
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): | ||
deprecate( | ||
"timesteps as indices", | ||
"0.8.0", | ||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | ||
" `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to" | ||
" pass values from `scheduler.timesteps` as timesteps.", | ||
standard_warn=False, | ||
) | ||
step_indices = timesteps | ||
else: | ||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
|
||
sigma = self.sigmas[step_indices].flatten() | ||
while len(sigma.shape) < len(original_samples.shape): | ||
sigma = sigma.unsqueeze(-1) | ||
|
||
noisy_samples = original_samples + noise * sigma | ||
return noisy_samples | ||
|
||
def __len__(self): | ||
return self.config.num_train_timesteps |
Oops, something went wrong.