Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AWQ quantization support for LLaMA #1032

Merged
merged 95 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
fed3d61
add quantisation config
robirv938 Aug 14, 2023
0f936d0
pass down quantisation setings
robirv938 Aug 14, 2023
520394a
american englihs
robirv938 Aug 14, 2023
640cedf
llama add the code for quantization
robirv938 Aug 14, 2023
0437ffa
update
robirv938 Aug 14, 2023
861d3d7
merge in the AWQ code with note saying its source
robirv938 Aug 14, 2023
d659e95
update
robirv938 Aug 14, 2023
c0e4862
update
ri938 Aug 14, 2023
fed311e
update
ri938 Aug 14, 2023
0937043
fix loading of layers
ri938 Aug 14, 2023
7109bd3
update
ri938 Aug 14, 2023
5bd5ed6
update
ri938 Aug 14, 2023
c3cc5ed
quantization config is part of the model config
ri938 Aug 15, 2023
02bdfed
function
ri938 Aug 15, 2023
2f97151
update
ri938 Aug 15, 2023
c39ec2a
Merge pull request #2 from ri938/add_awq_improvements
ri938 Aug 15, 2023
e5434ef
working prototype
ri938 Aug 15, 2023
ff4d693
merge linear layers
ri938 Aug 15, 2023
033e8c1
update
ri938 Aug 15, 2023
a3ac858
Merge pull request #3 from ri938/merge_linear_layers
ri938 Aug 15, 2023
974bf06
Add quant layer in Row and Column parallel.
casper-hansen Aug 15, 2023
fbaf889
fix pylint errors
ri938 Aug 16, 2023
db4db0c
improve the quant weight loaded code
ri938 Aug 16, 2023
73db30f
Merge pull request #5 from ri938/more_improvements_awq
ri938 Aug 16, 2023
ee7116a
Merge remote-tracking branch 'upstream/add_awq_quant_support' into ad…
casper-hansen Aug 16, 2023
8ff92c7
Loading works, Refactored Quant into Row/Column Parallel
casper-hansen Aug 16, 2023
f5e8d15
WIP. Try to load MPT.
casper-hansen Aug 16, 2023
409e290
add quantization utils
julian-q Aug 23, 2023
a6193cc
consolidate quantization operations
julian-q Aug 23, 2023
37839a9
tweak quant config
julian-q Aug 23, 2023
c90cc44
rename to WeightQuantizationConfig
julian-q Aug 23, 2023
e0520fe
unify TP linear layer forward pass
julian-q Aug 23, 2023
b1d2639
fix shape bugs
julian-q Aug 31, 2023
7f1a80a
move quantized linear to funcitonal
julian-q Aug 31, 2023
7277fcb
streamline quantized weight loading
julian-q Aug 31, 2023
d860aa7
clean up packed dimension calculation
julian-q Aug 31, 2023
9735df9
use pack factor member
julian-q Aug 31, 2023
3dc0f59
fix tensor parallelism support
julian-q Sep 1, 2023
ce97430
run autoformat, fix styling
julian-q Sep 4, 2023
ac066f4
revert MPT quantization for now
julian-q Sep 4, 2023
d605a80
add link to citation
julian-q Sep 4, 2023
269ccfa
clean up weightquantconfig
julian-q Sep 4, 2023
d8ee12a
placeholders for transposed, packed in other models
julian-q Sep 4, 2023
d2d10b9
decompose and document special weight configs
julian-q Sep 4, 2023
a2018e2
fix parameter default
julian-q Sep 4, 2023
3e5a3b4
create quantized_linear.py, simplify AWQ param
julian-q Sep 4, 2023
05dfe27
clean up AWQ constants
julian-q Sep 4, 2023
a4ea6ba
Merge branch 'main' into add_awq_quant_support
WoosukKwon Sep 13, 2023
d1864cf
Minor
WoosukKwon Sep 13, 2023
36c447d
Minor
WoosukKwon Sep 13, 2023
0c3ff9a
Move QuantizationConfig to model_executor
WoosukKwon Sep 13, 2023
03c48dd
Minor
WoosukKwon Sep 13, 2023
28aa926
Implement AWQConfig
WoosukKwon Sep 13, 2023
1270771
Minor
WoosukKwon Sep 13, 2023
c66911e
Refactoring
WoosukKwon Sep 13, 2023
a4e6138
Minor
WoosukKwon Sep 13, 2023
6d80e03
Minor
WoosukKwon Sep 13, 2023
70b4f69
Minor
WoosukKwon Sep 13, 2023
1bb701d
Minor
WoosukKwon Sep 13, 2023
9c303ff
Download quant config from HF hub
WoosukKwon Sep 13, 2023
f391e25
Minor
WoosukKwon Sep 13, 2023
0f4a8ee
Fix logic for finding quant config files
WoosukKwon Sep 13, 2023
7698b1d
Minor
WoosukKwon Sep 13, 2023
1fd2a0b
Minor
WoosukKwon Sep 13, 2023
712a2bc
Minor
WoosukKwon Sep 13, 2023
edacf80
Minor
WoosukKwon Sep 14, 2023
d1abfce
Minor
WoosukKwon Sep 14, 2023
a15355c
Refactor
WoosukKwon Sep 14, 2023
005d315
Support *.pt
WoosukKwon Sep 14, 2023
5bedc7c
Revert back
WoosukKwon Sep 14, 2023
f4ebb84
Add AWQ parallel linears
WoosukKwon Sep 14, 2023
4b333a3
Merge branch 'main' into add_awq_quant_support
WoosukKwon Sep 14, 2023
f574f20
Minor
WoosukKwon Sep 14, 2023
93cb2ff
Fix
WoosukKwon Sep 14, 2023
f615a08
Add dtype
WoosukKwon Sep 14, 2023
650bc5c
Add a comment
WoosukKwon Sep 14, 2023
8cc1be6
yapf
WoosukKwon Sep 14, 2023
3f1d71b
Add ParallelLinear
WoosukKwon Sep 14, 2023
18dde56
getattr -> get
WoosukKwon Sep 14, 2023
a3d4d42
Minor bugfix
WoosukKwon Sep 14, 2023
44cb5c4
Requires_grad = False
WoosukKwon Sep 14, 2023
2fcade6
Minor
WoosukKwon Sep 14, 2023
74365a9
Define packed & transposed
WoosukKwon Sep 14, 2023
59707c4
Add docstring
WoosukKwon Sep 14, 2023
7c08f5a
Remove unused
WoosukKwon Sep 14, 2023
1593d0b
Add weight loading logics
WoosukKwon Sep 14, 2023
3ee4a2b
Add awq folder
WoosukKwon Sep 14, 2023
c3e0a9f
Minor
WoosukKwon Sep 14, 2023
4d69e0e
Minor
WoosukKwon Sep 14, 2023
9d60b1f
Fix
WoosukKwon Sep 14, 2023
ffebfbb
Fix TP
WoosukKwon Sep 14, 2023
7df3cd5
Minor
WoosukKwon Sep 14, 2023
745e1f9
Fix bias
WoosukKwon Sep 14, 2023
66f0493
Add -q option for quantization
WoosukKwon Sep 16, 2023
fcd04d1
Add quantization option for benchmark scripts
WoosukKwon Sep 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,7 @@ cython_debug/

# Sphinx documentation
_build/

# vim swap files
*.swo
*.swp
18 changes: 14 additions & 4 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
llm = LLM(
model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
Expand Down Expand Up @@ -63,19 +64,28 @@ def run_to_completion(profile: bool = False):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1,
parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3,
parser.add_argument('--num-iters',
type=int,
default=3,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true',
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
main(args)
74 changes: 44 additions & 30 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import random
import time
from typing import List, Tuple
from typing import List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
Expand All @@ -22,15 +22,10 @@ def sample_requests(
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]

# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
Expand Down Expand Up @@ -63,6 +58,7 @@ def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
Expand All @@ -72,6 +68,7 @@ def run_vllm(
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
Expand Down Expand Up @@ -111,8 +108,8 @@ def run_hf(
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -132,13 +129,14 @@ def run_hf(
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) + max(
max_output_len, next_output_len)) <= 2048:
if (max(max_prompt_len, next_prompt_len) +
max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue

# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
Expand All @@ -165,44 +163,58 @@ def main(args: argparse.Namespace):
random.seed(args.seed)

# Sample the requests.
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)

if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(
requests, args.model, tokenizer, args.n, args.use_beam_search,
args.hf_max_batch_size, args.trust_remote_code)
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(
prompt_len + output_len
for _, prompt_len, output_len in requests
)
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf"],
default="vllm")
parser.add_argument("--dataset", type=str, required=True,
parser.add_argument("--dataset",
type=str,
required=True,
help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1,
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None,
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
Expand All @@ -215,6 +227,8 @@ def main(args: argparse.Namespace):
elif args.backend == "hf":
if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.tokenizer is None:
args.tokenizer = args.model

Expand Down
15 changes: 15 additions & 0 deletions csrc/quantization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <torch/extension.h>

torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"awq_gemm",
&awq_gemm,
"Quantized GEMM for AWQ");
}
79 changes: 79 additions & 0 deletions csrc/quantization/awq/dequantize.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/

#pragma once


__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
uint4 result;

uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);

// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;

// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.

// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));

// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.

// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;

// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));

return result;
}

Loading