Skip to content

Commit

Permalink
[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to …
Browse files Browse the repository at this point in the history
…return tuples (open-mmlab#334)

* add outputs for models

* add for pipelines

* finish schedulers

* better naming

* adapt tests as well

* replace dict access with . access

* make schedulers works

* finish

* correct readme

* make  bcp compatible

* up

* small fix

* finish

* more fixes

* more fixes

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>

* Update src/diffusers/models/vae.py

Co-authored-by: Pedro Cuenca <[email protected]>

* Adapt model outputs

* Apply more suggestions

* finish examples

* correct

Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
3 people authored Sep 5, 2022
1 parent daddd98 commit cc59b05
Show file tree
Hide file tree
Showing 39 changed files with 893 additions and 247 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```

**Note**: If you don't want to use the token, you can also simply download the model weights
Expand All @@ -101,7 +101,7 @@ pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```

If you are limited by GPU memory, you might want to consider using the model in `fp16`.
Expand All @@ -117,7 +117,7 @@ pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```

Finally, if you wish to use a different scheduler, you can simply instantiate
Expand All @@ -143,7 +143,7 @@ pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]

image.save("astronaut_rides_horse.png")
```
Expand Down Expand Up @@ -184,7 +184,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"

with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images

images[0].save("fantasy_landscape.png")
```
Expand Down Expand Up @@ -228,7 +228,7 @@ pipe = pipe.to(device)

prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images

images[0].save("cat_on_bench.png")
```
Expand Down Expand Up @@ -260,7 +260,7 @@ ldm = DiffusionPipeline.from_pretrained(model_id)

# run pipeline in inference (sample random noise and denoise)
prompt = "A painting of a squirrel eating a burger"
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"]
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images

# save images
for idx, image in enumerate(images):
Expand All @@ -277,7 +277,7 @@ model_id = "google/ddpm-celebahq-256"
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference

# run pipeline in inference (sample random noise and denoise)
image = ddpm()["sample"]
image = ddpm().images

# save image
image[0].save("ddpm_generated_image.png")
Expand Down
2 changes: 1 addition & 1 deletion examples/textual_inversion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pipe = pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch
prompt = "A <cat-toy> backpack"

with autocast("cuda"):
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0]
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

image.save("cat-backpack.png")
```
4 changes: 2 additions & 2 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def main():
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).sample().detach()
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = latents * 0.18215

# Sample noise that we'll add to the latents
Expand All @@ -515,7 +515,7 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"]
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def transforms(examples):

with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps)["sample"]
noise_pred = model(noisy_images, timesteps).sample
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)

Expand Down Expand Up @@ -174,7 +174,7 @@ def transforms(examples):

generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
logits = model(noise, time_step)["sample"]
logits = model(noise, time_step).sample

assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from pathlib import Path
from typing import Optional

from diffusers import DiffusionPipeline
from huggingface_hub import HfFolder, Repository, whoami

from .pipeline_utils import DiffusionPipeline
from .utils import is_modelcards_available, logging


Expand Down
27 changes: 22 additions & 5 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block


@dataclass
class UNet2DOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states output. Output of last layer of model.
"""

sample: torch.FloatTensor


class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
Expand Down Expand Up @@ -118,8 +131,11 @@ def __init__(
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]:
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
Expand Down Expand Up @@ -181,6 +197,7 @@ def forward(
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps

output = {"sample": sample}
if not return_dict:
return (sample,)

return output
return UNet2DOutput(sample=sample)
23 changes: 19 additions & 4 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block


@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""

sample: torch.FloatTensor


class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
Expand Down Expand Up @@ -125,7 +138,8 @@ def forward(
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
Expand Down Expand Up @@ -183,6 +197,7 @@ def forward(
sample = self.conv_act(sample)
sample = self.conv_out(sample)

output = {"sample": sample}
if not return_dict:
return (sample,)

return output
return UNet2DConditionOutput(sample=sample)
104 changes: 87 additions & 17 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,56 @@
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block


@dataclass
class DecoderOutput(BaseOutput):
"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model.
"""

sample: torch.FloatTensor


@dataclass
class VQEncoderOutput(BaseOutput):
"""
Output of VQModel encoding method.
Args:
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Encoded output sample of the model. Output of the last layer of the model.
"""

latents: torch.FloatTensor


@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""

latent_dist: "DiagonalGaussianDistribution"


class Encoder(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -369,26 +411,40 @@ def __init__(
act_fn=act_fn,
)

def encode(self, x):
def encode(self, x, return_dict: bool = True):
h = self.encoder(x)
h = self.quant_conv(h)
return h

def decode(self, h, force_not_quantize=False):
if not return_dict:
return (h,)

return VQEncoderOutput(latents=h)

def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec

def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample
h = self.encode(x)
dec = self.decode(h)
return dec
h = self.encode(x).latents
dec = self.decode(h).sample

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)


class AutoencoderKL(ModelMixin, ConfigMixin):
Expand Down Expand Up @@ -431,23 +487,37 @@ def __init__(
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)

def encode(self, x):
def encode(self, x, return_dict: bool = True):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior

def decode(self, z):
if not return_dict:
return (posterior,)

return AutoencoderKLOutput(latent_dist=posterior)

def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec

def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

def forward(
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample
posterior = self.encode(x)
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec
dec = self.decode(z).sample

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)
Loading

0 comments on commit cc59b05

Please sign in to comment.