Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama implementation with no tensor parallel linears #1561

Merged
merged 5 commits into from
Oct 5, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Oct 3, 2024

Summary:
Trying to demo llama with normal linear + quantized model + tensor parallelism works

  • verified correctness against original llama3 model
  • supported json-model-override-args in bench_latency script

Next: add pytorch native tensor parallelism test code for int8 weight only in torchao, diff from current llama model def: https://gist.github.com/jerryzh168/692ff83735d4ca298c1aad2424b2c225

Test Plan:

Using json-model-override-args to overwrite the name of the model

python3 -m sglang.bench_latency --correct --model meta-llama/Meta-Llama-3-8B --json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}'
Init nccl begin.
Load weight begin. avail mem=94.48 GB
INFO 10-04 15:00:53 weight_utils.py:236] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  2.46it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:00<00:00,  2.26it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  2.25it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:01<00:00,  3.11it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:01<00:00,  2.75it/s]

Load weight end. type=TorchNativeLlamaForCausalLM, dtype=torch.bfloat16, avail mem=79.41 GB

performance check

python3 -m sglang.bench_latency --model jerryzh168/llama3-8B --batch-size 1 --input 128 --output 8
python3 -m sglang.bench_latency --model jerryzh168/llama3-8B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128

max_total_num_tokens=631444
Warmup ...
Prefill. latency: 0.09536 s, throughput:   1342.32 token/s
Decode.  latency: 0.00538 s, throughput:    185.80 token/s
Decode.  latency: 0.00476 s, throughput:    209.91 token/s
Decode.  latency: 0.00466 s, throughput:    214.38 token/s
Decode.  median latency: 0.00476 s, median throughput:    209.91 token/s
Total. latency:  0.110 s, throughput:   1198.18 token/s
Benchmark ...
Prefill. latency: 0.06534 s, throughput:   1958.93 token/s
Decode.  latency: 0.00502 s, throughput:    199.16 token/s
Decode.  latency: 0.00476 s, throughput:    210.03 token/s
Decode.  latency: 0.00469 s, throughput:    213.19 token/s
Decode.  latency: 0.00466 s, throughput:    214.77 token/s
Decode.  latency: 0.00466 s, throughput:    214.74 token/s
Decode.  median latency: 0.00469 s, median throughput:    213.19 token/s
Total. latency:  0.098 s, throughput:   1381.24 token/s

Accuracy check:


# python3 scripts/playground/reference_hf.py --model meta-llama/Meta-Llama-3-8B
========== Prompt 0 ==========
prefill logits (final) tensor([ 5.0195,  3.0801,  0.7422,  ..., -7.4805, -7.4805, -7.4805],
       device='cuda:0')
<|begin_of_text|>The capital of France is Paris. It is located in the north of the country. The city is situated
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.

========== Prompt 1 ==========
prefill logits (final) tensor([ 5.2109,  4.2344,  1.8408,  ..., -7.5195, -7.5195, -7.5195],
       device='cuda:0')
<|begin_of_text|>The capital of the United Kindom is London. It is the largest city in the UK and the largest city in the
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.

========== Prompt 2 ==========
prefill logits (final) tensor([ 9.5391,  3.1914,  0.8188,  ..., -7.0469, -7.0469, -7.0469],
       device='cuda:0')
<|begin_of_text|>Today is a sunny day and I like to go out and enjoy the sun. I am going to the beach with my

# python3 scripts/playground/reference_hf.py --model jerryzh168/llama3-8B
========== Prompt 0 ==========
prefill logits (final) tensor([ 5.0195,  3.0801,  0.7422,  ..., -7.4805, -7.4805, -7.4805],
       device='cuda:0')
<|begin_of_text|>The capital of France is Paris. It is located in the north of the country. The city is situated
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.

========== Prompt 1 ==========
prefill logits (final) tensor([ 5.2109,  4.2344,  1.8408,  ..., -7.5195, -7.5195, -7.5195],
       device='cuda:0')
<|begin_of_text|>The capital of the United Kindom is London. It is the largest city in the UK and the largest city in the
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.

========== Prompt 2 ==========
prefill logits (final) tensor([ 9.5391,  3.1914,  0.8188,  ..., -7.0469, -7.0469, -7.0469],
       device='cuda:0')
<|begin_of_text|>Today is a sunny day and I like to go out and enjoy the sun. I am going to the beach with my


# python3 -m sglang.bench_latency --correct --model meta-llama/Meta-Llama-3-8B
max_total_num_tokens=557684

input_ids=[[128000, 791, 6864, 315, 9822, 374], [128000, 791, 6864, 315, 279, 3723, 17262, 316, 374], [128000, 15724, 374, 264, 40798, 1938, 323, 358, 1093]]

prefill logits (first half): tensor([[ 1.9609,  2.1094, -1.2500,  ..., -5.5000, -5.5000, -5.5000],
        [ 1.9609,  2.1094, -1.2500,  ..., -5.5000, -5.5000, -5.5000],
        [ 2.2969,  2.9531,  2.1406,  ..., -8.3750, -8.3750, -8.3750]],
       device='cuda:0')

prefill logits (final): tensor([[ 5.0312,  3.1094,  0.7500,  ..., -7.4375, -7.4375, -7.4375],
        [ 5.2188,  4.2188,  1.8359,  ..., -7.5312, -7.5312, -7.5312],
        [ 9.5000,  3.1406,  0.7891,  ..., -7.0938, -7.0938, -7.0938]],
       device='cuda:0')

========== Prompt 0 ==========
<|begin_of_text|>The capital of France is Paris. It is located in the north of the country. It is the largest

========== Prompt 1 ==========
<|begin_of_text|>The capital of the United Kindom is London. It is the largest city in the UK and the largest city in the

========== Prompt 2 ==========
<|begin_of_text|>Today is a sunny day and I like to go out for a walk. I am going to the park. I am


# python3 -m sglang.bench_latency --correct --model jerryzh168/llama3-8B
Load weight end. type=TorchNativeLlamaForCausalLM, dtype=torch.bfloat16, avail mem=79.41 GB
Memory pool end. avail mem=11.16 GB
Capture cuda graph begin. This can take up to several minutes.
max_total_num_tokens=557684

input_ids=[[128000, 791, 6864, 315, 9822, 374], [128000, 791, 6864, 315, 279, 3723, 17262, 316, 374], [128000, 15724, 374, 264, 40798, 1938, 323, 358, 1093]]

prefill logits (first half): tensor([[ 1.9609,  2.1094, -1.2500,  ..., -5.5000, -5.5000, -5.5000],
        [ 1.9609,  2.1094, -1.2500,  ..., -5.5000, -5.5000, -5.5000],
        [ 2.2969,  2.9531,  2.1406,  ..., -8.3750, -8.3750, -8.3750]],
       device='cuda:0')

prefill logits (final): tensor([[ 5.0312,  3.1094,  0.7500,  ..., -7.4375, -7.4375, -7.4375],
        [ 5.2188,  4.2188,  1.8359,  ..., -7.5312, -7.5312, -7.5312],
        [ 9.5000,  3.1406,  0.7891,  ..., -7.0938, -7.0938, -7.0938]],
       device='cuda:0')

========== Prompt 0 ==========
<|begin_of_text|>The capital of France is Paris. It is located in the north of the country. Paris is the largest

========== Prompt 1 ==========
<|begin_of_text|>The capital of the United Kindom is London. It is the largest city in the UK and the largest city in the

========== Prompt 2 ==========
<|begin_of_text|>Today is a sunny day and I like to go out for a walk. I am going to the park. I am

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Is `TorchNativeLlamaForCausalLM a better name?
  2. Did you test the correctness?
    - Get the reference output by `python3 scripts/playground/reference_hf.py --model [new model]`
    - Get the SGLang output by `python3 -m sglang.bench_latency --correct --model [new model]`
  3. Maybe we can add some arguments that allow using this model implementation without using a new checkpoint. We have some arguments like
    # Model override args
    parser.add_argument(
    "--json-model-override-args",
    type=str,
    help="A dictionary in JSON string format used to override default model configurations.",
    default=ServerArgs.json_model_override_args,
    )
    to override the model configs. I am not sure whether it works.

Summary:
Trying to demo llama with normal lineaer + quantized model + tensor parallelism works

Test Plan:
TODO

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 requested a review from merrymercy October 5, 2024 00:00
@merrymercy merrymercy merged commit 9b0926c into sgl-project:main Oct 5, 2024
2 of 11 checks passed
@merrymercy
Copy link
Contributor

@jerryzh168 Thanks! It is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants