From f9aaa73ff8c8ce35fb4b2906ce60a6c8a17c1a7f Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Thu, 29 Feb 2024 11:32:28 +0530 Subject: [PATCH 1/2] fix: Assertion issue with SDXL Compel --- invokeai/app/invocations/compel.py | 26 ++++++++------------------ invokeai/backend/model_patcher.py | 8 ++++---- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 771c811eea0..ff136580523 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,17 +1,11 @@ -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple, Union, cast import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from transformers import CLIPTextModel, CLIPTokenizer - -from invokeai.app.invocations.fields import ( - FieldDescriptions, - Input, - InputField, - OutputField, - UIComponent, -) +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list @@ -25,12 +19,7 @@ ) from invokeai.backend.util.devices import torch_dtype -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - invocation, - invocation_output, -) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .model import ClipField # unconditioned: Optional[torch.Tensor] @@ -149,7 +138,7 @@ def run_clip_compel( assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, CLIPTextModel) + assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection)) # return zero on empty if prompt == "" and zero_on_empty: @@ -196,7 +185,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ): - assert isinstance(text_encoder, CLIPTextModel) + assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) + text_encoder = cast(CLIPTextModel, text_encoder) compel = Compel( tokenizer=tokenizer, text_encoder=text_encoder, diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index bee8909c311..473a0883085 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -4,12 +4,12 @@ import pickle from contextlib import contextmanager -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import torch from diffusers import OnnxRuntimeModel, UNet2DConditionModel -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from invokeai.app.shared.models import FreeUConfig from invokeai.backend.model_manager import AnyModel @@ -168,7 +168,7 @@ def apply_lora( def apply_ti( cls, tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, + text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], ti_list: List[Tuple[str, TextualInversionModelRaw]], ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: init_tokens_count = None @@ -265,7 +265,7 @@ def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionMod @contextmanager def apply_clip_skip( cls, - text_encoder: CLIPTextModel, + text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], clip_skip: int, ) -> None: skipped_layers = [] From ea35ec413c4cfb1a9e14f69c7f809fa06c1867b1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 29 Feb 2024 20:05:11 +1100 Subject: [PATCH 2/2] chore: ruff --- invokeai/backend/model_manager/search.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index b1ee0c22de9..7e89c394b1c 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -118,7 +118,7 @@ class ModelSearch(ModelSearchBase): """ models_found: Set[Path] = Field(default_factory=set) - config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() + config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() def search_started(self) -> None: self.models_found = set() @@ -147,9 +147,11 @@ def search(self, directory: Union[Path, str]) -> Set[Path]: def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None: absolute_path = Path(path) - if len(absolute_path.parts) - len(self._directory.parts) > max_depth \ - or not absolute_path.exists() \ - or absolute_path.parent in self.models_found: + if ( + len(absolute_path.parts) - len(self._directory.parts) > max_depth + or not absolute_path.exists() + or absolute_path.parent in self.models_found + ): return entries = os.scandir(absolute_path.as_posix()) entries = [entry for entry in entries if not entry.name.startswith(".")]