Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dance Diffusion] Add dance diffusion #803

Merged
merged 49 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
635f005
start
patrickvonplaten Oct 11, 2022
3103f9f
add more logic
patrickvonplaten Oct 11, 2022
2d6d178
Update src/diffusers/models/unet_2d_condition_flax.py
patrickvonplaten Oct 11, 2022
3f0a8f8
match weights
patrickvonplaten Oct 18, 2022
f882c33
up
patrickvonplaten Oct 18, 2022
a353452
up
patrickvonplaten Oct 18, 2022
37addaa
make model work
patrickvonplaten Oct 18, 2022
05b4a0b
making class more general, fixing missed file rename
natolambert Oct 18, 2022
1697eec
small fix
patrickvonplaten Oct 20, 2022
1a019c3
make new conversion work
patrickvonplaten Oct 20, 2022
a2bf35b
up
patrickvonplaten Oct 20, 2022
f7220cf
finalize conversion
patrickvonplaten Oct 20, 2022
5b1e292
up
patrickvonplaten Oct 20, 2022
a9f111b
first batch of variable renamings
natolambert Oct 20, 2022
8320ff6
remove c and c_prev var names
natolambert Oct 20, 2022
3cd030f
add mid and out block structure
natolambert Oct 20, 2022
0303a4d
add pipeline
patrickvonplaten Oct 24, 2022
c9ea2c1
Merge branch 'add_dance_diffusion' of https://github.com/huggingface/…
patrickvonplaten Oct 24, 2022
20dee8d
up
patrickvonplaten Oct 24, 2022
a5764dc
finish conversion
patrickvonplaten Oct 24, 2022
077406c
finish
patrickvonplaten Oct 24, 2022
031e9da
upload
patrickvonplaten Oct 24, 2022
e052890
more fixes
patrickvonplaten Oct 24, 2022
686ba12
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Oct 24, 2022
d800579
Apply suggestions from code review
patrickvonplaten Oct 24, 2022
21f13e5
add attr
patrickvonplaten Oct 24, 2022
aa5b3ed
fix
patrickvonplaten Oct 24, 2022
26ae749
up
patrickvonplaten Oct 24, 2022
49efa62
uP
patrickvonplaten Oct 24, 2022
a4abcf9
up
patrickvonplaten Oct 24, 2022
64a8805
finish tests
patrickvonplaten Oct 24, 2022
510d615
finish
patrickvonplaten Oct 24, 2022
9cf65c3
uP
patrickvonplaten Oct 24, 2022
6773edc
finish
patrickvonplaten Oct 24, 2022
78bfc9d
fix test
patrickvonplaten Oct 24, 2022
d1ec608
up
patrickvonplaten Oct 24, 2022
9ce38e6
naming consistency in tests
natolambert Oct 24, 2022
f4d3e59
Apply suggestions from code review
patrickvonplaten Oct 25, 2022
b763d80
remove hardcoded 16
patrickvonplaten Oct 25, 2022
e0744ee
Remove bogus
patrickvonplaten Oct 25, 2022
8539305
Merge branch 'add_dance_diffusion' of https://github.com/huggingface/…
patrickvonplaten Oct 25, 2022
6b2196a
fix some stuff
patrickvonplaten Oct 25, 2022
48648a4
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Oct 25, 2022
fbeeeaf
finish
patrickvonplaten Oct 25, 2022
4b0cc18
improve logging
patrickvonplaten Oct 25, 2022
5bea0a2
Merge branch 'add_dance_diffusion' of https://github.com/huggingface/…
patrickvonplaten Oct 25, 2022
58d7f16
docs
patrickvonplaten Oct 25, 2022
63c1e41
Merge branch 'add_dance_diffusion' of https://github.com/huggingface/…
patrickvonplaten Oct 25, 2022
cf79361
upload
patrickvonplaten Oct 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions scripts/convert_dance_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python3
from torch import nn
from audio_diffusion.models import DiffusionAttnUnet1D
import argparse
from copy import deepcopy
import torch
import os
import math


MODELS_MAP = {
"gwf-440k": {
'url': "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
'sample_rate': 48000,
'sample_size': 65536
},
"jmann-small-190k": {
'url': "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
'sample_rate': 48000,
'sample_size': 65536
},
"jmann-large-580k": {
'url': "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
'sample_rate': 48000,
'sample_size': 131072
},
"maestro-uncond-150k": {
'url': "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
'sample_rate': 16000,
'sample_size': 65536
},
"unlocked-uncond-250k": {
'url': "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
'sample_rate': 16000,
'sample_size': 65536
},
"honk-140k": {'url': "https://model-server.zqevans2.workers.dev/honk-140k.ckpt", 'sample_rate': 16000, 'sample_size': 65536}
}


def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return torch.atan2(sigma, alpha) / math.pi * 2


def get_crash_schedule(t):
sigma = torch.sin(t * math.pi / 2) ** 2
alpha = (1 - sigma ** 2) ** 0.5
return alpha_sigma_to_t(alpha, sigma)


class Object(object):
pass


class DiffusionUncond(nn.Module):
def __init__(self, global_args):
super().__init__()

self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
self.diffusion_ema = deepcopy(self.diffusion)
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)


def download(model_name):
pass


def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path):
assert model_name == args.model_path, f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
args.model_path = download(model_name)

sample_rate = MODELS_MAP[model_name]["sample_rate"]
sample_size = MODELS_MAP[model_name]["sample_size"]

config = Object()
config.sample_size = sample_size
config.sample_rate = sample_rate
config.latent_dim = 0

diffusion_model = DiffusionUncond(config)
diffusion_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
model = diffusion_model.eval()

steps = 100
step_index = 2

generator = torch.manual_seed(33)
noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
step_list = get_crash_schedule(t)

output = model.diffusion_ema(noise, step_list[step_index: step_index + 1])
assert output.abs().sum() - 4550.5430 < 1e-3

import ipdb; ipdb.set_trace()




if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()

main(args)
166 changes: 166 additions & 0 deletions src/diffusers/models/unet_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding
from .unet_blocks_1d import UNetMidBlock1D, get_down_block, get_up_block


@dataclass
class UNet1DOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
Hidden states output. Output of last layer of model.
"""

sample: torch.FloatTensor


class UNet1DModel(ModelMixin, ConfigMixin):
r"""
UNet1DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)

Parameters:
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
Input sample size.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(224, 448, 672, 896)`): Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
"""

@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 2,
out_channels: int = 2,
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
down_block_types: Tuple[str] = ["DownBlock1DNoSkip"] + 7 * ["DownBlock1D"] + 6 * ["AttnDownBlock1D"],
up_block_types: Tuple[str] = 6 * ["UpDownBlock1D"] + 7 * ["UpBlock1D"] + ["UpBlock1DNoSkip"],
block_out_channels: Tuple[int] = [128, 128, 256, 256] + [512] * 10,
):
super().__init__()

self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4

# time
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0])
timestep_input_dim = 2 * block_out_channels[0]

self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])

# down
output_channel = in_channels
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i] if i < len(down_block_types) - 1 else out_channels

down_block = get_down_block(
down_block_type,
c=input_channel,
c_prev=output_channel,
)
self.down_blocks.append(down_block)

# mid
self.mid_block = UNetMidBlock1D(
c=block_out_channels[-1],
c_prev=block_out_channels[-1],
)

# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]

up_block = get_up_block(
up_block_type,
c_prev=prev_output_channel,
c=output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet1DOutput`] instead of a plain tuple.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

Returns:
[`~models.unet_2d.UNet1DOutput`] or `tuple`: [`~models.unet_2d.UNet1DOutput`] if `return_dict` is True,
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)

# 2. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

down_block_res_samples += res_samples

# 3. mid
sample = self.mid_block(sample, emb)

# 4. up
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(sample, res_samples, emb)

if not return_dict:
return (sample,)

return UNet1DOutput(sample=sample)
Loading