diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index a0358c32e4ec..43a4fc4bf11e 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -92,5 +92,7 @@ title: "Stable Diffusion" - local: api/pipelines/stochastic_karras_ve title: "Stochastic Karras VE" + - local: api/pipelines/dance_diffusion + title: "Dance Diffusion" title: "Pipelines" title: "API" diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index c92fdccb8333..c3f5e65edfbd 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -22,6 +22,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## UNet2DOutput [[autodoc]] models.unet_2d.UNet2DOutput +## UNet1DModel +[[autodoc]] UNet1DModel + ## UNet2DModel [[autodoc]] UNet2DModel diff --git a/docs/source/api/pipelines/dance_diffusion.mdx b/docs/source/api/pipelines/dance_diffusion.mdx new file mode 100644 index 000000000000..4d969bf6f032 --- /dev/null +++ b/docs/source/api/pipelines/dance_diffusion.mdx @@ -0,0 +1,33 @@ + + +# Dance Diffusion + +## Overview + +[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) by Zach Evans. + +Dance Diffusion is the first in a suite of generative audio tools for producers and musicians to be released by Harmonai. +For more info or to get involved in the development of these tools, please visit https://harmonai.org and fill out the form on the front page. + +The original codebase of this implementation can be found [here](https://github.com/Harmonai-org/sample-generator). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_dance_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py) | *Unconditional Audio Generation* | - | + + +## DanceDiffusionPipeline +[[autodoc]] DanceDiffusionPipeline + - __call__ diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 12a6b5c587bc..80fbfdcb2321 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -95,6 +95,10 @@ Original paper can be found [here](https://arxiv.org/abs/2011.13456). [[autodoc]] ScoreSdeVeScheduler +#### improved pseudo numerical methods for diffusion models (iPNDM) + +Original implementation can be found [here](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296). + #### variance preserving stochastic differential equation (SDE) scheduler Original paper can be found [here](https://arxiv.org/abs/2011.13456). diff --git a/scripts/convert_dance_diffusion_to_diffusers.py b/scripts/convert_dance_diffusion_to_diffusers.py new file mode 100755 index 000000000000..db183bce9262 --- /dev/null +++ b/scripts/convert_dance_diffusion_to_diffusers.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python3 +import argparse +import math +import os +from copy import deepcopy + +import torch +from torch import nn + +from audio_diffusion.models import DiffusionAttnUnet1D +from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel +from diffusion import sampling + + +MODELS_MAP = { + "gwf-440k": { + "url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt", + "sample_rate": 48000, + "sample_size": 65536, + }, + "jmann-small-190k": { + "url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt", + "sample_rate": 48000, + "sample_size": 65536, + }, + "jmann-large-580k": { + "url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt", + "sample_rate": 48000, + "sample_size": 131072, + }, + "maestro-uncond-150k": { + "url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt", + "sample_rate": 16000, + "sample_size": 65536, + }, + "unlocked-uncond-250k": { + "url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt", + "sample_rate": 16000, + "sample_size": 65536, + }, + "honk-140k": { + "url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt", + "sample_rate": 16000, + "sample_size": 65536, + }, +} + + +def alpha_sigma_to_t(alpha, sigma): + """Returns a timestep, given the scaling factors for the clean image and for + the noise.""" + return torch.atan2(sigma, alpha) / math.pi * 2 + + +def get_crash_schedule(t): + sigma = torch.sin(t * math.pi / 2) ** 2 + alpha = (1 - sigma**2) ** 0.5 + return alpha_sigma_to_t(alpha, sigma) + + +class Object(object): + pass + + +class DiffusionUncond(nn.Module): + def __init__(self, global_args): + super().__init__() + + self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4) + self.diffusion_ema = deepcopy(self.diffusion) + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + +def download(model_name): + url = MODELS_MAP[model_name]["url"] + os.system(f"wget {url} ./") + + return f"./{model_name}.ckpt" + + +DOWN_NUM_TO_LAYER = { + "1": "resnets.0", + "2": "attentions.0", + "3": "resnets.1", + "4": "attentions.1", + "5": "resnets.2", + "6": "attentions.2", +} +UP_NUM_TO_LAYER = { + "8": "resnets.0", + "9": "attentions.0", + "10": "resnets.1", + "11": "attentions.1", + "12": "resnets.2", + "13": "attentions.2", +} +MID_NUM_TO_LAYER = { + "1": "resnets.0", + "2": "attentions.0", + "3": "resnets.1", + "4": "attentions.1", + "5": "resnets.2", + "6": "attentions.2", + "8": "resnets.3", + "9": "attentions.3", + "10": "resnets.4", + "11": "attentions.4", + "12": "resnets.5", + "13": "attentions.5", +} +DEPTH_0_TO_LAYER = { + "0": "resnets.0", + "1": "resnets.1", + "2": "resnets.2", + "4": "resnets.0", + "5": "resnets.1", + "6": "resnets.2", +} + +RES_CONV_MAP = { + "skip": "conv_skip", + "main.0": "conv_1", + "main.1": "group_norm_1", + "main.3": "conv_2", + "main.4": "group_norm_2", +} + +ATTN_MAP = { + "norm": "group_norm", + "qkv_proj": ["query", "key", "value"], + "out_proj": ["proj_attn"], +} + + +def convert_resconv_naming(name): + if name.startswith("skip"): + return name.replace("skip", RES_CONV_MAP["skip"]) + + # name has to be of format main.{digit} + if not name.startswith("main."): + raise ValueError(f"ResConvBlock error with {name}") + + return name.replace(name[:6], RES_CONV_MAP[name[:6]]) + + +def convert_attn_naming(name): + for key, value in ATTN_MAP.items(): + if name.startswith(key) and not isinstance(value, list): + return name.replace(key, value) + elif name.startswith(key): + return [name.replace(key, v) for v in value] + raise ValueError(f"Attn error with {name}") + + +def rename(input_string, max_depth=13): + string = input_string + + if string.split(".")[0] == "timestep_embed": + return string.replace("timestep_embed", "time_proj") + + depth = 0 + if string.startswith("net.3."): + depth += 1 + string = string[6:] + elif string.startswith("net."): + string = string[4:] + + while string.startswith("main.7."): + depth += 1 + string = string[7:] + + if string.startswith("main."): + string = string[5:] + + # mid block + if string[:2].isdigit(): + layer_num = string[:2] + string_left = string[2:] + else: + layer_num = string[0] + string_left = string[1:] + + if depth == max_depth: + new_layer = MID_NUM_TO_LAYER[layer_num] + prefix = "mid_block" + elif depth > 0 and int(layer_num) < 7: + new_layer = DOWN_NUM_TO_LAYER[layer_num] + prefix = f"down_blocks.{depth}" + elif depth > 0 and int(layer_num) > 7: + new_layer = UP_NUM_TO_LAYER[layer_num] + prefix = f"up_blocks.{max_depth - depth - 1}" + elif depth == 0: + new_layer = DEPTH_0_TO_LAYER[layer_num] + prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0" + + if not string_left.startswith("."): + raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.") + + string_left = string_left[1:] + + if "resnets" in new_layer: + string_left = convert_resconv_naming(string_left) + elif "attentions" in new_layer: + new_string_left = convert_attn_naming(string_left) + string_left = new_string_left + + if not isinstance(string_left, list): + new_string = prefix + "." + new_layer + "." + string_left + else: + new_string = [prefix + "." + new_layer + "." + s for s in string_left] + return new_string + + +def rename_orig_weights(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if k.endswith("kernel"): + # up- and downsample layers, don't have trainable weights + continue + + new_k = rename(k) + + # check if we need to transform from Conv => Linear for attention + if isinstance(new_k, list): + new_state_dict = transform_conv_attns(new_state_dict, new_k, v) + else: + new_state_dict[new_k] = v + + return new_state_dict + + +def transform_conv_attns(new_state_dict, new_k, v): + if len(new_k) == 1: + if len(v.shape) == 3: + # weight + new_state_dict[new_k[0]] = v[:, :, 0] + else: + # bias + new_state_dict[new_k[0]] = v + else: + # qkv matrices + trippled_shape = v.shape[0] + single_shape = trippled_shape // 3 + for i in range(3): + if len(v.shape) == 3: + new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0] + else: + new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape] + return new_state_dict + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model_name = args.model_path.split("/")[-1].split(".")[0] + if not os.path.isfile(args.model_path): + assert ( + model_name == args.model_path + ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" + args.model_path = download(model_name) + + sample_rate = MODELS_MAP[model_name]["sample_rate"] + sample_size = MODELS_MAP[model_name]["sample_size"] + + config = Object() + config.sample_size = sample_size + config.sample_rate = sample_rate + config.latent_dim = 0 + + diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate) + diffusers_state_dict = diffusers_model.state_dict() + + orig_model = DiffusionUncond(config) + orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"]) + orig_model = orig_model.diffusion_ema.eval() + orig_model_state_dict = orig_model.state_dict() + renamed_state_dict = rename_orig_weights(orig_model_state_dict) + + renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys()) + diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys()) + + assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}" + assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}" + + for key, value in renamed_state_dict.items(): + assert ( + diffusers_state_dict[key].squeeze().shape == value.squeeze().shape + ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" + if key == "time_proj.weight": + value = value.squeeze() + + diffusers_state_dict[key] = value + + diffusers_model.load_state_dict(diffusers_state_dict) + + steps = 100 + seed = 33 + + diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps) + + generator = torch.manual_seed(seed) + noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device) + + t = torch.linspace(1, 0, steps + 1, device=device)[:-1] + step_list = get_crash_schedule(t) + + pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler) + + generator = torch.manual_seed(33) + audio = pipe(num_inference_steps=steps, generator=generator).audios + + generated = sampling.iplms_sample(orig_model, noise, step_list, {}) + generated = generated.clamp(-1, 1) + + diff_sum = (generated - audio).abs().sum() + diff_max = (generated - audio).abs().max() + + if args.save: + pipe.save_pretrained(args.checkpoint_path) + + print("Diff sum", diff_sum) + print("Diff max", diff_max) + + assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/" + + print(f"Conversion for {model_name} successful!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not." + ) + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8a0dbc9de125..2c531cf8cee0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, @@ -29,10 +29,19 @@ get_scheduler, ) from .pipeline_utils import DiffusionPipeline - from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline + from .pipelines import ( + DanceDiffusionPipeline, + DDIMPipeline, + DDPMPipeline, + KarrasVePipeline, + LDMPipeline, + PNDMPipeline, + ScoreSdeVePipeline, + ) from .schedulers import ( DDIMScheduler, DDPMScheduler, + IPNDMScheduler, KarrasVeScheduler, PNDMScheduler, SchedulerMixin, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1242ad6fca7f..c5d53b2feb4b 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,6 +16,7 @@ if is_torch_available(): + from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .vae import AutoencoderKL, VQModel diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 06b814e2bbcd..35715e17fc47 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -101,17 +101,28 @@ def forward(self, timesteps): class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" - def __init__(self, embedding_size: int = 256, scale: float = 1.0): + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos - # to delete later - self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + if set_W_to_weight: + # to delete later + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - self.weight = self.W + self.weight = self.W def forward(self, x): - x = torch.log(x) + if self.log: + x = torch.log(x) + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi - out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py new file mode 100644 index 000000000000..73cf0785ec1e --- /dev/null +++ b/src/diffusers/models/unet_1d.py @@ -0,0 +1,172 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..utils import BaseOutput +from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block + + +@dataclass +class UNet1DOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet1DModel(ModelMixin, ConfigMixin): + r""" + UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`False`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(32, 32, 64)`): Tuple of block output channels. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 65536, + sample_rate: Optional[int] = None, + in_channels: int = 2, + out_channels: int = 2, + extra_in_channels: int = 0, + time_embedding_type: str = "fourier", + freq_shift: int = 0, + flip_sin_to_cos: bool = True, + use_timestep_embedding: bool = False, + down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), + mid_block_type: str = "UNetMidBlock1D", + up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + block_out_channels: Tuple[int] = (32, 32, 64), + ): + super().__init__() + + self.sample_size = sample_size + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.out_block = None + + # down + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + + if i == 0: + input_channel += extra_in_channels + + down_block = get_down_block( + down_block_type, + in_channels=input_channel, + out_channels=output_channel, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type=mid_block_type, + mid_channels=block_out_channels[-1], + in_channels=block_out_channels[-1], + out_channels=None, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels + + up_block = get_up_block( + up_block_type, + in_channels=prev_output_channel, + out_channels=output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # TODO(PVP, Nathan) placeholder for RL application to be merged shortly + # Totally fine to add another layer with a if statement - no need for nn.Identity here + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[UNet1DOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): `(batch_size, sample_size, num_channels)` noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # 1. time + if len(timestep.shape) == 0: + timestep = timestep[None] + + timestep_embed = self.time_proj(timestep)[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]) + + # 2. down + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) + down_block_res_samples += res_samples + + # 3. mid + sample = self.mid_block(sample) + + # 4. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block(sample, res_samples) + + if not return_dict: + return (sample,) + + return UNet1DOutput(sample=sample) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py new file mode 100644 index 000000000000..9009071d1e78 --- /dev/null +++ b/src/diffusers/models/unet_1d_blocks.py @@ -0,0 +1,384 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torch.nn.functional as F +from torch import nn + + +_kernels = { + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states): + hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + weight[indices, indices] = self.kernel.to(weight) + return F.conv1d(hidden_states, weight, stride=2) + + +class Upsample1d(nn.Module): + def __init__(self, kernel="linear", pad_mode="reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states): + hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + weight[indices, indices] = self.kernel.to(weight) + return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) + + +class SelfAttention1d(nn.Module): + def __init__(self, in_channels, n_head=1, dropout_rate=0.0): + super().__init__() + self.channels = in_channels + self.group_norm = nn.GroupNorm(1, num_channels=in_channels) + self.num_heads = n_head + + self.query = nn.Linear(self.channels, self.channels) + self.key = nn.Linear(self.channels, self.channels) + self.value = nn.Linear(self.channels, self.channels) + + self.proj_attn = nn.Linear(self.channels, self.channels, 1) + + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel_dim, seq = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) + + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = torch.softmax(attention_scores, dim=-1) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.dropout(hidden_states) + + output = hidden_states + residual + + return output + + +class ResConvBlock(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, is_last=False): + super().__init__() + self.is_last = is_last + self.has_conv_skip = in_channels != out_channels + + if self.has_conv_skip: + self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) + + self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) + self.group_norm_1 = nn.GroupNorm(1, mid_channels) + self.gelu_1 = nn.GELU() + self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) + + if not self.is_last: + self.group_norm_2 = nn.GroupNorm(1, out_channels) + self.gelu_2 = nn.GELU() + + def forward(self, hidden_states): + residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states + + hidden_states = self.conv_1(hidden_states) + hidden_states = self.group_norm_1(hidden_states) + hidden_states = self.gelu_1(hidden_states) + hidden_states = self.conv_2(hidden_states) + + if not self.is_last: + hidden_states = self.group_norm_2(hidden_states) + hidden_states = self.gelu_2(hidden_states) + + output = hidden_states + residual + return output + + +def get_down_block(down_block_type, out_channels, in_channels): + if down_block_type == "DownBlock1D": + return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "AttnDownBlock1D": + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "DownBlock1DNoSkip": + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block(up_block_type, in_channels, out_channels): + if up_block_type == "UpBlock1D": + return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "AttnUpBlock1D": + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "UpBlock1DNoSkip": + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels): + if mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + raise ValueError(f"{mid_block_type} does not exist.") + + +class UNetMidBlock1D(nn.Module): + def __init__(self, mid_channels, in_channels, out_channels=None): + super().__init__() + + out_channels = in_channels if out_channels is None else out_channels + + # there is always at least one resnet + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.up = Upsample1d(kernel="cubic") + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states): + hidden_states = self.down(hidden_states) + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class AttnDownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1D(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.down(hidden_states) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1DNoSkip(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = torch.cat([hidden_states, temb], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class AttnUpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, res_hidden_states_tuple): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1DNoSkip(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, res_hidden_states_tuple): + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index b580b4ed4056..641c253c86f8 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -21,7 +21,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_2d_blocks.py similarity index 100% rename from src/diffusers/models/unet_blocks.py rename to src/diffusers/models/unet_2d_blocks.py diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py similarity index 100% rename from src/diffusers/models/unet_blocks_flax.py rename to src/diffusers/models/unet_2d_blocks_flax.py diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index a65e238969dd..d271b78a6525 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -22,7 +22,7 @@ from ..modeling_utils import ModelMixin from ..utils import BaseOutput, logging from .embeddings import TimestepEmbedding, Timesteps -from .unet_blocks import ( +from .unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 5411855f79d5..f0e721826bd7 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -23,7 +23,7 @@ from ..modeling_flax_utils import FlaxModelMixin from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps -from .unet_blocks_flax import ( +from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxCrossAttnUpBlock2D, FlaxDownBlock2D, diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index f1c0fc1416ee..5f5a47dada0f 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput -from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 894505654e47..5c08a2db39fa 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -93,6 +93,20 @@ class ImagePipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] +@dataclass +class AudioPipelineOutput(BaseOutput): + """ + Output class for audio pipelines. + + Args: + audios (`np.ndarray`) + List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the + denoised audio samples of the diffusion pipeline. + """ + + audios: np.ndarray + + class DiffusionPipeline(ConfigMixin): r""" Base class for all models. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3c4f877fa830..b3124af39077 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -2,6 +2,7 @@ if is_torch_available(): + from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline diff --git a/src/diffusers/pipelines/dance_diffusion/__init__.py b/src/diffusers/pipelines/dance_diffusion/__init__.py new file mode 100644 index 000000000000..2ad34fc52aaa --- /dev/null +++ b/src/diffusers/pipelines/dance_diffusion/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline_dance_diffusion import DanceDiffusionPipeline diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py new file mode 100644 index 000000000000..dfbb980be8d8 --- /dev/null +++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -0,0 +1,113 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +from typing import Optional, Tuple, Union + +import torch + +from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DanceDiffusionPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`IPNDMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 100, + generator: Optional[torch.Generator] = None, + sample_length_in_s: Optional[float] = None, + return_dict: bool = True, + ) -> Union[AudioPipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of audio samples to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at + the expense of slower inference. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if sample_length_in_s is None: + sample_length_in_s = self.unet.sample_size / self.unet.sample_rate + + sample_size = sample_length_in_s * self.unet.sample_rate + + down_scale_factor = 2 ** len(self.unet.up_blocks) + if sample_size < 3 * down_scale_factor: + raise ValueError( + f"{sample_length_in_s} is too small. Make sure it's bigger or equal to" + f" {3 * down_scale_factor / self.unet.sample_rate}." + ) + + original_sample_size = int(sample_size) + if sample_size % down_scale_factor != 0: + sample_size = ((sample_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor + logger.info( + f"{sample_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled" + f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising" + " process." + ) + sample_size = int(sample_size) + + audio = torch.randn((batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, device=audio.device) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(audio, t).sample + + # 2. compute previous image: x_t -> t_t-1 + audio = self.scheduler.step(model_output, t, audio).prev_sample + + audio = audio.clamp(-1, 1).cpu().numpy() + + audio = audio[:, :, :original_sample_size] + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 91026b7d487f..f179acde09e2 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,6 +19,7 @@ if is_torch_available(): from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler + from .scheduling_ipndm import IPNDMScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py new file mode 100644 index 000000000000..fb413a280530 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -0,0 +1,152 @@ +# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class IPNDMScheduler(SchedulerMixin, ConfigMixin): + """ + Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion + [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296) + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + """ + + @register_to_config + def __init__(self, num_train_timesteps: int = 1000): + # set `betas`, `alphas`, `timesteps` + self.set_timesteps(num_train_timesteps) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + # running values + self.ets = [] + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] + steps = torch.cat([steps, torch.tensor([0.0])]) + + self.betas = torch.sin(steps * math.pi / 2) ** 2 + self.alphas = (1.0 - self.betas**2) ** 0.5 + + timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] + self.timesteps = timesteps.to(device) + + self.ets = [] + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep_index = (self.timesteps == timestep).nonzero().item() + prev_timestep_index = timestep_index + 1 + + ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] + self.ets.append(ets) + + if len(self.ets) == 1: + ets = self.ets[-1] + elif len(self.ets) == 2: + ets = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): + alpha = self.alphas[timestep_index] + sigma = self.betas[timestep_index] + + next_alpha = self.alphas[prev_timestep_index] + next_sigma = self.betas[prev_timestep_index] + + pred = (sample - sigma * ets) / max(alpha, 1e-8) + prev_sample = next_alpha * pred + ets * next_sigma + + return prev_sample + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index ee748f5b1d6a..93834c76550b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -34,6 +34,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UNet1DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] @@ -122,6 +137,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DanceDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -242,6 +272,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class IPNDMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class KarrasVeScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py new file mode 100644 index 000000000000..b1aa33fcff7e --- /dev/null +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -0,0 +1,101 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch + +from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel +from diffusers.utils import slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class PipelineFastTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_unet(self): + torch.manual_seed(0) + model = UNet1DModel( + block_out_channels=(32, 32, 64), + extra_in_channels=16, + sample_size=512, + sample_rate=16_000, + in_channels=2, + out_channels=2, + down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"], + up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"], + ) + return model + + def test_dance_diffusion(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + scheduler = IPNDMScheduler() + + pipe = DanceDiffusionPipeline(unet=self.dummy_unet, scheduler=scheduler) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe(generator=generator, num_inference_steps=4) + audio = output.audios + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe(generator=generator, num_inference_steps=4, return_dict=False) + audio_from_tuple = output[0] + + audio_slice = audio[0, -3:, -3:] + audio_from_tuple_slice = audio_from_tuple[0, -3:, -3:] + + assert audio.shape == (1, 2, self.dummy_unet.sample_size) + expected_slice = np.array([-0.7265, 1.0000, -0.8388, 0.1175, 0.9498, -1.0000]) + assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(audio_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + +@slow +@require_torch_gpu +class PipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_dance_diffusion(self): + device = torch_device + + pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k") + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe(generator=generator, num_inference_steps=100, sample_length_in_s=4.096) + audio = output.audios + + audio_slice = audio[0, -3:, -3:] + + assert audio.shape == (1, 2, pipe.unet.sample_size) + expected_slice = np.array([-0.1576, -0.1526, -0.127, -0.2699, -0.2762, -0.2487]) + assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/test_models_unet_1d.py b/tests/test_models_unet_1d.py new file mode 100644 index 000000000000..c274ce419211 --- /dev/null +++ b/tests/test_models_unet_1d.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import UNet1DModel +from diffusers.utils import slow, torch_device + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class UnetModel1DTests(unittest.TestCase): + @slow + def test_unet_1d_maestro(self): + model_id = "harmonai/maestro-150k" + model = UNet1DModel.from_pretrained(model_id, subfolder="unet") + model.to(torch_device) + + sample_size = 65536 + noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device) + timestep = torch.tensor([1]).to(torch_device) + + with torch.no_grad(): + output = model(noise, timestep).sample + + output_sum = output.abs().sum() + output_max = output.abs().max() + + assert (output_sum - 224.0896).abs() < 4e-2 + assert (output_max - 0.0607).abs() < 4e-4 diff --git a/tests/test_models_unet.py b/tests/test_models_unet_2d.py similarity index 99% rename from tests/test_models_unet.py rename to tests/test_models_unet_2d.py index 7a16a8592644..393d47887e6d 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet_2d.py @@ -29,7 +29,7 @@ torch.backends.cuda.matmul.allow_tf32 = False -class UnetModelTests(ModelTesterMixin, unittest.TestCase): +class Unet2DModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel @property diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index c3d4b9bc76f9..899870a74a3f 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -19,7 +19,14 @@ import numpy as np import torch -from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + IPNDMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + ScoreSdeVeScheduler, +) torch.backends.cuda.matmul.allow_tf32 = False @@ -203,6 +210,10 @@ def recursive_check(tuple_object, dict_object): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", 50) + timestep = 0 + if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler: + timestep = 1 + for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -215,14 +226,14 @@ def recursive_check(tuple_object, dict_object): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - outputs_dict = scheduler.step(residual, 0, sample, **kwargs) + outputs_dict = scheduler.step(residual, timestep, sample, **kwargs) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs) + outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs) recursive_check(outputs_tuple, outputs_dict) @@ -901,3 +912,157 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3 + + +class IPNDMSchedulerTest(SchedulerCommonTest): + scheduler_classes = (IPNDMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = {"num_train_timesteps": 1000} + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + scheduler.ets = dummy_past_residuals[:] + + if time_step is None: + time_step = scheduler.timesteps[len(scheduler.timesteps) // 2] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residuals + new_scheduler.ets = dummy_past_residuals[:] + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + pass + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.ets = dummy_past_residuals[:] + + if time_step is None: + time_step = scheduler.timesteps[len(scheduler.timesteps) // 2] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + # copy over dummy past residuals + new_scheduler.set_timesteps(num_inference_steps) + + # copy over dummy past residual (must be after setting timesteps) + new_scheduler.ets = dummy_past_residuals[:] + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + scheduler.ets = dummy_past_residuals[:] + + time_step_0 = scheduler.timesteps[5] + time_step_1 = scheduler.timesteps[6] + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps, time_step=None) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward(num_inference_steps=num_inference_steps, time_step=None) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 2540529) < 10