Skip to content

Commit

Permalink
merge unet-rl formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert committed Jun 20, 2022
2 parents 49718b4 + 77aadfe commit 4497e78
Show file tree
Hide file tree
Showing 41 changed files with 753 additions and 3,355 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
# Make marked copies of snippets of codes conform to the original

fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_copies.py --fix_and_overwrite

# Run tests for the library

Expand Down
31 changes: 24 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,32 @@ More precisely, 🤗 Diffusers offers:
**Models**: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to *denoise* a noisy input to an image.
*Examples*: UNet, Conditioned UNet, 3D UNet, Transformer UNet

![model_diff_1_50](https://user-images.githubusercontent.com/23423619/171610307-dab0cd8b-75da-4d4e-9f5a-5922072e2bb5.png)

<p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174349667-04e9e485-793b-429a-affe-096e8199ad5b.png" width="800"/>
<br>
<em> Figure from DDPM paper (https://arxiv.org/abs/2006.11239). </em>
<p>

**Schedulers**: Algorithm class for both **inference** and **training**.
The class provides functionality to compute previous image according to alpha, beta schedule as well as predict noise for training.
*Examples*: [DDPM](https://arxiv.org/abs/2006.11239), [DDIM](https://arxiv.org/abs/2010.02502), [PNDM](https://arxiv.org/abs/2202.09778), [DEIS](https://arxiv.org/abs/2204.13902)

![sampling](https://user-images.githubusercontent.com/23423619/171608981-3ad05953-a684-4c82-89f8-62a459147a07.png)
![training](https://user-images.githubusercontent.com/23423619/171608964-b3260cce-e6b4-4841-959d-7d8ba4b8d1b2.png)
<p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174349706-53d58acc-a4d1-4cda-b3e8-432d9dc7ad38.png" width="800"/>
<br>
<em> Sampling and training algorithms. Figure from DDPM paper (https://arxiv.org/abs/2006.11239). </em>
<p>


**Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ...
*Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2

![imagen](https://user-images.githubusercontent.com/23423619/171609001-c3f2c1c9-f597-4a16-9843-749bf3f9431c.png)

<p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174348898-481bd7c2-5457-4830-89bc-f0907756f64c.jpeg" width="550"/>
<br>
<em> Figure from ImageGen (https://imagen.research.google/). </em>
<p>

## Philosophy

- Readability and clarity is prefered over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code design. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper.
Expand Down Expand Up @@ -147,7 +159,8 @@ eta = 0.0 # <- deterministic sampling

for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
orig_t = noise_scheduler.get_orig_t(t, num_inference_steps)
orig_t = len(noise_scheduler) // num_inference_steps * t

with torch.inference_mode():
residual = unet(image, orig_t)

Expand All @@ -173,6 +186,10 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
```

#### **Examples for other modalities:**

[Diffuser](https://diffusion-planning.github.io/) for planning in reinforcement learning: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1TmBmlYeKUZSkUZoJqfBmaicVTKx6nN1R?usp=sharing)

### 2. `diffusers` as a collection of popular Diffusion systems (GLIDE, Dalle, ...)

For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
Expand Down
14 changes: 11 additions & 3 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_transformers_available


__version__ = "0.0.4"

from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
from .pipelines import BDDM, DDIM, DDPM, PNDM
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler


if is_transformers_available():
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .pipelines import GLIDE, GradTTS, LatentDiffusion
else:
from .utils.dummy_transformers_objects import *
2 changes: 1 addition & 1 deletion src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def to_json_string(self) -> str:
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
config_dict = self._internal_dict
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

def to_json_file(self, json_file_path: Union[str, os.PathLike]):
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _find_mismatched_keys(
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")

if len(unexpected_keys) > 0:
logger.warninging(
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
Expand All @@ -502,7 +502,7 @@ def _find_mismatched_keys(
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warninging(
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
Expand All @@ -521,7 +521,7 @@ def _find_mismatched_keys(
for key, shape1, shape2 in mismatched_keys
]
)
logger.warninging(
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,14 @@ def __init__(
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

def forward(self, x, t):
def forward(self, x, timesteps):
assert x.shape[2] == x.shape[3] == self.resolution

if not torch.is_tensor(t):
t = torch.tensor([t], dtype=torch.long, device=x.device)
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)

# timestep embedding
temb = get_timestep_embedding(t, self.ch)
temb = get_timestep_embedding(timesteps, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/unet_grad_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=
self.final_block = Block(dim, dim)
self.final_conv = torch.nn.Conv2d(dim, 1, 1)

def forward(self, x, mask, mu, t, spk=None):
def forward(self, x, timesteps, mu, mask, spk=None):
if self.n_spks > 1:
# Get speaker embedding
spk = self.spk_emb(spk)

if not isinstance(spk, type(None)):
s = self.spk_mlp(spk)

t = self.time_pos_emb(t, scale=self.pe_scale)
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
t = self.mlp(t)

if self.n_spks < 2:
Expand Down
Loading

0 comments on commit 4497e78

Please sign in to comment.