-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add performance data sdxl for RTX 4090 24G and 48G #1041
base: main
Are you sure you want to change the base?
Conversation
…ediff into add_Performance_data_sdxl
…ediff into add_Performance_data_sdxl
| OneDiff Max reserved CUDA memory Used| | |14.873 GiB |14.859 GiB |35.666 GiB | | ||
| PyTorch Warmup with Run time | | | | | | | ||
| OneDiff Warmup with Compilation time | 474.36 s <sup>1</sup> | 236.54 s <sup>2</sup> |142.691 s <sup>3</sup> |287.011 s <sup>3</sup> |502.223 s <sup>3</sup> | | ||
| OneDiff Warmup with Cache time | 306.84 s | 104.57 s |142.992s |132.207 s |363.051 s | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oom 时的报错信息是什么,可以发下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错信息如下
[2024-07-26 16:43:01,384] [INFO] [graphs.py:34:dynamic_graphed_callable] Dynamically CUDA graphing ModuleToBeGraphed
/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/cuda/graphs.py:83: UserWarning: The CUDA Graph is empty. This usually means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:222.)
super().capture_end()
[2024-07-26 16:48:29,566] [ERROR] [graphs.py:112:make_graphed_callable] Failed to capture CUDA Graph, please try without it
[2024-07-26 16:48:29,567] [ERROR] [graphs.py:38:dynamic_graphed_callable] Failed to dynamically CUDA graph ModuleToBeGraphed
Traceback (most recent call last):
File "/root/project/nexfort/src/nexfort/cuda/graphs.py", line 110, in make_graphed_callable
static_outputs = func(*static_inputs, **static_kwarg_inputs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/project/nexfort/src/nexfort/fx_compiler/fx_compiler.py", line 88, in forward
return self.compiled_fn(*args)
File "/root/project/nexfort/src/nexfort/fx_compiler/overrides.py", line 74, in wrapper
return func(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 987, in forward
return compiled_fn(full_args)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 217, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 451, in wrapper
return compiled_fn(runtime_args)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1131, in __call__
return self.current_callable(inputs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 944, in run
return model(new_inputs)
File "/tmp/torchinductor_root/cb/ccbplrs7ajzxwnuf4q4zztbyhyulafddo6say73bhvlyhhozur3c.py", line 2917, in call
buf351 = torch.ops.nexfort_cuda.cudnn_convolution_bias_add_act.default(buf350, arg102_1, arg103_1, None, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, None)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_ops.py", line 667, in __call__
return self_._op(*args, **kwargs)
RuntimeError: FIND was unable to find an engine to execute this computation
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/root/project/nexfort/src/nexfort/cuda/graphs.py", line 36, in dynamic_graphed_callable
cached_callable = simple_make_graphed_callable(func, args, kwargs, warmups=warmups)
File "/root/project/nexfort/src/nexfort/cuda/graphs.py", line 58, in simple_make_graphed_callable
return make_graphed_callable(
File "/root/project/nexfort/src/nexfort/cuda/graphs.py", line 109, in make_graphed_callable
with torch.cuda.graph(fwd_graph, pool=execution_env.mempool, stream=execution_env.stream):
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/cuda/graphs.py", line 185, in __exit__
self.cuda_graph.capture_end()
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/cuda/graphs.py", line 83, in capture_end
super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Traceback (most recent call last):
File "/root/project/onediff/benchmarks/text_to_image.py", line 428, in <module>
main()
File "/root/project/onediff/benchmarks/text_to_image.py", line 360, in main
pipe(**get_kwarg_inputs())
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1289, in __call__
image = self.vae.decode(latents, return_dict=False)[0]
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 314, in decode
decoded = self._decode(z).sample
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 285, in _decode
dec = self.decoder(z)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/project/onediff/src/onediff/infer_compiler/backends/nexfort/deployable_module.py", line 27, in forward
return self._deployable_module_model(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/diffusers/models/autoencoders/vae.py", line 284, in forward
def forward(
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/root/project/nexfort/src/nexfort/cuda/graphs.py", line 43, in dynamic_graphed_callable
return cached_callable(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/project/nexfort/src/nexfort/fx_compiler/fx_compiler.py", line 88, in forward
return self.compiled_fn(*args)
File "/root/project/nexfort/src/nexfort/fx_compiler/overrides.py", line 74, in wrapper
return func(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 987, in forward
return compiled_fn(full_args)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 217, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 451, in wrapper
return compiled_fn(runtime_args)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1131, in __call__
return self.current_callable(inputs)
File "/root/anaconda3/envs/sd2/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 944, in run
return model(new_inputs)
File "/tmp/torchinductor_root/cb/ccbplrs7ajzxwnuf4q4zztbyhyulafddo6say73bhvlyhhozur3c.py", line 2726, in call
buf267 = empty_strided_cuda((1, 512, 1024, 1024), (536870912, 1, 524288, 512), torch.float32)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 31.51 GiB of which 196.06 MiB is free. Including non-PyTorch memory, this process has 31.31 GiB memory in use. Of the allocated memory 28.34 GiB is allocated by PyTorch, with 19.83 GiB allocated in private pools (e.g., CUDA Graphs), and 2.63 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 31.51 GiB of which 196.06 MiB is free. Including non-PyTorch memory, this process has 31.31 GiB memory in use. Of the allocated memory 28.34 GiB is allocated by PyTorch, with 19.83 GiB allocated in private pools (e.g., CUDA Graphs), and 2.63 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
this process has 31.31 GiB memory in use
看起来是否会 oom,Max reserved CUDA memory Used 更有参考价值。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sdxl on 4090(32GB)、4090(48GB)