diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index b02d4c2441..e07a3e6799 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -1,3 +1,4 @@ +llama 2 20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 @@ -8,6 +9,7 @@ 20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +llama 3 20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 @@ -17,3 +19,11 @@ 20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +kv cache quantization: +20240801093317, tok/s= 95.52, mem/s=1433.80 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801093529, tok/s= 92.36, mem/s=1386.35 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801093944, tok/s= 89.88, mem/s=1349.13 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 +20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 20a7bf1103..6dd9c10d94 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -22,3 +22,11 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt + +export MODEL_REPO=meta-llama/Meta-Llama-3-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192 diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 34ff9abb12..6e6db90571 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -68,10 +68,11 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc next_token, next_prob = decode_one_token( model, cur_token, input_pos, **sampling_kwargs ) + next_token, next_prob = next_token.clone(), next_prob.clone() input_pos += 1 - new_tokens.append(next_token.clone()) + new_tokens.append(next_token) callback(new_tokens[-1]) - new_probs.append(next_prob.clone()) + new_probs.append(next_prob) cur_token = next_token.view(1, -1) return new_tokens, new_probs @@ -88,6 +89,7 @@ def generate( *, interactive: bool, callback = lambda x: x, + kv_cache_quantization: bool = False, **sampling_kwargs ) -> torch.Tensor: """ @@ -97,14 +99,27 @@ def generate( # create an empty tensor of the expected final shape and fill in the current tokens device = prompt.device T = prompt.numel() - T_new = T + max_new_tokens - seq = torch.empty(T_new, dtype=prompt.dtype, device=device) + + # calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size) + max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + new_tokens = max_seq_length - T + + # full prompt+output will be stored in seq + seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device) seq[:T] = prompt.view(-1) - # setup model cache - max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 + # setup model caches with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if kv_cache_quantization: + from model import AffineQuantizedKVCache + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + _replace_with_custom_fn_if_matches_filter( + model, + AffineQuantizedKVCache.from_float, + lambda x, y: isinstance(x, torchao._models.llama.model.KVCache), + ) + # format model input x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens) @@ -113,8 +128,9 @@ def generate( next_token = prefill(model, x, input_pos, **sampling_kwargs).clone() seq[T] = next_token + # execute token generation input_pos = torch.tensor([T], device=device, dtype=torch.int) - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) seq[T + 1:] = torch.cat(generated_tokens) return seq @@ -147,6 +163,7 @@ def main( temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, + kv_cache_quantization: bool = False, compile: bool = True, compile_prefill: bool = False, profile: Optional[Path] = None, @@ -276,6 +293,7 @@ def callback(x): callback=callback, temperature=temperature, top_k=top_k, + kv_cache_quantization=kv_cache_quantization, ) if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") @@ -286,7 +304,10 @@ def callback(x): t = time.perf_counter() - t0 if not interactive: - print(tokenizer.decode(y.tolist())) + tok_list = y.tolist() + # truncate text after end of string token + tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())] + print(tokenizer.decode(tokens)) else: print() tokens_generated = y.size(0) - prompt_length @@ -305,12 +326,13 @@ def callback(x): print(f"Model Size: {model_size:.02f} GB") if write_result: result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " - result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " + result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " result_txt += f"repro: python generate.py " result_txt += f"--quantization {quantization} " if quantization else "" result_txt += f"--checkpoint_path {checkpoint_path} " result_txt += f"--device {device} " result_txt += f"--precision {precision} " + result_txt += f"--kv_cache_quantization " if kv_cache_quantization else "" result_txt += f"--compile " if compile else "" result_txt += f"--compile_prefill " if compile_prefill else "" result_txt += f"--profile {profile} " if profile else "" @@ -337,6 +359,7 @@ def callback(x): parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') + parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') parser.add_argument('--profile', type=Path, default=None, help='Profile path.') @@ -347,5 +370,5 @@ def callback(x): args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result ) diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 7204b0e387..58a1709642 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -12,6 +12,7 @@ from torch.nn import functional as F from torchao.utils import find_multiple +# TODO remove suplerfluous arg def prepare_inputs_for_model(inps, max_new_tokens=1): # this is because input from lm-eval is 2d if inps.dim() > 2: @@ -97,6 +98,43 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out + +from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine +from torchao.quantization.utils import quantize_activation_per_token_absmax + +class AffineQuantizedKVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + scale_shape = (max_batch_size, n_heads, max_seq_length, 1) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.int8)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8)) + self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) + self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) + + def update(self, input_pos, k_val, v_val): + # quantize current k_val and store it in the cache + q_k_val, k_scale = quantize_activation_per_token_absmax(k_val) + self.k_cache[:, :, input_pos] = q_k_val + self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1) + k_out = self.k_cache*self.k_cache_scale + k_out[:, :, input_pos] = k_val + + q_v_val, v_scale = quantize_activation_per_token_absmax(v_val) + self.v_cache[:, :, input_pos] = q_v_val + self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1) + v_out = self.v_cache*self.v_cache_scale + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + @classmethod + def from_float(cls, kv_cache): + cache_shape = kv_cache.k_cache.shape + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape + scale_dtype = kv_cache.k_cache.dtype + return cls(max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype) + class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__()