Skip to content

Commit

Permalink
Reduce the likelihood of encountering #7513 by elminating places wher…
Browse files Browse the repository at this point in the history
…e the door was left open for this to happen.
  • Loading branch information
RyanJDick committed Jan 3, 2025
1 parent d1da699 commit caab87b
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 63 deletions.
11 changes: 3 additions & 8 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ class CompelInvocation(BaseInvocation):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)

def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
Expand All @@ -76,12 +73,13 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:

# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]

text_encoder_info = context.models.load(self.clip.text_encoder)
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)

with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
context.models.load(self.clip.tokenizer) as tokenizer,
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=_lora_loader(),
Expand Down Expand Up @@ -140,9 +138,7 @@ def run_clip_compel(
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)

# return zero on empty
if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.model
Expand Down Expand Up @@ -180,7 +176,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
context.models.load(clip_field.tokenizer) as tokenizer,
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=_lora_loader(),
Expand Down Expand Up @@ -226,7 +222,6 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:

del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info

c = c.detach().to("cpu")
Expand Down
14 changes: 4 additions & 10 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def prep_ip_adapter_image_prompts(
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
Expand All @@ -556,7 +555,7 @@ def prep_ip_adapter_image_prompts(
single_ipa_images = [
context.images.get_pil(image.image_name, mode="RGB") for image in single_ipa_image_fields
]
with image_encoder_model_info as image_encoder_model:
with context.models.load(single_ip_adapter.image_encoder_model) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
Expand Down Expand Up @@ -621,7 +620,6 @@ def run_t2i_adapters(
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")

# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
Expand All @@ -637,7 +635,7 @@ def run_t2i_adapters(
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")

t2i_adapter_model: T2IAdapter
with t2i_adapter_loaded_model as t2i_adapter_model:
with context.models.load(t2i_adapter_field.t2i_adapter_model) as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor

# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
Expand Down Expand Up @@ -926,10 +924,8 @@ def step_callback(state: PipelineIntermediateState) -> None:
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (cached_weights, unet),
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(denoise_ctx),
Expand Down Expand Up @@ -995,11 +991,9 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
del lora_info
return

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (cached_weights, unet),
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
Expand Down
27 changes: 15 additions & 12 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ def _run_diffusion(
else None
)

transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
transformer_config = context.models.get_config(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")

# Calculate the timestep schedule.
timesteps = get_schedule(
Expand Down Expand Up @@ -299,9 +299,11 @@ def _run_diffusion(
)

# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
(cached_weights, transformer) = exit_stack.enter_context(
context.models.load(self.transformer.transformer).model_on_device()
)
assert isinstance(transformer, Flux)
config = transformer_info.config
config = transformer_config
assert config is not None

# Determine if the model is quantized.
Expand Down Expand Up @@ -512,15 +514,18 @@ def _prep_controlnet_extensions(
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.

# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]

# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
# keep peak memory down.
controlnet_conds: list[torch.Tensor] = []
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
for controlnet in controlnets:
image = context.images.get_pil(controlnet.image.image_name)

# HACK(ryand): We have to load the ControlNet model to determine whether the VAE needs to be run. We really
# shouldn't have to load the model here. There's a risk that the model will be dropped from the model cache
# before we load it into VRAM and thus we'll have to load it again (context:
# https://github.com/invoke-ai/InvokeAI/issues/7513).
controlnet_model = context.models.load(controlnet.control_model)
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
Expand Down Expand Up @@ -550,10 +555,8 @@ def _prep_controlnet_extensions(

# Finally, load the ControlNet models and initialize the ControlNet extensions.
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
for controlnet, controlnet_cond, controlnet_model in zip(
controlnets, controlnet_conds, controlnet_models, strict=True
):
model = exit_stack.enter_context(controlnet_model)
for controlnet, controlnet_cond in zip(controlnets, controlnet_conds, strict=True):
model = exit_stack.enter_context(context.models.load(controlnet.control_model))

if isinstance(model, XLabsControlNetFlux):
controlnet_extensions.append(
Expand Down
19 changes: 7 additions & 12 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,11 @@ def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
)

def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)

prompt = [self.prompt]

with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
Expand All @@ -90,22 +87,20 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
return prompt_embeds

def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)

prompt = [self.prompt]

clip_text_encoder_info = context.models.load(self.clip.text_encoder)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None

with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
context.models.load(self.clip.tokenizer) as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)

clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None

# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
Expand Down
13 changes: 4 additions & 9 deletions invokeai/app/invocations/sd3_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,11 @@ def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:

def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)

prompt = [self.prompt]

with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
):
context.util.signal_progress("Running T5 encoder")
assert isinstance(t5_text_encoder, T5EncoderModel)
Expand Down Expand Up @@ -129,14 +126,12 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)

prompt = [self.prompt]

clip_text_encoder_info = context.models.load(clip_model.text_encoder)
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
context.models.load(clip_model.tokenizer) as clip_tokenizer,
ExitStack() as exit_stack,
):
context.util.signal_progress("Running CLIP encoder")
Expand Down
10 changes: 2 additions & 8 deletions invokeai/app/invocations/spandrel_image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,14 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")

# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)

def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
message=f"Processing tile {step}/{total_steps}",
percentage=step / total_steps,
)

# Do the upscaling.
with spandrel_model_info as spandrel_model:
with context.models.load(self.image_to_image_model) as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)

# Upscale the image
Expand Down Expand Up @@ -206,9 +203,6 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")

# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)

# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
target_width = int(image.width * self.scale)
Expand All @@ -221,7 +215,7 @@ def step_callback(iteration: int, step: int, total_steps: int) -> None:
)

# Do the upscaling.
with spandrel_model_info as spandrel_model:
with context.models.load(self.image_to_image_model) as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)

iteration = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,9 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
yield (lora_info.model, lora.weight)
del lora_info

# Load the UNet model.
unet_info = context.models.load(self.unet.unet)

with (
ExitStack() as exit_stack,
unet_info as unet,
context.models.load(self.unet.unet) as unet,
LayerPatcher.apply_smart_model_patches(
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
),
Expand Down

0 comments on commit caab87b

Please sign in to comment.