Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Assertion issue with SDXL Compel #5827

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, Optional, Tuple, Union, cast

import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTokenizer

from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
Expand All @@ -25,12 +19,7 @@
)
from invokeai.backend.util.devices import torch_dtype

from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import ClipField

# unconditioned: Optional[torch.Tensor]
Expand Down Expand Up @@ -149,7 +138,7 @@ def run_clip_compel(
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))

# return zero on empty
if prompt == "" and zero_on_empty:
Expand Down Expand Up @@ -196,7 +185,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
Expand Down
10 changes: 6 additions & 4 deletions invokeai/backend/model_manager/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ModelSearch(ModelSearchBase):
"""

models_found: Set[Path] = Field(default_factory=set)
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()

def search_started(self) -> None:
self.models_found = set()
Expand Down Expand Up @@ -147,9 +147,11 @@ def search(self, directory: Union[Path, str]) -> Set[Path]:

def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
absolute_path = Path(path)
if len(absolute_path.parts) - len(self._directory.parts) > max_depth \
or not absolute_path.exists() \
or absolute_path.parent in self.models_found:
if (
len(absolute_path.parts) - len(self._directory.parts) > max_depth
or not absolute_path.exists()
or absolute_path.parent in self.models_found
):
return
entries = os.scandir(absolute_path.as_posix())
entries = [entry for entry in entries if not entry.name.startswith(".")]
Expand Down
8 changes: 4 additions & 4 deletions invokeai/backend/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import pickle
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
import torch
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel
Expand Down Expand Up @@ -168,7 +168,7 @@ def apply_lora(
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
init_tokens_count = None
Expand Down Expand Up @@ -265,7 +265,7 @@ def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionMod
@contextmanager
def apply_clip_skip(
cls,
text_encoder: CLIPTextModel,
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
clip_skip: int,
) -> None:
skipped_layers = []
Expand Down
Loading