Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot run a FP8 quantized model with LoraX #671

Open
2 of 4 tasks
Aktsvigun opened this issue Nov 11, 2024 · 2 comments
Open
2 of 4 tasks

Cannot run a FP8 quantized model with LoraX #671

Aktsvigun opened this issue Nov 11, 2024 · 2 comments

Comments

@Aktsvigun
Copy link

System Info

Lorax version:

Name: lorax-client
Version: 0.6.3
Summary: LoRAX Python Client
Home-page: https://github.com/predibase/lorax
Author: Travis Addair
Author-email: [email protected]
License: Apache-2.0
Location: /mnt/share/ai_studio/.venv/lib/python3.11/site-packages
Requires: aiohttp, certifi, huggingface-hub, pydantic
Required-by:

Platform:
linux, x86_64

nvidia-smi output:

Mon Nov 11 14:54:42 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:8B:00.0 Off |                    0 |
| N/A   28C    P0             75W /  700W |       1MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:8C:00.0 Off |                    0 |
| N/A   29C    P0             71W /  700W |       1MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Information

  • Docker
  • The CLI directly

Tasks

  • An officially supported command
  • My own modifications

Reproduction

docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/predibase/lorax:latest --model-id neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8

Expected behavior

I am using an official script to run LoRAX via docker from the official LoRAX page (section Launch LoRAX Server) - the only modification is the model id - I'm using FP8 quantized Llama-3.1-8b.
However, it seems that LoRAX's backend does not support FP8 models, as I'm getting a FP8-related error:

2024-11-11T14:47:20.230969Z ERROR lorax_launcher: server.py:311 Error when initializing model
Traceback (most recent call last):
  File "/opt/conda/bin/lorax-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 439, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 296, in serve_inner
    model = get_model(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 186, in get_model
    return FlashLlama(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in __init__
    super().__init__(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1128, in __init__
    model = model_cls(prefix, config, weights)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in __init__
    self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in __init__
    [
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in <listcomp>
    create_layer_fn(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in __init__
    self.self_attn = FlashLlamaAttention(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in __init__
    self.query_key_value = load_attention(config, prefix, weights, layer_id)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
    base_layer = load_attention_multi(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
    return _load_gqa(config, prefix, weights)
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 194, in _load_gqa
    weight = weights.get_multi_weights_col(
  File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/weights.py", line 141, in get_multi_weights_col
    weight = torch.cat(weight_list, dim=dim)
RuntimeError: "cat_cuda" not implemented for 'Float8_e4m3fn'

2024-11-11T14:47:21.122584Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:

2024-11-11 14:47:10.890 | INFO     | lorax_server.utils.state:<module>:22 - Backend = fa2
2024-11-11 14:47:10.890 | INFO     | lorax_server.utils.state:<module>:24 - Prefix caching = False
2024-11-11 14:47:10.890 | INFO     | lorax_server.utils.state:<module>:25 - Chunked prefill = False
/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'lorax_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead.
  return func(*args, **kwargs)
Traceback (most recent call last):

  File "/opt/conda/bin/lorax-server", line 8, in <module>
    sys.exit(app())

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
    server.serve(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 439, in serve
    asyncio.run(

  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)

  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 296, in serve_inner
    model = get_model(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/__init__.py", line 186, in get_model
    return FlashLlama(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in __init__
    super().__init__(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1128, in __init__
    model = model_cls(prefix, config, weights)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in __init__
    self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in __init__
    [

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in <listcomp>
    create_layer_fn(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in __init__
    self.self_attn = FlashLlamaAttention(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in __init__
    self.query_key_value = load_attention(config, prefix, weights, layer_id)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
    base_layer = load_attention_multi(config, prefix, weights)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
    return _load_gqa(config, prefix, weights)

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 194, in _load_gqa
    weight = weights.get_multi_weights_col(

  File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/weights.py", line 141, in get_multi_weights_col
    weight = torch.cat(weight_list, dim=dim)

RuntimeError: "cat_cuda" not implemented for 'Float8_e4m3fn'

Could you please investigate?

@ajtejankar
Copy link
Contributor

You need to add the following flag to your LoRAX launch command --quantize fp8

@ajtejankar
Copy link
Contributor

So, docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/predibase/lorax:latest --model-id neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 should become docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/predibase/lorax:latest --model-id neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --quantize fp8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants