Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NPU Adaption for Sanna #10409

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
7 changes: 5 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module


Expand All @@ -74,6 +75,9 @@

logger = get_logger(__name__)

if is_torch_npu_available():
torch.npu.config.allow_internal_format = False


def save_model_card(
repo_id: str,
Expand Down Expand Up @@ -920,8 +924,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3147,7 +3147,16 @@ def __call__(
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
attn_mask = attention_mask[0]
seq_len = hidden_states.shape[1]
attention_mask = attn_mask.repeat_interleave(seq_len * batch_size, dim=0)
attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1])

if attention_mask.dtype != torch.uint8:
if attention_mask.dtype == torch.bool:
attention_mask = torch.logical_not(attention_mask.bool())
else:
attention_mask = attention_mask.to(torch.uint8)

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available
from ..attention_processor import (
Attention,
AttentionProcessor,
AttnProcessor2_0,
AttnProcessorNPU,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
Expand Down Expand Up @@ -119,6 +120,12 @@ def __init__(
# 2. Cross Attention
if cross_attention_dim is not None:
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

if is_torch_npu_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as in the other PR - let's not update default attn processor logic for now
we can manually set it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as in the other PR - let's not update default attn processor logic for now we can manually set it

I've updated the new one, please take a look. This can just use set up NPU FA directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will let you know when the full test is complete

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu It still needs to modify the sanna_transformer file, so I think to check in the init it;s the best option now

attn_processor = AttnProcessorNPU()
else:
attn_processor = AttnProcessor2_0()

self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
Expand All @@ -127,7 +134,7 @@ def __init__(
dropout=dropout,
bias=True,
out_bias=attention_out_bias,
processor=AttnProcessor2_0(),
processor=attn_processor,
)

# 3. Feed-forward
Expand Down