Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Modular Diffusers #9672

Open
wants to merge 68 commits into
base: main
Choose a base branch
from
Open

The Modular Diffusers #9672

wants to merge 68 commits into from

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Oct 14, 2024

Getting Started with Modular Diffusers

With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers let you:

Write Only What's New: You won't need to rewrite the entire pipeline from scratch. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities.

Assemble Like LEGO®: You can mix and match blocks in flexible ways. This allows you to write dedicated blocks for specific workflows, and then assemble different blocks into a pipeline that that can be used more conveniently for multiple workflows. Here we will walk you through how to use a pipeline like this we built with Modular diffusers! In later sections, we will also go over how to assemble and build new pipelines!

Quick Start with StableDiffusionXLAutoPipeline

from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline, ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
components.enable_auto_cpu_offload(device="cuda:0")

# Create pipeline
auto_pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
auto_pipe.update_states(**components.components)

Auto Workflow Selection

The pipeline automatically adapts to your inputs:

  • Basic text-to-image: Just provide a prompt
  • Image-to-image: Add an image input
  • Inpainting: Add both image and mask_image
  • ControlNet: Add a control_image
  • And more!

Auto Documentations

We care a great deal about documentation here at Diffusers, and Modular Diffusers carries this mission forward. All our pipeline blocks comes with complete docstrings that automatically compose as you build your pipelines. This means

  • Every pipeline you build with Modular diffusers come with complete documentation automatically
  • Input/output signatures are dynamically generated, same goes for components and configurations
  • Parameter descriptions and types are included
  • Block relationships and dependencies are documented as well

inspect your pipeline

# get pipeline info components/configurations/pipeline blocks/ docstring
print(auto_pipe)
see an example of output
ModularPipeline:
==============================

Pipeline Block:
--------------
StableDiffusionXLAutoPipeline
 (Class: SequentialPipelineBlocks)
  • text_encoder (StableDiffusionXLTextEncoderStep)
  • ip_adapter (StableDiffusionXLAutoIPAdapterStep)
  • image_encoder (StableDiffusionXLAutoVaeEncoderStep)
  • before_denoise (StableDiffusionXLAutoBeforeDenoiseStep)
  • denoise (StableDiffusionXLAutoDenoiseStep)
  • decode (StableDiffusionXLAutoDecodeStep)

Registered Components:
----------------------
text_encoder: CLIPTextModel (dtype=torch.float16, device=cpu)
text_encoder_2: CLIPTextModelWithProjection (dtype=torch.float16, device=cpu)
tokenizer: CLIPTokenizer
tokenizer_2: CLIPTokenizer
image_encoder: CLIPVisionModelWithProjection (dtype=torch.float16, device=cpu)
feature_extractor: CLIPImageProcessor
unet: UNet2DConditionModel (dtype=torch.float16, device=cpu)
vae: AutoencoderKL (dtype=torch.float16, device=cpu)
scheduler: EulerDiscreteScheduler
controlnet: ControlNetModel (dtype=torch.float16, device=cpu)
guider: CFGGuider
controlnet_guider: CFGGuider

Registered Configs:
------------------
force_zeros_for_empty_prompt: True
requires_aesthetics_score: False

------------------
This pipeline contains blocks that are selected at runtime based on inputs.

Trigger Inputs: {'control_image', 'control_mode', 'image_latents', 'padding_mask_crop', 'mask_image', 'ip_adapter_image', 'image', 'mask'}
  Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('control_image')`).
Check `.doc` of returned object for more information.

  Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.
  - for image-to-image generation, you need to provide either `image` or `image_latents`
  - for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` 
  - to run the controlnet workflow, you need to provide `control_image`
  - to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`
  - to run the ip_adapter workflow, you need to provide `ip_adapter_image`
  - for text-to-image generation, all you need to provide is `prompt`

  Args:

      prompt (`Union[str, List]`, *optional*):
          The prompt or prompts to guide the image generation.

      prompt_2 (`Union[str, List]`, *optional*):
          The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in
          both text-encoders

      negative_prompt (`Union[str, List]`, *optional*):
          The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if
          `guidance_scale` is less than `1`).

      negative_prompt_2 (`Union[str, List]`, *optional*):
          The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not
          defined, `negative_prompt` is used in both text-encoders

      cross_attention_kwargs (`Union[dict, NoneType]`, *optional*):
          A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor`
          in [diffusers.models.attention_processor]

      guidance_scale (`float`, *optional*, defaults to 5.0):
          Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
          `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance
          scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are
          closely linked to the text `prompt`, usually at the expense of lower image quality.

      clip_skip (`Union[int, NoneType]`, *optional*):

      ip_adapter_image (`Union[Image, ndarray, Tensor, List, List, List]`):
          The image(s) to be used as ip adapter

      height (`Union[int, NoneType]`, *optional*):
          The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below
          512 pixels won't work well for
          [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and
          checkpoints that are not specifically fine-tuned on low resolutions.

      width (`Union[int, NoneType]`, *optional*):
          The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512
          pixels won't work well for
          [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and
          checkpoints that are not specifically fine-tuned on low resolutions.

      generator (`Union[Generator, List, NoneType]`, *optional*):
          One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
          generation deterministic.

      image (`Union[Image, ndarray, Tensor, List, List, List]`):
          The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of
          the image will be masked out with `mask_image` and repainted according to `prompt`.

      mask_image (`Union[Image, ndarray, Tensor, List, List, List]`, *optional*):
          `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while
          black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel
          (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected
          shape would be `(B, H, W, 1)`.

      padding_mask_crop (`Union[Tuple, NoneType]`, *optional*):
          The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and
          mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect
          ratio of the image and contains all masked area, and then expand that area based on `padding_mask_crop`. The image
          and mask_image will then be cropped based on the expanded area before resizing to the original image size for
          inpainting. This is useful when the masked area is small while the image is large and contain information
          irrelevant for inpainting, such as background.

      num_images_per_prompt (`int`, *optional*, defaults to 1):
          The number of images to generate per prompt.

      num_inference_steps (`int`, *optional*, defaults to 50):
          The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower
          inference.

      timesteps (`Union[Tensor, NoneType]`, *optional*):
          Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their
          `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used.
          Must be in descending order.

      sigmas (`Union[Tensor, NoneType]`, *optional*):
          Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their
          `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used.

      denoising_end (`Union[float, NoneType]`, *optional*):
          When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before
          it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount
          of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should
          ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup.

      strength (`float`, *optional*, defaults to 0.3):
          Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting).
          Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
          `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1,
          added noise will be maximum and the denoising process will run for the full number of iterations specified in
          `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of
          `denoising_start` being declared as an integer, the value of `strength` will be ignored.

      denoising_start (`Union[float, NoneType]`, *optional*):
          The denoising start value to use for the scheduler. Determines the starting point of the denoising process.

      latents (`Union[Tensor, NoneType]`, *optional*):
          Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can
          be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by
          sampling using the supplied random `generator`.

      original_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The original size (height, width) of the image that conditions the generation process. If different from
          target_size, the image will appear to be down- or upsampled. Part of SDXL's micro-conditioning as explained in
          section 2.2 of https://huggingface.co/papers/2307.01952

      target_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The target size (height, width) of the generated image. For most cases, this should be set to the desired output
          dimensions. Part of SDXL's micro-conditioning as explained in section 2.2 of
          https://huggingface.co/papers/2307.01952

      negative_original_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The negative original size to condition against during generation. Part of SDXL's micro-conditioning as explained
          in section 2.2 of https://huggingface.co/papers/2307.01952. See:
          https://github.com/huggingface/diffusers/issues/4208

      negative_target_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The negative target size to condition against during generation. Should typically match target_size. Part of SDXL's
          micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See:
          https://github.com/huggingface/diffusers/issues/4208

      crops_coords_top_left (`Tuple`, *optional*, defaults to (0, 0)):
          `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
          `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning

      negative_crops_coords_top_left (`Tuple`, *optional*, defaults to (0, 0)):
          To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
          micro-conditioning

      aesthetic_score (`float`, *optional*, defaults to 6.0):
          Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of
          SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952

      negative_aesthetic_score (`float`, *optional*, defaults to 2.0):
          Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. Can be
          used to simulate an aesthetic score of the generated image by influencing the negative text condition.

      control_image (`Union[Image, ndarray, Tensor, List, List, List]`, *optional*):
          The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is
          used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass
          images as a list for proper batching.

      control_guidance_start (`Union[float, List]`, *optional*, defaults to 0.0):
          The percentage of total steps at which the ControlNet starts applying.

      control_guidance_end (`Union[float, List]`, *optional*, defaults to 1.0):
          The percentage of total steps at which the ControlNet stops applying.

      control_mode (`List`, *optional*):
          The control mode for union controlnet, 0 for openpose, 1 for depth, 2 for hed/pidi/scribble/ted, 3 for
          canny/lineart/anime_lineart/mlsd, 4 for normal and 5 for segment

      controlnet_conditioning_scale (`Union[float, List]`, *optional*, defaults to 1.0):
          Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list
          of scales.

      guess_mode (`bool`, *optional*, defaults to False):
          Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0.

      guidance_rescale (`float`, *optional*, defaults to 0.0):
          Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion
          Noise Schedules and Sample Steps are Flawed'.

      eta (`float`, *optional*, defaults to 0.0):
          Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others.

      guider_kwargs (`Union[Dict, NoneType]`, *optional*):
          Optional kwargs dictionary passed to the Guider.

      output_type (`str`, *optional*, defaults to pil):
          The output format of the generated image. Choose between PIL (PIL.Image.Image), torch.Tensor or np.array.

      return_dict (`bool`, *optional*, defaults to True):
          Whether or not to return a StableDiffusionXLPipelineOutput instead of a plain tuple.

      dtype (`dtype`, *optional*):
          The dtype of the model inputs

      preprocess_kwargs (`Union[dict, NoneType]`, *optional*):
          A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under
          `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]

      ip_adapter_embeds (`List`, *optional*):
          Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.

      negative_ip_adapter_embeds (`List`, *optional*):
          Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.

      image_latents (`Tensor`, *optional*):
          The latents representing the reference image for image-to-image/inpainting generation. Can be generated in
          vae_encode step.

      mask (`Tensor`, *optional*):
          The mask for the inpainting generation. Can be generated in vae_encode step.

      masked_image_latents (`Tensor`, *optional*):
          The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in
          vae_encode step.

      image_latents (`Union[Tensor, NoneType]`, *optional*):
          The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in
          vae_encode or prepare_latent step.

      crops_coords (`Union[Tuple, NoneType]`, *optional*):
          The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be
          generated in vae_encode or prepare_latent step.

      crops_coords (`Tuple`, *optional*):
          The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be
          generated in vae_encode step.

  Returns:

      images (`Union[List, List, List]`):
          The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array

use get_execution_blocks to see which blocks will run for your inputs/workflow, for example, if you want to run a text-to-image controlnet workflow, you can do this

print(auto_pipe.get_execution_blocks("control_image"))

see the docstring relevant to your inputs/workflow

print(auto_pipe.get_execution_blocks("control_image").doc)

Advanced Workflows

Once you've created the auto pipeline, you can use it for different features as long as you add the required components and pass the required inputs.

# Add ControlNet
auto_pipe.update_states(controlnet=controlnet)

# Enable IP-Adapter
auto_pipe.update_states(image_encoder=..., feature_extractor=...)
auto_pipe.load_ip_adapter("h94/IP-Adapter")

# Add LoRA
auto_pipe.load_lora_weights(...)

# at inference time, pass all the inputs required for your workflow
images = auto_pipe(
    prompt="..",
    control_image=pose_image,        # this trigger the ControlNet workflow
    ip_adapter_image=style_image,    # this trigger the ip-adapter workflow
    ...
).images

Here is an example you can run for a more complex workflow using controlnet/IP-Adapter/Lora/PAG

from diffusers import ControlNetModel
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from diffusers.utils import load_image
from diffusers.guider import PAGGuider

# load controlnet
controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=dtype)
components.add("controlnet", controlnet)

# load image_encoder for ip adapter
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)

# load additional components into the pipeline
auto_pipe.update_states(**components.get(["controlnet", "image_encoder", "feature_extractor"]))

# load ip adapter
auto_pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipe.set_ip_adapter_scale(0.6)

# let's also load a lora while we're at it
auto_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face")

# let's also throw PAG in there because why not!
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
auto_pipe.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)

# prepare inputs
prompt = "an astronaut"
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/person_pose.png")
ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")

# Run pipeline with everything combined
images = auto_pipe(
    prompt=prompt,
    control_image=control_image,
    ip_adapter_image=ip_adapter_image,
    output="images"
).images
images[0]

yiyi_modular_out

check out more usage examples here

test1: complete testing script for `StableDiffusionXLAutoPipeline`
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import StableDiffusionXLAutoPipeline, StableDiffusionXLIPAdapterStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "modular_test_outputs_0131_auto_pipeline"
if os.path.exists(out_folder):
    # Remove all files in the directory
    for file in os.listdir(out_folder):
        file_path = os.path.join(out_folder, file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f"Error: {e}")
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"

# for img2img
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99

# for ip adapter
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")


# (2) define blocks and nodes(builder)      

auto_pipeline_block = StableDiffusionXLAutoPipeline()
auto_pipeline = ModularPipeline.from_block(auto_pipeline_block)
refiner_pipeline = ModularPipeline.from_block(auto_pipeline_block)



# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"
ip_adapter_repo = "h94/IP-Adapter"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
components.add("controlnet", controlnet)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(ip_adapter_repo, subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)


# load components/config into nodes
auto_pipeline.update_states(**components.components)


# load other componetns for swap later
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)


# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
cfg_guider = CFGGuider()
controlnet_cfg_guider = CFGGuider()


# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()



# using auto_pipeline to generate images

# to get info about auto_pipeline and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
print(f" ")
print(f" auto_pipeline:")
print(auto_pipeline)
print(" ")


# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" auto_pipeline info (default use case: text2img)")
print(auto_pipeline.get_execution_blocks())
print(" ")

# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()


# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
auto_pipeline.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    auto_pipeline.unload_lora_weights()

auto_pipeline.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)


# test4: SDXL(text2img) with ip_adapter+ pag?
print(f" ")
print(f" running test4: SDXL(text2img) with ip_adapter")

auto_pipeline.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipeline.set_ip_adapter_scale(0.6)

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    ip_adapter_image=ip_adapter_image,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test  4_out_text2img_ip_adapter_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_ip_adapter.png")

auto_pipeline.unload_ip_adapter()
clear_memory()

# test5: SDXL(text2img) with controlnet

if not test_pag:
    auto_pipeline.update_states(guider=cfg_guider, controlnet_guider=controlnet_cfg_guider)
    guider_kwargs = {}
else:
    guider_kwargs = {"pag_scale": 3.0}


# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet use case)")
print(auto_pipeline.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test5: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_text2img_control.png")

clear_memory()

# test6: SDXL(img2img)

print(f" ")
print(f" running test6: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)

# let's checkout the sdxl_node info for img2img use case
print(f" auto_pipeline info (img2img use case)")
print(auto_pipeline.get_execution_blocks("image"))
print(" ")

images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img.png")

clear_memory()


# test7: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(auto_pipeline.get_execution_blocks("image", "control_image"))
print(" ")

print(f" ")
print(f" running test7: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_control.png")

clear_memory()

# test8: img2img with refiner

refiner_pipeline.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)
# let's checkout the refiner_node
print(f" refiner_pipeline info")
print(refiner_pipeline)
print(f" ")

print(f" refiner_pipeline: triggered by `image_latents`")
print(refiner_pipeline.get_execution_blocks("image_latents"))
print(" ")

print(f" running test8: img2img with refiner")


generator = torch.Generator(device="cuda").manual_seed(0)
latents = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)
images_output = refiner_pipeline(
    image_latents=latents,  
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_img2img_refiner.png")

clear_memory()

# test9: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" auto_pipeline info (inpainting use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting.png")

clear_memory()

# test10: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" auto_pipeline info (inpainting + controlnet use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "control_image"))
print(" ")

print(f" ") 
print(f" running test10: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    image=init_image,
    height=1024,
    width=1024,
    mask_image=inpaint_mask, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_control.png")

clear_memory()

# test11: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet")

auto_pipeline.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet.png")

clear_memory()


# test12: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test12: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    padding_mask_crop=33, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test13: apg

print(f" ")
print(f" running test13: apg")

apg_guider = APGGuider()
auto_pipeline.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
images_output = auto_pipeline(
  prompt=prompt, 
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

auto_pipeline.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), vae=components.get("vae_fix"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet union use case)")
print(auto_pipeline.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union

print(f" ")
print(f" auto_pipeline info (img2img controlnet union use case)")
print(auto_pipeline.get_execution_blocks("image", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    generator=generator, 
    control_mode=[3], 
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt, 
    height=1024, 
    width=1024, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" auto_pipeline info (inpainting controlnet union use case)")
print(auto_pipeline.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test16: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    mask_image=inpaint_mask, 
    control_image=controlnet_union_image,
    control_mode=[3],
    height=1024, 
    width=1024, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test16_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test16_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

Modular Setup

StableDiffusionXLAutoPipeline is a very convenient preset; Just like the LEGO sets, you can break it down and reassemble and rearrange the pipeline blocks however you want. A more modular setup would look like this:

# AUTOBLOCK is a map of all the blocks we used to assemble `StableDiffusionXLAutoPipeline`
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS


# step1: create separate nodes to encode text/image/ip-adapter inputs
text_node = ModularPipeline.from_block(AUTO_BLOCKS.pop("text_encoder")()) 
image_node = ModularPipeline.from_block(AUTO_BLOCKS.pop("image_encoder")()) 
decoder_node = ModularPipeline.from_block(AUTO_BLOCKS.pop("decode")()) 

# make a node for "denoising", here we just use the leftover blocks
class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(AUTO_BLOCKS.values())
    block_names = list(AUTO_BLOCKS.keys())

sdxl_node = SDXLAutoBlocks()
# we can also use the same block to make a refiner node, but you need to load a different unet/config later with 
refiner_node = SDXLAutoBlocks()

# lora_node for lora related things
lora_node = ModularPipeline.from_block(StableDiffusionXLLoraStep())
# IPAdapater nodes for IPAdapter related things
ip_adapter_node = ModularPipeline.from_block(StableDiffusionXLIPAdapterStep())

# step2: load models into the nodes (sdxl_node and refiner nodes are made with same block but need different components)
...
sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)
...

#step3:  generate embeddings to reuse them
text_state = text_node(prompt=,,,)
image_state = image_node(image=...)
ip_adapter_state = ip_adapter_node(...)

# step4: re-use embeddings in different workflows, change call parameters, or take the latent to use for a different workflow before decode
latents_img2img = sdxl_node(**text_state.intermediates, **image_state.intermediates, output="latents")
latents_text2img_28steps = sdxl_node(**text_state.intermediates, num_inference_steps = 28, ..., output="latents")
latents_text2img_ipa = sdxl_node(**text_state.intermedaites, **ip_adapter_embeddings, ..., output="latents)
latents_refined = refiner_node(**text_state.intermediates, image_latents=latents_xx, output="latents)
...

# step5: decode once it is ready to decode
image = decoder_node(latents=latents_refined, output="images").images
image[0]

With this setup, you precompute embeddings and reuse them across different denoise backends or with different inference parameters such as guidance_scale, num_inference_steps, or use different schedulers. You can modify your workflow by simply adding/removing/swapping blocks without recomputing the entire pipeline over and over again.

check out the full example script here

test2: modular setup This is the full testing script I used for more configuration, including inpainting/refiner/union controlnet/APG
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS, IMAGE2IMAGE_BLOCKS, StableDiffusionXLLoraStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "modular_test_outputs"
if os.path.exists(out_folder):
    # Remove all files in the directory
    for file in os.listdir(out_folder):
        file_path = os.path.join(out_folder, file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f"Error: {e}")
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a photo of an astronaut riding a horse on mars"

# for img2img
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99


# (2) define blocks and nodes(builder)   

all_blocks_map = AUTO_BLOCKS.copy()
# text block
text_block = all_blocks_map.pop("text_encoder")()
# image encoder block
image_encoder_block = all_blocks_map.pop("image_encoder")()
# decoder block
decoder_block = all_blocks_map.pop("decode")()

class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(all_blocks_map.values())
    block_names = list(all_blocks_map.keys())
# sdxl main block
sdxl_auto_blocks = SDXLAutoBlocks()

# lora step
lora_step = StableDiffusionXLLoraStep()


image2image_blocks_map = IMAGE2IMAGE_BLOCKS.copy()
# we do not need image_encoder for refiner becuase it takes image_latents (from another pipeline) as input
image_block = image2image_blocks_map.pop("image_encoder")()
# refiner block
class RefinerSteps(SequentialPipelineBlocks):
    block_classes = list(image2image_blocks_map.values())
    block_names = list(image2image_blocks_map.keys())
refiner_block = RefinerSteps()

text_node = ModularPipeline.from_block(text_block)
image_node = ModularPipeline.from_block(image_encoder_block)
sdxl_node = ModularPipeline.from_block(sdxl_auto_blocks)
decoder_node = ModularPipeline.from_block(decoder_block)
refiner_node = ModularPipeline.from_block(refiner_block)
lora_node = ModularPipeline.from_block(lora_step)


# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("controlnet", controlnet)
components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)


# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
cfg_guider = CFGGuider()
controlnet_cfg_guider = CFGGuider()

# load components/config into nodes
text_node.update_states(**components.get(["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]))
image_node.update_states(**components.get(["vae"]))
decoder_node.update_states(vae=components.get("vae"))

sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)

lora_node.update_states(**components.get(["unet", "text_encoder", "text_encoder_2"]))

# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()


# (5) run the workflows
print(f" ")
print(f" text_node:")
print(text_node)
print(f" ")
print(f" generating text embeddings with text_node")
# using text_node to generate text embeddings
text_state = text_node(prompt=prompt)
print(" ")
print(f" components info after run text_node: text_encoder and text_encoder_2 are on device")
print(components)
print(f" ")
print(f" text_state info")
print(text_state)
print(" ")



# using sdxl_node to generate images

# to get info about sdxl_node and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
# so the information might not be super useful for your specific use case, you will find a "trigger inputs" section says this

# Trigger Inputs: {'control_mode', 'control_image', 'image_latents', 'mask'}
#  Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('control_mode')`).
# Check `.doc` of returned object for more information. provided)

print(f" ")
print(f" sdxl_node:")
print(sdxl_node)
print(" ")

# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" sdxl_node info (default use case: text2img)")
print(sdxl_node.get_execution_blocks())
print(" ")


# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()

# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
lora_node.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    lora_node.unload_lora_weights()

sdxl_node.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)

# test4: SDXL(text2img) with controlnet

if not test_pag:
    sdxl_node.update_states(guider=cfg_guider, controlnet_guider=controlnet_cfg_guider)
    guider_kwargs = {}
else:
    guider_kwargs = {"pag_scale": 3.0}


# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet use case)")
print(sdxl_node.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test4: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test4_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_control.png")

clear_memory()

# test5: SDXL(img2img)

# for img2img use case, we encode the image with image_node first, this way we can use the same image_latents for different workflows
# let's checkout the image_node
print(f" image_node info")
print(image_node)
print(" ")


print(f" ")
print(f" running test5: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

# let's checkout what's in image_state
print(f" image_state info")
print(image_state)
print(" ")

# let's checkout the sdxl_node info for img2img use case
print(f" sdxl_node info (img2img use case)")
print(sdxl_node.get_execution_blocks("image_latents"))
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_img2img.png")

clear_memory()

# test6: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(sdxl_node.get_execution_blocks("image_latents","control_image"))
print(" ")

print(f" ")
print(f" running test6: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img_control.png")

clear_memory()

# test7: img2img with refiner

# let's checkout the refiner_node
print(f" refiner_node info")
print(refiner_node)
print(" ")

print(f" ")
print(f" running test7: img2img with refiner")

images_output = refiner_node(
    image_latents=latents, 
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_refiner.png")

clear_memory()

# test8: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" sdxl_node info (inpainting use case)")
print(sdxl_node.get_execution_blocks("mask", "image_latents"))
print(" ")

print(f" ") 
print(f" running test8: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)
print(f" image_state info")
print(image_state)
print(" ")
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_inpainting.png")

clear_memory()

# test9: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" sdxl_node info (inpainting + controlnet use case)")
print(sdxl_node.get_execution_blocks("mask", "control_image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting_control.png")

clear_memory()

# test10: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test10: SDXL(inpainting) with inpaint_unet")

sdxl_node.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    num_images_per_prompt=num_images_per_prompt,
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_inpaint_unet.png")

clear_memory()


# test11: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator, padding_mask_crop=33)
print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)

# we need a different decoder when using padding_mask_crop
print(f" decoder_node info")
print(decoder_node)
print(" ")
print(f" decoder_node info (inpaint/padding_mask_crop)")
print(decoder_node.pipeline_block.blocks["inpaint"])
print(" ")

images_output = decoder_node(latents=latents, crops_coords=image_state.get_intermediate("crops_coords"), **image_state.inputs, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test12: apg

print(f" ")
print(f" running test12: apg")

apg_guider = APGGuider()
sdxl_node.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
latents= sdxl_node(
  **text_state.intermediates,
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="latents"
)


images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

sdxl_node.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
image_node.update_states(vae=components.get("vae_fix"))
decoder_node.update_states(vae=components.get("vae_fix"))

# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet union use case)")
print(sdxl_node.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test13: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

latents = sdxl_node(
    **text_state.intermediates, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union
print(f" image_node info(with vae_fix for controlnet union)")
print(image_node)
print(" ")


print(f" ")
print(f" sdxl_node info (img2img controlnet union use case)")
print(sdxl_node.get_execution_blocks("image_latents", "control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" sdxl_node info (inpainting controlnet union use case)")
print(sdxl_node.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)
test3: modular setup with IPAdapter
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS, IMAGE2IMAGE_BLOCKS, StableDiffusionXLLoraStep, StableDiffusionXLIPAdapterStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "modular_test_outputs_0121_ipa"
if os.path.exists(out_folder):
    # Remove all files in the directory
    for file in os.listdir(out_folder):
        file_path = os.path.join(out_folder, file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f"Error: {e}")
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"

# for img2img
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99

# for ip adapter
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")


# (2) define blocks and nodes(builder)   

all_blocks_map = AUTO_BLOCKS.copy()
# text block
text_block = all_blocks_map.pop("text_encoder")()
# image encoder block
image_encoder_block = all_blocks_map.pop("image_encoder")()
# decoder block
decoder_block = all_blocks_map.pop("decode")()

class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(all_blocks_map.values())
    block_names = list(all_blocks_map.keys())
# sdxl main block
sdxl_auto_blocks = SDXLAutoBlocks()

# lora step
lora_step = StableDiffusionXLLoraStep()

# ip adapter step
ip_adapter_step = StableDiffusionXLIPAdapterStep()


image2image_blocks_map = IMAGE2IMAGE_BLOCKS.copy()
# we do not need image_encoder for refiner becuase it takes image_latents (from another pipeline) as input
image_block = image2image_blocks_map.pop("image_encoder")()
# refiner block
class RefinerSteps(SequentialPipelineBlocks):
    block_classes = list(image2image_blocks_map.values())
    block_names = list(image2image_blocks_map.keys())
refiner_block = RefinerSteps()

text_node = ModularPipeline.from_block(text_block)
image_node = ModularPipeline.from_block(image_encoder_block)
sdxl_node = ModularPipeline.from_block(sdxl_auto_blocks)
decoder_node = ModularPipeline.from_block(decoder_block)
refiner_node = ModularPipeline.from_block(refiner_block)
lora_node = ModularPipeline.from_block(lora_step)
ip_adapter_node = ModularPipeline.from_block(ip_adapter_step)


# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"
ip_adapter_repo = "h94/IP-Adapter"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("controlnet", controlnet)
components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(ip_adapter_repo, subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)

# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
cfg_guider = CFGGuider()
controlnet_cfg_guider = CFGGuider()

# load components/config into nodes
text_node.update_states(**components.get(["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]))
image_node.update_states(**components.get(["vae"]))
decoder_node.update_states(vae=components.get("vae"))

sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)

lora_node.update_states(**components.get(["unet", "text_encoder", "text_encoder_2"]))
ip_adapter_node.update_states(**components.get(["unet", "image_encoder", "feature_extractor"]))

# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()


# (5) run the workflows
print(f" ")
print(f" text_node:")
print(text_node)
print(f" ")
print(f" generating text embeddings with text_node")
# using text_node to generate text embeddings
text_state = text_node(prompt=prompt, negative_prompt=negative_prompt)
print(" ")
print(f" components info after run text_node: text_encoder and text_encoder_2 are on device")
print(components)
print(f" ")
print(f" text_state info")
print(text_state)
print(" ")


# use ip adapter to get image embeddings
print(f" ")
print(f" ip_adapter_node:")
print(ip_adapter_node)
print(f" ")
print(f" generating ip adapter image embeddings with ip_adapter_node")
ip_adapter_node.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
ip_adapter_node.set_ip_adapter_scale(0.6)
ip_adapter_state = ip_adapter_node(ip_adapter_image=ip_adapter_image)
print(f" ")
print(f" ip_adapter_state info")
print(ip_adapter_state)
print(" ")


# using sdxl_node to generate images
print(f" ")
print(f" sdxl_node:")
print(sdxl_node)
print(" ")

# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" sdxl_node info (default use case: text2img)")
print(sdxl_node.get_execution_blocks())
print(" ")


# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()

# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
lora_node.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    lora_node.unload_lora_weights()

sdxl_node.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)

# test4: SDXL(text2img) with controlnet

if not test_pag:
    sdxl_node.update_states(guider=cfg_guider, controlnet_guider=controlnet_cfg_guider)
    guider_kwargs = {}
else:
    guider_kwargs = {"pag_scale": 3.0}


# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet use case)")
print(sdxl_node.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test4: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test4_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_control.png")

clear_memory()

# test5: SDXL(img2img)

# for img2img use case, we encode the image with image_node first, this way we can use the same image_latents for different workflows
# let's checkout the image_node
print(f" image_node info")
print(image_node)
print(" ")


print(f" ")
print(f" running test5: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

# let's checkout what's in image_state
print(f" image_state info")
print(image_state)
print(" ")

# let's checkout the sdxl_node info for img2img use case
print(f" sdxl_node info (img2img use case)")
print(sdxl_node.get_execution_blocks("image_latents"))
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_img2img.png")

clear_memory()

# test6: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(sdxl_node.get_execution_blocks("image_latents","control_image"))
print(" ")

print(f" ")
print(f" running test6: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img_control.png")

clear_memory()

# test7: img2img with refiner

# let's checkout the refiner_node
print(f" refiner_node info")
print(refiner_node)
print(" ")

print(f" ")
print(f" running test7: img2img with refiner")

images_output = refiner_node(
    image_latents=latents, 
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_refiner.png")

clear_memory()

# test8: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" sdxl_node info (inpainting use case)")
print(sdxl_node.get_execution_blocks("mask", "image_latents"))
print(" ")

print(f" ") 
print(f" running test8: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)
print(f" image_state info")
print(image_state)
print(" ")
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_inpainting.png")

clear_memory()

# test9: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" sdxl_node info (inpainting + controlnet use case)")
print(sdxl_node.get_execution_blocks("mask", "control_image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting_control.png")

clear_memory()

# test10: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test10: SDXL(inpainting) with inpaint_unet")

sdxl_node.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    num_images_per_prompt=num_images_per_prompt,
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_inpaint_unet.png")

clear_memory()


# test11: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator, padding_mask_crop=33)
print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)

# we need a different decoder when using padding_mask_crop
print(f" decoder_node info")
print(decoder_node)
print(" ")
print(f" decoder_node info (inpaint/padding_mask_crop)")
print(decoder_node.pipeline_block.blocks["inpaint"])
print(" ")

images_output = decoder_node(latents=latents, crops_coords=image_state.get_intermediate("crops_coords"), **image_state.inputs, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test12: apg

print(f" ")
print(f" running test12: apg")

apg_guider = APGGuider()
sdxl_node.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
latents= sdxl_node(
  **text_state.intermediates,
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  **ip_adapter_state.intermediates,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="latents"
)


images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

sdxl_node.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
image_node.update_states(vae=components.get("vae_fix"))
decoder_node.update_states(vae=components.get("vae_fix"))

# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet union use case)")
print(sdxl_node.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test13: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

latents = sdxl_node(
    **text_state.intermediates, 
    **ip_adapter_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union
print(f" image_node info(with vae_fix for controlnet union)")
print(image_node)
print(" ")


print(f" ")
print(f" sdxl_node info (img2img controlnet union use case)")
print(sdxl_node.get_execution_blocks("image_latents", "control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" sdxl_node info (inpainting controlnet union use case)")
print(sdxl_node.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

Developer Guide: Building with Modular Diffusers

Core Components Overview

The Modular Diffusers architecture consists of four main components:

ModularPipeline

The main interface for creating and running modular pipelines. Unlike traditional pipelines, you don't write it from scratch - it builds itself from pipeline blocks! Example usage:

from diffusers import ModularPipeline
pipe = ModularPipeline.from_block(auto_pipeline_block)
images = pipe(prompt="a cat", num_inference_steps=15, output="images")

PipelineBlock

The fundamental building block, similar to a mellon/comfy node. Each block:

  • Defines required components, inputs, and outputs
  • Implements __call__(pipeline, state) -> (pipeline, state)
  • Can be reused across different pipelines
  • Can be combined with other blocks

MultiPipelineBlocks

Combines multiple blocks into a bigger one! These combined blocks behave just like single blocks - with their own inputs, outputs, and components, but they are able to handle more complex workflows!

We have two types of MultiPipelineBlocks available, you can use them to combine individual blocks into ready-to-use sets (Like LEGO® presets!)

  1. SequentialPipelineBlocks

    • Chains blocks in sequential order
    class StableDiffusionXLMainSteps(SequentialPipelineBlocks):
        block_classes = [InputStep, SetTimestepsStep, ...]
        block_names = ["input", "set_timesteps", ...]
  2. AutoPipelineBlocks

    • Provides conditional block selection, AutoPipelineBlocks makes the complex if.. else.. logic in your code disappear! with this, you can write blocks for specific use case to keep your code path clean; and use AutoPipelineBlocks to combine blocks into convenient presets that can provide a better user experience :)
    • In this example the ControlNetDenoiseStep step will be dispatched when "control_image" is passed from the user, otherwise, it will run the default DenoseStep
    class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
        block_classes = [ ControlNetDenoiseStep, DenoiseStep]
        block_names = [ "controlnet", "unet"]
        block_trigger_inputs = ["control_image", None]

PipelineState and BlockStates

PipelineState and BlockStates manage dataflow between/inside blocks; they make debugging really easy! feel free to print out them at any given time to have an overview of all the shapes/types/values of your pipeline/block states

Differential Diffusion Example

Here we'll show you a new way to build with Modular Diffusers. Let's look at implementing a Differential Diffusion pipeline as an example. (https://differential-diffusion.github.io/). It is, in a sense, an image-to-image workflow, so we can start with the preset of pipeline blocks we used to build our current img2img pipeline (IMAGE2IMAGE_BLOCKS) and see how we can build this new pipeline with them!

IMAGE2IMAGE_BLOCKS = OrderedDict([
    ("text_encoder", StableDiffusionXLTextEncoderStep),
    ("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
    ("image_encoder", StableDiffusionXLVaeEncoderStep),
    ("input", StableDiffusionXLInputStep),
    ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
    ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
    ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
    ("denoise", StableDiffusionXLDenoiseStep),
    ("decode", StableDiffusionXLDecodeStep)
])

It seems like we can reuse the "text_encoder", "ip_adapter", "image_encoder", "input", "prepare_add_cond" and "decode" steps from img2img workflow out-of-box. The "set_timesteps" step in Differential Diffusion is the same as the one we use for text-to-image (i.e. it does not take strength parameter), so we just use StableDiffusionXLSetTimestepsStep. It uses a different denoising method so we will need to write a new "denoise" step, and the "prepare_latents" step is also a little bit different, so we will write a new one too.

Here are the changes needed to create the Differential Diffusion version of these blocks:

  1. Modified StableDiffusionXLImg2ImgPrepareLatentsStep :
  class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
      expected_components = ["vae", "scheduler"]
      model_name = "stable-diffusion-xl"
  
      @property
      def description(self) -> str:
          return (
-             "Step that prepares the latents for the image-to-image generation process"
+             "Step that prepares the latents for the differential diffusion generation process"
          )
  
      @property
      def intermediates_inputs(self) -> List[InputParam]:
          return [
-             InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation."),
+             InputParam("timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for sampling. Can be generated in set_timesteps step."),
              InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation."),
              InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt."),
              InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")]
  
      def __call__(self, pipeline, state: PipelineState) -> PipelineState:
          data = self.get_block_state(state)
          data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype
          data.device = pipeline._execution_device
          data.add_noise = True if data.denoising_start is None else False
+         pipeline.scheduler.set_begin_index(None)
          if data.latents is None:
              data.latents = pipeline.prepare_latents_img2img(
                  data.image_latents,
-                 data.latent_timestep,
+                 data.timesteps,
                  data.batch_size,
                  data.num_images_per_prompt,
                  data.dtype,
                  data.device,
                  data.generator,
                  data.add_noise,
              )
  1. Modified StableDiffusionXLDenoiseStep step: we remove inpaint-related logics and added diff-diff specific logic
  class SDXLDiffDiffDenoiseStep(PipelineBlock):
      expected_components = ["unet", "scheduler", "guider"]
      model_name = "stable-diffusion-xl"
  
      @property
      def description(self) -> str:
          return (
-             "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process"
+             "Step that iteratively denoise the latents for the image generation process using differential diffusion"
          )

      @property
      def inputs(self) -> List[Tuple[str, Any]]:
          return [
              # ... common parameters ...
+             InputParam("diffdiff_map", required=True),
+             InputParam("denoising_start"),
          ]

      def __init__(self):
          super().__init__()
          self.components["guider"] = CFGGuider()
          self.components["scheduler"] = None
          self.components["unet"] = None
+         self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_convert_grayscale=True)

      @torch.no_grad()
      def __call__(self, pipeline, state: PipelineState) -> PipelineState:
          # ... setup code ...

+         # preparations for diff diff
+         data.latent_height = data.image_latents.shape[-2]
+         data.latent_width = data.image_latents.shape[-1]
+         data.diffdiff_map = pipeline.mask_processor.preprocess(data.diffdiff_map, height=data.latent_height, width=data.latent_width)
+         
+         data.diffdiff_map = data.diffdiff_map.squeeze(0).to(data.device)
+         thresholds = torch.arange(data.num_inference_steps, dtype=data.diffdiff_map.dtype) / data.num_inference_steps
+         data.thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(data.device)
+         data.masks = data.diffdiff_map > (data.thresholds + (data.denoising_start or 0))
+
+         data.original_with_noise = data.latents

          with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
              for i, t in enumerate(data.timesteps):
+                 # diff diff
+                 if i == 0 and data.denoising_start is None:
+                     data.latents = data.original_with_noise[:1]
+                 else:
+                     data.mask = data.masks[i].unsqueeze(0)
+                     data.mask = data.mask.to(data.latents.dtype)
+                     data.mask = data.mask.unsqueeze(1)  # fit shape
+                     data.latents = data.original_with_noise[i] * data.mask + data.latents * (1 - data.mask)

                  # ... rest of denoising loop ...
-                 if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None:
-                     data.init_latents_proper = data.image_latents
-                     if i < len(data.timesteps) - 1:
-                         data.noise_timestep = data.timesteps[i + 1]
-                         data.init_latents_proper = pipeline.scheduler.add_noise(
-                             data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep])
-                         )
-                     data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents

That's all there is to it! Once you've made these 2 diff-diff blocks, you can create a preset(pre-assembled sets of blocks) and then build your pipeline from it.

# create diff-diff preset
DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]

class DiffDiffBlocks(SequentialPipelineBlocks):
    block_classes = list(DIFFDIFF_BLOCKS.values())
    block_names = list(DIFFDIFF_BLOCKS.keys())

# create diff-diff pipeline from preset
diffdiff_blocks = DiffDiffBlocks()
dd_node = ModularPipeline.from_block(diffdiff_blocks)

to use it

dd_node.update_states(**components.components)

prompt = "a green pear"
negative_prompt = "blurry"

image = dd_node(
    prompt=prompt,
    negative_prompt=negative_prompt,
    diffdiff_map=mask,
    image=image,
    output="images"
).images[0]

diff-diff-out

Complete Example: Implementing Differential Diffusion Pipeline
from diffusers.pipelines.modular_pipeline import PipelineBlock, SequentialPipelineBlocks, PipelineState, InputParam, OutputParam
from diffusers.guider import CFGGuider
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import DPMSolverMultistepScheduler

import torch
from typing import List, Tuple, Any, Optional, Dict, Union

class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
    expected_components = ["vae", "scheduler"]
    model_name = "stable-diffusion-xl"

    @property
    def description(self) -> str:
        return (
            "Step that prepares the latents for the differential diffusion generation process"
        )

    @property
    def inputs(self) -> List[Tuple[str, Any]]:
        return [
            InputParam(
                "generator", 
                type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], 
                description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) "
                           "to make generation deterministic."
            ),
            InputParam(
                "latents", 
                type_hint=Optional[torch.Tensor], 
                description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`."
            ),
            InputParam(
                "num_images_per_prompt", 
                default=1, 
                type_hint=int, 
                description="The number of images to generate per prompt"
            ),
            InputParam(
                "denoising_start", 
                type_hint=Optional[float], 
                description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups."
            ),
        ]

    @property
    def intermediates_inputs(self) -> List[InputParam]:
        return [
            InputParam("timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for sampling. Can be generated in set_timesteps step."), 
            InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), 
            InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), 
            InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")]

    @property
    def intermediates_outputs(self) -> List[OutputParam]:
        return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")]

    def __init__(self):
        super().__init__()
        self.components["scheduler"] = None

    @torch.no_grad()
    def __call__(self, pipeline, state: PipelineState) -> PipelineState:
        data = self.get_block_state(state)

        data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype
        data.device = pipeline._execution_device
        data.add_noise = True if data.denoising_start is None else False
        pipeline.scheduler.set_begin_index(None)
        if data.latents is None:
            data.latents = pipeline.prepare_latents_img2img(
                data.image_latents,
                data.timesteps,
                data.batch_size,
                data.num_images_per_prompt,
                data.dtype,
                data.device,
                data.generator,
                data.add_noise,
            )

        self.add_block_state(state, data)

        return pipeline, state


class SDXLDiffDiffDenoiseStep(PipelineBlock):
    expected_components = ["unet", "scheduler", "guider"]
    model_name = "stable-diffusion-xl"

    @property
    def description(self) -> str:
        return (
            "Step that iteratively denoise the latents for the image generation process using differential diffusion"
        )

    @property
    def inputs(self) -> List[Tuple[str, Any]]:
        return [
            InputParam(
                "guidance_scale", 
                type_hint=float,
                default=5.0,
                description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1."
            ),
            InputParam(
                "guidance_rescale",
                type_hint=float,
                default=0.0,
                description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'."
            ),
            InputParam(
                "cross_attention_kwargs",
                type_hint=Optional[Dict[str, Any]],
                default=None,
                description="Optional kwargs dictionary passed to the AttentionProcessor."
            ),
            InputParam(
                "generator",
                type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
                description="One or a list of torch generator(s) to make generation deterministic."
            ),
            InputParam(
                "eta",
                type_hint=float,
                default=0.0,
                description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others."
            ),
            InputParam(
                "guider_kwargs",
                type_hint=Optional[Dict[str, Any]],
                default=None,
                description="Optional kwargs dictionary passed to the Guider."
            ),
            InputParam(
                "num_images_per_prompt",
                type_hint=int,
                default=1,
                description="The number of images to generate per prompt."
            ),
            InputParam("diffdiff_map",required=True),
            InputParam("denoising_start"),
        ]

    @property
    def intermediates_inputs(self) -> List[str]:
        return [
            InputParam(
                "latents", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
            ),
            InputParam(
                "batch_size", 
                required=True, 
                type_hint=int, 
                description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
            ),
            InputParam(
                "timesteps", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
            ),
            InputParam(
                "num_inference_steps", 
                required=True, 
                type_hint=int, 
                description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."
            ),
            InputParam(
                "pooled_prompt_embeds", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step."
            ),
            InputParam(
                "negative_pooled_prompt_embeds", 
                type_hint=Optional[torch.Tensor], 
                description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step.    "
            ),
            InputParam(
                "add_time_ids", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step."
            ),
            InputParam(
                "negative_add_time_ids", 
                type_hint=Optional[torch.Tensor], 
                description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step."
            ),
            InputParam(
                "prompt_embeds", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step."
            ),
            InputParam(
                "negative_prompt_embeds", 
                type_hint=Optional[torch.Tensor], 
                description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step.   "
            ),
            InputParam(
                "timestep_cond", 
                type_hint=Optional[torch.Tensor], 
                description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step."
            ),
            InputParam(
                "image_latents", 
                type_hint=Optional[torch.Tensor], 
                description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step."
            ),
            InputParam(
                "ip_adapter_embeds", 
                type_hint=Optional[torch.Tensor], 
                description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step."
            ),
            InputParam(
                "negative_ip_adapter_embeds", 
                type_hint=Optional[torch.Tensor], 
                description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step."
            ),
        ]

    @property
    def intermediates_outputs(self) -> List[OutputParam]:
        return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]

    def __init__(self):
        super().__init__()
        self.components["guider"] = CFGGuider()
        self.components["scheduler"] = None
        self.components["unet"] = None
        self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_convert_grayscale=True)
    
    @torch.no_grad()
    def __call__(self, pipeline, state: PipelineState) -> PipelineState:

        data = self.get_block_state(state)

        data.num_channels_unet = pipeline.unet.config.in_channels
        data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
        data.device = pipeline._execution_device

        # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale
        data.guider_kwargs = data.guider_kwargs or {}
        data.guider_kwargs = {
            **data.guider_kwargs,
            "disable_guidance": data.disable_guidance,
            "guidance_scale": data.guidance_scale,
            "guidance_rescale": data.guidance_rescale,
            "batch_size": data.batch_size * data.num_images_per_prompt,
        }

        pipeline.guider.set_guider(pipeline, data.guider_kwargs)
        # Prepare conditional inputs using the guider
        data.prompt_embeds = pipeline.guider.prepare_input(
            data.prompt_embeds,
            data.negative_prompt_embeds,
        )
        data.add_time_ids = pipeline.guider.prepare_input(
            data.add_time_ids,
            data.negative_add_time_ids,
        )
        data.pooled_prompt_embeds = pipeline.guider.prepare_input(
            data.pooled_prompt_embeds,
            data.negative_pooled_prompt_embeds,
        )

        data.added_cond_kwargs = {
            "text_embeds": data.pooled_prompt_embeds,
            "time_ids": data.add_time_ids,
        }

        if data.ip_adapter_embeds is not None:
            data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds)
            data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds

        # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta)
        data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)

        # preparations for diff diff
        data.latent_height = data.image_latents.shape[-2]
        data.latent_width = data.image_latents.shape[-1]
        data.diffdiff_map = pipeline.mask_processor.preprocess(data.diffdiff_map, height=data.latent_height, width=data.latent_width)
        
        data.diffdiff_map = data.diffdiff_map.squeeze(0).to(data.device)
        thresholds = torch.arange(data.num_inference_steps, dtype=data.diffdiff_map.dtype) / data.num_inference_steps
        data.thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(data.device)
        data.masks = data.diffdiff_map > (data.thresholds + (data.denoising_start or 0))

        data.original_with_noise = data.latents

        with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
            for i, t in enumerate(data.timesteps):
    
                # diff diff
                if i == 0 and data.denoising_start is None:
                    data.latents = data.original_with_noise[:1]
                else:
                    data.mask = data.masks[i].unsqueeze(0)
                    # cast mask to the same type as latents etc
                    data.mask = data.mask.to(data.latents.dtype)
                    data.mask = data.mask.unsqueeze(1)  # fit shape
                    data.latents = data.original_with_noise[i] * data.mask + data.latents * (1 - data.mask)
                # end diff diff
        
                # expand the latents if we are doing classifier free guidance
                data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents)
                data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t)

                # predict the noise residual
                data.noise_pred = pipeline.unet(
                    data.latent_model_input,
                    t,
                    encoder_hidden_states=data.prompt_embeds,
                    timestep_cond=data.timestep_cond,
                    cross_attention_kwargs=data.cross_attention_kwargs,
                    added_cond_kwargs=data.added_cond_kwargs,
                    return_dict=False,
                )[0]
                # perform guidance
                data.noise_pred = pipeline.guider.apply_guidance(
                    data.noise_pred,
                    timestep=t,
                    latents=data.latents,
                )
                # compute the previous noisy sample x_t -> x_t-1
                data.latents_dtype = data.latents.dtype
                data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
                if data.latents.dtype != data.latents_dtype:
                    if torch.backends.mps.is_available():
                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
                        data.latents = data.latents.to(data.latents_dtype)

                if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
                    progress_bar.update()

        pipeline.guider.reset_guider(pipeline)
        self.add_block_state(state, data)

        return pipeline, state



from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import IMAGE2IMAGE_BLOCKS, TEXT2IMAGE_BLOCKS
from diffusers.pipelines.modular_pipeline import ModularPipeline
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.utils import load_image

from torchvision import transforms
import torchvision


DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]


DIFFDIFF_CORE_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()


class DiffDiffBlocks(SequentialPipelineBlocks):
    block_classes = list(DIFFDIFF_BLOCKS.values())
    block_names = list(DIFFDIFF_BLOCKS.keys())


diffdiff_blocks = DiffDiffBlocks()
dd_node = ModularPipeline.from_block(diffdiff_blocks)

components = ComponentsManager()
components.add_from_pretrained("SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, variant="fp16")

components.enable_auto_cpu_offload()

dd_node.update_states(**components.components)

print(dd_node)


image = load_image(
        "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true"
    )

mask = load_image(
        "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true"
    )

prompt = "a green pear"
negative_prompt = "blurry"

image = dd_node(
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=7.5,
    num_inference_steps=25,
    diffdiff_map=mask,
    image=image,
    output="images"
).images[0]

image.save("diffdiff_out.png")

Diffusers as seen in nodes

coming up soon....

Next Steps

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yoland68
Copy link

Very cool!

@oozzy77
Copy link

oozzy77 commented Oct 30, 2024

hi this is very interesting! I'm making a Python pipeline flow visual scripting tool, that can auto-convert functions to visual nodes for fast and modular UI blocks demo. Itself is a pip package: https://pypi.org/project/nozyio/

I wanted to integrate diffusers with my flow nodes UI project but found its not very modular. But this PR may change that! Looking forward to see how this evolves.

github: https://github.com/oozzy77/nozyio happy to connect!

@yiyixuxu
Copy link
Collaborator Author

@oozzy77 thanks!
do you want to join a slack channel with me? if you want to experiment building something with this PR I'm eager to hear your feedback and iterate base on that

@oozzy77
Copy link

oozzy77 commented Oct 31, 2024 via email

@yiyixuxu
Copy link
Collaborator Author

@oozzy77 I sent an invite!

@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Dec 4, 2024
@hlky hlky mentioned this pull request Dec 5, 2024
@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Jan 12, 2025

testing script to use from the latest commit (will keep this one up to date from now on) cc @hlky @asomoza
now have a auto workflow that supports any combination of text2img, img2img, inpaint, controlnet, controlnet-union, pag, APG, lora

testing script for modular diffusers (most updated)
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS, IMAGE2IMAGE_BLOCKS 

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16


# define output folder
out_folder = "modular_test_outputs_0110"
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a photo of an astronaut riding a horse on mars"

# for img2img
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99


# (2) define blocks and nodes(builder)   

all_blocks_map = AUTO_BLOCKS.copy()
# text block
text_block = all_blocks_map.pop("text_encoder")()
# image encoder block
image_encoder_block = all_blocks_map.pop("image_encoder")()
# decoder block
decoder_block = all_blocks_map.pop("decode")()

class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(all_blocks_map.values())
    block_names = list(all_blocks_map.keys())
# sdxl main block
sdxl_auto_blocks = SDXLAutoBlocks()



image2image_blocks_map = IMAGE2IMAGE_BLOCKS.copy()
# we do not need image_encoder for refiner becuase it takes image_latents (from another pipeline) as input
_ = image2image_blocks_map.pop("image_encoder")()
# refiner block
class RefinerSteps(SequentialPipelineBlocks):
    block_classes = list(image2image_blocks_map.values())
    block_names = list(image2image_blocks_map.keys())
refiner_block = RefinerSteps()

text_node = ModularPipeline.from_block(text_block)
image_node = ModularPipeline.from_block(image_encoder_block)
sdxl_node = ModularPipeline.from_block(sdxl_auto_blocks)
decoder_node = ModularPipeline.from_block(decoder_block)
refiner_node = ModularPipeline.from_block(refiner_block)


# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"

components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("controlnet", controlnet)
components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)

# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")

# load components/config into nodes
text_node.update_states(**components.get(["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]))
image_node.update_states(**components.get(["vae"]))
decoder_node.update_states(vae=components.get("vae"))

sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)



# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()


# (5) run the workflows
print(f" ")
print(f" text_node:")
print(text_node)
print(f" ")
print(f" generating text embeddings with text_node")
# using text_node to generate text embeddings
text_state = text_node(prompt=prompt)
print(" ")
print(f" components info after run text_node: text_encoder and text_encoder_2 are on device")
print(components)
print(f" ")
print(f" text_state info")
print(text_state)
print(" ")


# using sdxl_node to generate images

# to get info about sdxl_node and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
# so the information might not be super useful for your specific use case, you will find a "trigger inputs" section says this
#   Trigger Inputs:
#   --------------
#   This pipeline contains dynamic blocks that are selected at runtime based on your inputs.
#   • Trigger inputs: {'control_image', 'image_latents', 'mask'}
#   • Use .pipeline_block.get_triggered_blocks(*inputs) to see which blocks will be used for specific inputs
#   • Use .pipeline_block.get_triggered_blocks() to see blocks will be used for default inputs (when no trigger inputs are provided)
print(f" ")
print(f" sdxl_node:")
print(sdxl_node)
print(" ")

# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" sdxl_node info (default use case: text2img)")
print(sdxl_node.pipeline_block.get_triggered_blocks())
print(" ")


# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test1_out_text2img.png")
print(f" save modular output to {out_folder}/test1_out_text2img.png")

clear_memory()

# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
sdxl_node.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
latents = sdxl_node(
    **text_state.intermediates, 
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test2_out_text2img_lora.png")
print(f" save modular output to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image without lora again, with pag
print(f" ")
print(f" running test3:text2image without lora again, with pag")
sdxl_node.unload_lora_weights()
sdxl_node.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test3_out_text2img_pag.png")
print(f" save modular output to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)

# test4: SDXL(text2img) with controlnet

# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test4: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs={"pag_scale": 3.0}, 
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test4_out_text2img_control.png")
print(f" save modular output to {out_folder}/test4_out_text2img_control.png")

clear_memory()

# test5: SDXL(img2img)

# for img2img use case, we encode the image with image_node first, this way we can use the same image_latents for different workflows
# let's checkout the image_node
print(f" image_node info")
print(image_node)
print(" ")


print(f" ")
print(f" running test5: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

# let's checkout what's in image_state
print(f" image_state info")
print(image_state)
print(" ")

# let's checkout the sdxl_node info for img2img use case
print(f" sdxl_node info (img2img use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("image_latents"))
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    strength=strength, 
    guider_kwargs={"pag_scale": 3.0}, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test5_out_img2img.png")
print(f" save modular output to {out_folder}/test5_out_img2img.png")

clear_memory()

# test6: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("image_latents","control_image"))
print(" ")

print(f" ")
print(f" running test6: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs={"pag_scale": 3.0}, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=strength, 
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test6_out_img2img_control.png")
print(f" save modular output to {out_folder}/test6_out_img2img_control.png")

clear_memory()

# test7: img2img with refiner

# let's checkout the refiner_node
print(f" refiner_node info")
print(refiner_node)
print(" ")

print(f" ")
print(f" running test7: img2img with refiner")

images_output = refiner_node(
    image_latents=latents, 
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    output="images"
)
images_output.images[0].save(f"{out_folder}/test7_out_img2img_refiner.png")
print(f" save modular output to {out_folder}/test7_out_img2img_refiner.png")

clear_memory()

# test8: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" sdxl_node info (inpainting use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("mask", "image_latents"))
print(" ")

print(f" ") 
print(f" running test8: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)
print(f" image_state info")
print(image_state)
print(" ")
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs={"pag_scale": 3.0}, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test8_out_inpainting.png")
print(f" save modular output to {out_folder}/test8_out_inpainting.png")

clear_memory()

# test9: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" sdxl_node info (inpainting + controlnet use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("mask", "control_image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs={"pag_scale": 3.0}, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test9_out_inpainting_control.png")
print(f" save modular output to {out_folder}/test9_out_inpainting_control.png")

clear_memory()

# test10: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test10: SDXL(inpainting) with inpaint_unet")

sdxl_node.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs={"pag_scale": 3.0}, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test10_out_inpainting_inpaint_unet.png")
print(f" save modular output to {out_folder}/test10_out_inpainting_inpaint_unet.png")

clear_memory()


# test11: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator, padding_mask_crop=33)
print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs={"pag_scale": 3.0}, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)

# we need a different decoder when using padding_mask_crop
print(f" decoder_node info")
print(decoder_node)
print(" ")
print(f" decoder_node info (inpaint/padding_mask_crop)")
print(decoder_node.pipeline_block.blocks["inpaint"])
print(" ")

images_output = decoder_node(latents=latents, crops_coords=image_state.get_intermediate("crops_coords"), **image_state.inputs, output="images")
images_output.images[0].save(f"{out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")
print(f" save modular output to {out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test12: apg

print(f" ")
print(f" running test12: apg")

apg_guider = APGGuider()
sdxl_node.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
latents= sdxl_node(
  **text_state.intermediates,
  generator=generator,
  num_inference_steps=20,
  guidance_scale=15,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="latents"
)


images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test12_out_apg.png")
print(f" save modular output to {out_folder}/test12_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

sdxl_node.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
image_node.update_states(vae=components.get("vae_fix"))
decoder_node.update_states(vae=components.get("vae_fix"))

# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet union use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test13: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

latents = sdxl_node(
    **text_state.intermediates, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test13_out_text2img_control_union.png")
print(f" save modular output to {out_folder}/test13_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union
print(f" image_node info(with vae_fix for controlnet union)")
print(image_node)
print(" ")


print(f" ")
print(f" sdxl_node info (img2img controlnet union use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("image_latents", "control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test14_out_img2img_control_union.png")
print(f" save modular output to {out_folder}/test14_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" sdxl_node info (inpainting controlnet union use case)")
print(sdxl_node.pipeline_block.get_triggered_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
images_output.images[0].save(f"{out_folder}/test15_out_inpainting_control_union.png")
print(f" save modular output to {out_folder}/test15_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

@yiyixuxu yiyixuxu changed the title [WIP] The Modular Diffusers The Modular Diffusers Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

7 participants