diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index c327db9a95..e07a3e6799 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -1,3 +1,4 @@ +llama 2 20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 @@ -8,6 +9,7 @@ 20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +llama 3 20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 @@ -18,61 +20,10 @@ 20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -# done with quantization of latest token -20240718131341, tok/s=108.87, mem/s=1438.62 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240718131549, tok/s=103.15, mem/s=1363.06 GB/s, peak_mem=13.86 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240718131820, tok/s=163.84, mem/s=1084.89 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240718132103, tok/s=154.76, mem/s=1024.78 GB/s, peak_mem= 8.93 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 - -# done with full accuracy for latest token -20240718150644, tok/s=109.23, mem/s=1443.43 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240718151152, tok/s=100.29, mem/s=1325.29 GB/s, peak_mem=14.14 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240718151349, tok/s=166.08, mem/s=1099.70 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240718152147, tok/s=140.85, mem/s= 932.66 GB/s, peak_mem= 9.21 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 - -20240731133002, tok/s=109.42, mem/s=1446.01 GB/s, peak_mem=13.90 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240731135838, tok/s=102.85, mem/s=1359.17 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731140259, tok/s=102.91, mem/s=1359.87 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731140646, tok/s=101.19, mem/s=1337.23 GB/s, peak_mem=14.52 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731194813, tok/s=102.84, mem/s=1358.94 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731195225, tok/s=103.14, mem/s=1362.92 GB/s, peak_mem=14.52 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731200747, tok/s=102.79, mem/s=1358.40 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731201145, tok/s=103.09, mem/s=1362.33 GB/s, peak_mem=14.52 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 - -20240731201438, tok/s=109.42, mem/s=1446.00 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240731201739, tok/s=102.58, mem/s=1355.51 GB/s, peak_mem=13.83 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240731202121, tok/s=102.86, mem/s=1359.26 GB/s, peak_mem=15.00 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731202505, tok/s=103.44, mem/s=1366.91 GB/s, peak_mem=14.52 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 - -20240731212356, tok/s= 95.36, mem/s=1431.41 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240731212822, tok/s= 93.92, mem/s=1409.76 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240731213330, tok/s= 89.84, mem/s=1348.49 GB/s, peak_mem=17.28 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731213900, tok/s= 88.34, mem/s=1326.01 GB/s, peak_mem=17.24 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731215214, tok/s= 80.92, mem/s=1214.62 GB/s, peak_mem=19.80 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 -20240731220805, tok/s= 78.54, mem/s=1178.91 GB/s, peak_mem=19.30 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 - -20240731223923, tok/s= 95.52, mem/s=1433.68 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240731224400, tok/s= 90.39, mem/s=1356.73 GB/s, peak_mem=16.44 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240731224816, tok/s= 89.89, mem/s=1349.25 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731225411, tok/s= 84.96, mem/s=1275.21 GB/s, peak_mem=17.48 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240731230612, tok/s= 80.91, mem/s=1214.45 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 -20240731232130, tok/s= 69.10, mem/s=1037.25 GB/s, peak_mem=20.18 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 - -20240801010740, tok/s= 95.45, mem/s=1432.64 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240801011046, tok/s= 94.02, mem/s=1411.28 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240801011513, tok/s= 89.96, mem/s=1350.32 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240801011931, tok/s= 88.11, mem/s=1322.52 GB/s, peak_mem=17.20 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 - -20240801013354, tok/s= 95.45, mem/s=1432.67 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240801013812, tok/s= 92.15, mem/s=1383.16 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240801025927, tok/s= 89.88, mem/s=1349.14 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240801030347, tok/s= 87.32, mem/s=1310.69 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240801031549, tok/s= 80.91, mem/s=1214.39 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 -20240801033011, tok/s= 74.72, mem/s=1121.50 GB/s, peak_mem=19.34 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 - -20240801093317, tok/s= 95.52, mem/s=1433.80 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240801093529, tok/s= 92.36, mem/s=1386.35 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240801093944, tok/s= 89.88, mem/s=1349.13 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 -20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 -20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 \ No newline at end of file +kv cache quantization: +20240801093317, tok/s= 95.52, mem/s=1433.80 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801093529, tok/s= 92.36, mem/s=1386.35 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801093944, tok/s= 89.88, mem/s=1349.13 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 +20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index ed599de914..6dd9c10d94 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -1,34 +1,32 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder -# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt # in readme -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt export MODEL_REPO=meta-llama/Meta-Llama-3-8B -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -# # in readme -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +# in readme +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt -##### +export MODEL_REPO=meta-llama/Meta-Llama-3-8B python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --max_new_tokens 2048 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index beb1569bbb..6e6db90571 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -95,38 +95,40 @@ def generate( """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. """ + # create an empty tensor of the expected final shape and fill in the current tokens device = prompt.device T = prompt.numel() - # max_new_tokens can overflow block_size so we need to cap it + # calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size) max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 new_tokens = max_seq_length - T # full prompt+output will be stored in seq seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device) seq[:T] = prompt.view(-1) - # setup model cache + + # setup model caches with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if kv_cache_quantization: - from model import QuantizedKVCache - # go through the model and do the swaps + from model import AffineQuantizedKVCache from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter _replace_with_custom_fn_if_matches_filter( model, - QuantizedKVCache.from_float, + AffineQuantizedKVCache.from_float, lambda x, y: isinstance(x, torchao._models.llama.model.KVCache), ) # format model input - x, input_pos = prepare_inputs_for_model(prompt) + x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens) # execute prefill next_token = prefill(model, x, input_pos, **sampling_kwargs).clone() seq[T] = next_token + # execute token generation input_pos = torch.tensor([T], device=device, dtype=torch.int) generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) seq[T + 1:] = torch.cat(generated_tokens) @@ -172,7 +174,6 @@ def main( """Generates text samples based on a pre-trained Transformer model and tokenizer. """ - # torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=1000000, trace_alloc_record_context=True) torchao.quantization.utils.recommended_inductor_config_setter() assert checkpoint_path.is_file(), checkpoint_path @@ -294,12 +295,6 @@ def callback(x): top_k=top_k, kv_cache_quantization=kv_cache_quantization, ) - # if i==3: - # snapshot = torch.cuda.memory._snapshot() - # from pickle import dump - # with open("mem_trace_kvq_no_comp" + '.pickle', 'wb') as f: - # dump(snapshot, f) - # breakpoint() if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") continue diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 5c37adcdd7..58a1709642 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -12,7 +12,8 @@ from torch.nn import functional as F from torchao.utils import find_multiple -def prepare_inputs_for_model(inps): +# TODO remove suplerfluous arg +def prepare_inputs_for_model(inps, max_new_tokens=1): # this is because input from lm-eval is 2d if inps.dim() > 2: raise ValueError(f"Expected input to be of dim 1 or 2, but got {inps.dim()}") @@ -85,6 +86,7 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torc def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] + if use_index_put_for_kv_cache: k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val) v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val) @@ -97,15 +99,10 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out -# (Pdb) p k_val.shape -# torch.Size([1, 32, 6, 128]) -# (Pdb) p self.k_cache.shape -# torch.Size([1, 32, 208, 128]) so want final size to be 1,32,208,[1] - from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine from torchao.quantization.utils import quantize_activation_per_token_absmax -class QuantizedKVCache(nn.Module): +class AffineQuantizedKVCache(nn.Module): def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype=torch.bfloat16): super().__init__() cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) @@ -116,6 +113,7 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtyp self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) def update(self, input_pos, k_val, v_val): + # quantize current k_val and store it in the cache q_k_val, k_scale = quantize_activation_per_token_absmax(k_val) self.k_cache[:, :, input_pos] = q_k_val self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1) @@ -137,8 +135,6 @@ def from_float(cls, kv_cache): scale_dtype = kv_cache.k_cache.dtype return cls(max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype) - - class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__()