Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Using FlashInfer with FP8 model with FP8 KV cache produces an error #8641

Closed
1 task done
Syst3m1cAn0maly opened this issue Sep 19, 2024 · 3 comments · Fixed by #9861
Closed
1 task done

[Bug]: Using FlashInfer with FP8 model with FP8 KV cache produces an error #8641

Syst3m1cAn0maly opened this issue Sep 19, 2024 · 3 comments · Fixed by #9861
Labels
bug Something isn't working

Comments

@Syst3m1cAn0maly
Copy link

Syst3m1cAn0maly commented Sep 19, 2024

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

Model Input Dumps

No response

🐛 Describe the bug

When launching vLLM 0.6.1.post2 via docker for a FP8 quantized model containing k_scale and v_scale using the FlashInfer backend (set up with the corresponding env var), I get this error :

INFO 09-19 10:08:02 multiproc_worker_utils.py:123] Killing local vLLM worker processes
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner_base.py", line 112, in _wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1546, in execute_model
    hidden_or_intermediate_states = model_executable(
                                    ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 448, in forward
    model_output = self.model(input_ids, positions, kv_caches,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 329, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 251, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/llama.py", line 181, in forward
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layer.py", line 98, in forward
    return self.impl.forward(query,
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/attention/backends/flashinfer.py", line 749, in forward
    assert k_scale == 1.0 and v_scale == 1.0, (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: key/v_scale is not supported in FlashInfer.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 236, in run_rpc_server
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 34, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 573, in from_engine_args
    engine = cls(
             ^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 473, in __init__
    self.engine = self._engine_class(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 257, in __init__
    super().__init__(*args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 331, in __init__
    self._initialize_kv_caches()
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 460, in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/executor/distributed_gpu_executor.py", line 39, in determine_num_available_blocks
    num_blocks = self._run_workers("determine_num_available_blocks", )
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 199, in _run_workers
    driver_worker_output = driver_worker_method(*args, **kwargs)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 223, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1218, in profile_run
    self.execute_model(model_input, kv_caches, intermediate_tensors)
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner_base.py", line 125, in _wrapper
    pickle.dump(dumped_inputs, filep)
TypeError: cannot pickle 'flashinfer._prefill.BatchPrefillWithPagedKVCachePyTorchWrapper' object
[rank0]:[W919 10:08:03.463523741 ProcessGroupNCCL.cpp:1168] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
[rank0]:[W919 10:08:03.186378249 CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]
ERROR 09-19 10:08:05 api_server.py:188] RPCServer process died before responding to readiness probe

Is it a bug ?
How can I use a FP8 KV cache with scales and FlashInfer ?

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@Syst3m1cAn0maly Syst3m1cAn0maly added the bug Something isn't working label Sep 19, 2024
@pavanimajety
Copy link
Contributor

Hi @Syst3m1cAn0maly, this is currently not supported as the assertion suggests. The cause for the error is that reshape function doesn't work with torch.float8 & scales. I have a future to do to fix this.

@Syst3m1cAn0maly
Copy link
Author

Thanks for the quick answer.
Looking forward to the fix.

@desimonemike123
Copy link

Hi @pavanimajety , I'm wondering if I'm encountering the same issue as @Syst3m1cAn0maly mentioned above, hoping you can provide guidance, thx.

I'm trying to load a deepseek fp8 model and encountering issues.

Initially I received:
Cannot use FlashAttention-2 backend for FP8 KV cache. WARNING 09-27 17:43:56 selector.py:229] Please use FlashInfer backend with FP8 KV Cache for better performance by setting environment variable VLLM_ATTENTION_BACKEND=FLASHINFER
Which is then followed by:
AssertionError: fp8e4nv data type is not supported on CUDA arch < 89

I then set the VLLM_ATTENTION_BACKEND env-var, then receive the errors:
Failed to pickle inputs of failed execution: cannot pickle 'flashinfer._prefill.BatchPrefillWithPagedKVCachePyTorchWrapper' object which is then eventually followed by that same error AssertionError: fp8e4nv data type is not supported on CUDA arch < 89

I did review the issue at #7714 which appears related as well, but stuck at this point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants