-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor rotary embedding 3: so it is not on cpu #9307
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@yiyixuxu possible to trigger a torch.compile()
with PyTorch nightly to verify if this helps with the CUDAGraph issue? Code is in #9299 (comment).
Ccing @cpuhrsch maybe you would like to review it?
@sayakpaul yeah tested it is fine |
@sayakpaul import torch
import torch.utils.benchmark as benchmark
import gc
import time
torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True
import diffusers
from platform import python_version
from diffusers import DiffusionPipeline
print(diffusers.__version__)
print(torch.__version__)
print(python_version())
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def bytes_to_giga_bytes(bytes):
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
def flush():
"""Wipes off memory."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
prompt_embeds = torch.load("flux_prompt_embeds.pt")
pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")
def run_inference(pipe):
_ = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=5,
guidance_scale=3.5,
max_sequence_length=512,
generator=torch.manual_seed(42),
height=1024,
width=1024,
)
flush()
time = benchmark_fn(run_inference)
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
print(f" Execution time: {time} sec")
print(f" Memory: {memory} gib") |
theta = theta * ntk_factor | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] | ||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] | ||
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] | ||
freqs = freqs.to(pos.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect this to cause a sync as well since by default arange
allocates on the CPU. One way to mitigate could be to
a) use pin_memory()
on freqs
ahead of time and set non_blocking=True
b) do arange
on the GPU right away (i.e. torch.arange([...], device=pos.device)
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohhh let's do torch.arange([...], device=pos.device)
@yiyixuxu that looks reasonable but I'd call |
@@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed( | |||
assert dim % 2 == 0 | |||
|
|||
if isinstance(pos, int): | |||
pos = np.arange(pos) | |||
pos = torch.arange(pos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be passed a device
argument to allocate it on the GPU. If this isn't on the GPU, then neither will the following Tensors.
change get_1d_rotary to accept pos as torch tensors
fix #9299