Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion image-to-image and inpaint using onnx. #552

Merged
merged 7 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
},
opset=opset,
)
del pipeline.text_encoder

# UNET
unet_path = output_path / "unet" / "model.onnx"
Expand Down Expand Up @@ -125,6 +126,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
location="weights.pb",
convert_attribute=False,
)
del pipeline.unet

# VAE ENCODER
vae_encoder = pipeline.vae
Expand Down Expand Up @@ -157,6 +159,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
},
opset=opset,
)
del pipeline.vae

# SAFETY CHECKER
safety_checker = pipeline.safety_checker
Expand All @@ -173,8 +176,10 @@ def convert_models(model_path: str, output_path: str, opset: int):
},
opset=opset,
)
del pipeline.safety_checker

onnx_pipeline = StableDiffusionOnnxPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
Expand All @@ -187,6 +192,8 @@ def convert_models(model_path: str, output_path: str, opset: int):
onnx_pipeline.save_pretrained(output_path)
print("ONNX pipeline saved to", output_path)

del pipeline
del onnx_pipeline
_ = StableDiffusionOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
print("ONNX pipeline is loadable")

Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@
from .utils.dummy_torch_and_transformers_objects import * # noqa F403

if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import StableDiffusionOnnxPipeline
from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
StableDiffusionOnnxPipeline,
)
else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403

Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
)

if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline
from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
StableDiffusionOnnxPipeline,
)

if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ class StableDiffusionPipelineOutput(BaseOutput):
from .safety_checker import StableDiffusionSafetyChecker

if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion import StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline

if is_transformers_available() and is_flax_available():
import flax
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):

def __init__(
self,
vae_encoder: OnnxRuntimeModel,
vae_decoder: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
Expand All @@ -36,6 +37,7 @@ def __init__(
):
super().__init__()
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

68 changes: 68 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
LDMPipeline,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
Expand Down Expand Up @@ -2025,6 +2027,72 @@ def test_stable_diffusion_onnx(self):
expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

@slow
def test_stable_diffusion_img2img_onnx(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))

pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
)
pipe.set_progress_bar_config(disable=None)

prompt = "A fantasy landscape, trending on artstation"

np.random.seed(0)
output = pipe(
prompt=prompt,
init_image=init_image,
strength=0.75,
guidance_scale=7.5,
num_inference_steps=8,
output_type="np",
)
images = output.images
image_slice = images[0, 255:258, 383:386, -1]

assert images.shape == (1, 512, 768, 3)
expected_slice = np.array([[0.4806, 0.5125, 0.5453, 0.4846, 0.4984, 0.4955, 0.4830, 0.4962, 0.4969]])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

@slow
def test_stable_diffusion_inpaint_onnx(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)

pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
)
pipe.set_progress_bar_config(disable=None)

prompt = "A red cat sitting on a park bench"

np.random.seed(0)
output = pipe(
prompt=prompt,
init_image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
num_inference_steps=8,
output_type="np",
)
images = output.images
image_slice = images[0, 255:258, 255:258, -1]

assert images.shape == (1, 512, 512, 3)
expected_slice = np.array([0.3524, 0.3289, 0.3464, 0.3872, 0.4129, 0.3566, 0.3709, 0.4128, 0.3734])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_intermediate_state(self):
Expand Down