Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] log a warning when there are missing keys in the LoRA loading. #9622

Merged
merged 20 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 84 additions & 20 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,14 +1358,30 @@ def load_lora_into_transformer(
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys not found in the model:"
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

# Offload back.
if is_model_cpu_offload:
Expand Down Expand Up @@ -1932,14 +1948,30 @@ def load_lora_into_transformer(
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys not found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

# Offload back.
if is_model_cpu_offload:
Expand Down Expand Up @@ -2279,14 +2311,30 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys not found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

# Offload back.
if is_model_cpu_offload:
Expand Down Expand Up @@ -2717,14 +2765,30 @@ def load_lora_into_transformer(
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys not found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

# Offload back.
if is_model_cpu_offload:
Expand Down
26 changes: 21 additions & 5 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,30 @@ def _process_lora(
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)

warn_msg = ""
if incompatible_keys is not None:
# check only for unexpected keys
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys not found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

return is_model_cpu_offload, is_sequential_cpu_offload

Expand Down
25 changes: 17 additions & 8 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from diffusers.utils.testing_utils import (
floats_tensor,
is_peft_available,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
slow,
Expand Down Expand Up @@ -166,7 +167,7 @@ def test_modify_padding_mode(self):
@slow
@require_torch_gpu
@require_peft_backend
@unittest.skip("We cannot run inference on this model with the current CI hardware")
# @unittest.skip("We cannot run inference on this model with the current CI hardware")
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
Expand Down Expand Up @@ -208,9 +209,11 @@ def test_flux_the_last_ben(self):
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090])
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the recent fixes in peft. The tests would pass with peft==0.12.0.


assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert max_diff < 1e-3

def test_flux_kohya(self):
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
Expand All @@ -230,7 +233,9 @@ def test_flux_kohya(self):
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert max_diff < 1e-3

def test_flux_kohya_with_text_encoder(self):
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
Expand All @@ -248,9 +253,11 @@ def test_flux_kohya_with_text_encoder(self):
).images

out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219])
expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert max_diff < 1e-3

def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
Expand All @@ -268,6 +275,8 @@ def test_flux_xlabs(self):
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980])
expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980])

max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
assert max_diff < 1e-3
91 changes: 89 additions & 2 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
LCMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
CaptureLogger,
floats_tensor,
require_peft_backend,
require_peft_version_greater,
Expand Down Expand Up @@ -219,10 +221,18 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules

if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
if (
"text_encoder" in lora_loadable_modules
and hasattr(pipe, "text_encoder")
and getattr(pipe.text_encoder, "peft_config", None) is not None
):
modules_to_save["text_encoder"] = pipe.text_encoder

if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
if (
"text_encoder_2" in lora_loadable_modules
and hasattr(pipe, "text_encoder_2")
and getattr(pipe.text_encoder_2, "peft_config", None) is not None
):
modules_to_save["text_encoder_2"] = pipe.text_encoder_2

if has_denoiser:
Expand Down Expand Up @@ -1747,6 +1757,83 @@ def test_simple_inference_with_dora(self):
"DoRA lora should change the output",
)

def test_missing_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
pipe.unload_lora_weights()
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

# To make things dynamic since we cannot settle with a single key for all the models where we
# offer PEFT support.
missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key]

logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline")
)
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)

# Since the missing key won't contain the adapter name ("default_0").
# Also strip out the component prefix (such as "unet." from `missing_key`).
component = list({k.split(".")[0] for k in state_dict})[0]
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))

def test_unexpected_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
pipe.unload_lora_weights()
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)

unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)

logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline")
)
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)

self.assertTrue(".diffusers_cat" in cap_logger.out)

@unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
"""
Expand Down
Loading