diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index 985bc025c7..27b1b568df 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -40,12 +40,12 @@ def format_value(value): print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid')) -def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length): +def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, save, batch_size, max_length): tokenizer = AutoTokenizer.from_pretrained(repo_id) model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) - if compile: + if quantization == "autoquant" and compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) if quantization == "int8dq": @@ -57,6 +57,10 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi quantize_(model.to(device=device), int4_weight_only()) elif quantization == "autoquant": model = autoquant(model.to(device=device)) + + if quantization != "autoquant" and compile: + model = torch.compile(model, mode="max-autotune", fullgraph=True) + with torch.no_grad(): result = evaluate( HFLM( @@ -70,6 +74,12 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi pretty_print_nested_results(result) + if save: + # This doesn't work yet: https://github.com/huggingface/transformers/issues/32364 + # model.save_pretrained("quantized_model_test", safe_serialization=False) + file_name = repo_id.split("/")[-1] + "-" + quantization + ".pt" + torch.save(model.state_dict(), file_name) + if __name__ == '__main__': import argparse @@ -81,8 +91,9 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--save', action='store_true', help='Whether to save the model.') parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') args = parser.parse_args() - run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length) + run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.save, args.batch_size, args.max_length) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index f0efaee020..3af9a156f7 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os import sys import time from pathlib import Path @@ -165,6 +166,7 @@ def main( checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, kv_cache_quantization: bool = False, + save: bool = False, compile: bool = True, compile_prefill: bool = False, profile: Optional[Path] = None, @@ -238,6 +240,11 @@ def main( model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 + if save: + output_dir = str(checkpoint_path.cwd()) + filename = str(checkpoint_path.name).split(".")[0] + torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt")) + if compile: print("Compiling Model") global decode_one_token, prefill @@ -362,6 +369,7 @@ def callback(x): parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') + parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') parser.add_argument('--profile', type=Path, default=None, help='Profile path.') @@ -372,5 +380,5 @@ def callback(x): args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result )