Skip to content

Commit

Permalink
Merge pull request #1402 from d8ahazard/dev
Browse files Browse the repository at this point in the history
Fix broken diffusers import
  • Loading branch information
d8ahazard authored Dec 6, 2023
2 parents c548ede + e07dea9 commit d4c534f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
35 changes: 33 additions & 2 deletions dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
DEISMultistepScheduler,
UniPCMultistepScheduler, StableDiffusionXLPipeline, StableDiffusionPipeline
)
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.attention_processor import LoRAAttnProcessor2_0, LoRAAttnProcessor
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import logging as dl
from diffusers.utils.torch_utils import randn_tensor
Expand Down Expand Up @@ -102,6 +102,37 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.stack.__exit__(exc_type, exc_value, traceback)

def text_encoder_lora_state_dict(text_encoder):
state_dict = {}

def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection

attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))

return attn_modules

for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

return state_dict


def check_and_patch_scheduler(scheduler_class):
if not hasattr(scheduler_class, 'get_velocity'):
Expand Down
20 changes: 10 additions & 10 deletions postinstall.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def check_bitsandbytes():
else:
win_dll = os.path.join(venv_path, "lib", "site-packages", "bitsandbytes", "libbitsandbytes_cuda118.dll")
print(f"Checking for {win_dll}")
if not os.path.exists(win_dll):
if not os.path.exists(win_dll) or "0.41.2" not in bitsandbytes_version:
print("Can't find bitsandbytes CUDA dll. Installing bitsandbytes")
try:
pip_uninstall("bitsandbytes")
Expand All @@ -232,23 +232,23 @@ def check_bitsandbytes():
print("Installing bitsandbytes")
try:
pip_install(
"--prefer-binary", "https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl")
"--prefer-binary", "https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl")
except Exception as e:
print("Bitsandbytes 0.41.1 installation failed")
print("Bitsandbytes 0.41.2.post2 installation failed")
print("Some features such as 8bit optimizers will be unavailable")
print_bitsandbytes_installation_error(str(e))
pass
else:
print("Checking bitsandbytes (Linux)")
if bitsandbytes_version != "0.41.1":
if "0.41.2" not in bitsandbytes_version:
try:
print("Installing bitsandbytes")
pip_install("--force-install","--prefer-binary","https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl")
pip_install("bitsandbytes==0.41.2.post2","--prefer-binary","https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl")
except:
print("Bitsandbytes 0.41.1 installation failed")
print("Bitsandbytes 0.41.2 installation failed")
print("Some features such as 8bit optimizers will be unavailable")
print("Install manually with")
print("'python -m pip install bitsandbytes==0.41.1 --prefer-binary --force-install'")
print("'python -m pip install bitsandbytes==0.41.2.post2 --prefer-binary --force-install'")
pass


Expand All @@ -273,7 +273,7 @@ def check_versions():
Dependency(module="accelerate", version="0.21.0"),
Dependency(module="diffusers", version="0.22.1"),
Dependency(module="transformers", version="4.30.2"),
Dependency(module="bitsandbytes", version="0.41.1", required=False),
Dependency(module="bitsandbytes", version="0.41.1.post2", required=False),
]

launch_errors = []
Expand Down Expand Up @@ -347,7 +347,7 @@ def print_bitsandbytes_installation_error(err):
print("cd ../..")
print("# WINDOWS ONLY: ")
print(
"pip install --prefer-binary --force-reinstall https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl")
"pip install --prefer-binary --force-reinstall https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl")
print("#######################################################################################################")

def print_xformers_installation_error(err):
Expand Down Expand Up @@ -388,7 +388,7 @@ def print_launch_errors(launch_errors):
print("activate")
print("cd ../..")
print("pip install -r ./extensions/sd_dreambooth_extension/requirements.txt")
print("pip install --prefer-binary --force-reinstall https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl")
print("pip install --prefer-binary --force-reinstall https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl")
print("#######################################################################################################")


Expand Down
5 changes: 4 additions & 1 deletion scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,9 @@ def toggle_advanced():
c4_save_sample_prompt,
c4_save_sample_template,
]
for element in params_to_save:
setattr(element, "do_not_save_to_config", True)

# Do not load these values when 'load settings' is clicked
params_to_exclude = [
db_model_name,
Expand Down Expand Up @@ -1931,7 +1934,7 @@ def set_gen_sample():
outputs=[],
)

return ((dreambooth_interface, "Dreambooth", "dreambooth_interface"),)
return ((dreambooth_interface, "Dreambooth", "dreambooth_v2"),)


def build_concept_panel(concept: int):
Expand Down

0 comments on commit d4c534f

Please sign in to comment.