Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor rotary embedding 3: so it is not on cpu #9307

Merged
merged 2 commits into from
Aug 29, 2024
Merged

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Aug 28, 2024

fix #9299

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu changed the title refactor rotary embedding so it is not on cpu refactor rotary embedding 3: so it is not on cpu Aug 29, 2024
@yiyixuxu yiyixuxu requested a review from sayakpaul August 29, 2024 19:08
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@yiyixuxu possible to trigger a torch.compile() with PyTorch nightly to verify if this helps with the CUDAGraph issue? Code is in #9299 (comment).

Ccing @cpuhrsch maybe you would like to review it?

@yiyixuxu
Copy link
Collaborator Author

@sayakpaul yeah tested it is fine

@sayakpaul sayakpaul merged commit 61d96c3 into main Aug 29, 2024
18 checks passed
@sayakpaul sayakpaul deleted the fix-torch-rope branch August 29, 2024 19:37
@yiyixuxu
Copy link
Collaborator Author

@sayakpaul
Is this a reasonable script? I want to compare the performance against 0.30.1-patch before we introduce the rotary embedding refractor

import torch
import torch.utils.benchmark as benchmark
import gc

import time

torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True

import diffusers
from platform import python_version
from diffusers import DiffusionPipeline

print(diffusers.__version__)
print(torch.__version__)
print(python_version())

def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=torch.get_num_threads(),
    )
    return f"{(t0.blocked_autorange().mean):.3f}"

def bytes_to_giga_bytes(bytes):
    return f"{(bytes / 1024 / 1024 / 1024):.3f}"

def flush():
    """Wipes off memory."""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt_embeds = torch.load("flux_prompt_embeds.pt")
pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")

def run_inference(pipe):
    _ = pipe(
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        num_inference_steps=5,
        guidance_scale=3.5,
        max_sequence_length=512,
        generator=torch.manual_seed(42),
        height=1024,
        width=1024,
    )

flush()

time = benchmark_fn(run_inference)
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())  # in GBs.
print(f" Execution time: {time} sec")
print(f" Memory: {memory} gib")

theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
freqs = freqs.to(pos.device)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect this to cause a sync as well since by default arange allocates on the CPU. One way to mitigate could be to
a) use pin_memory() on freqs ahead of time and set non_blocking=True
b) do arange on the GPU right away (i.e. torch.arange([...], device=pos.device)).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhh let's do torch.arange([...], device=pos.device)

@sayakpaul
Copy link
Member

@yiyixuxu that looks reasonable but I'd call run_inference() maybe 2/3 times for warmups.

@@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed(
assert dim % 2 == 0

if isinstance(pos, int):
pos = np.arange(pos)
pos = torch.arange(pos)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be passed a device argument to allocate it on the GPU. If this isn't on the GPU, then neither will the following Tensors.

sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
change get_1d_rotary to accept pos as torch tensors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CUDAGRAPHs for Flux position embeddings
4 participants