From 9c882fd2e639e16ca4792829cee9f797eb9743aa Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Fri, 28 Jun 2024 11:10:05 +0200 Subject: [PATCH] Fix GPTQ CI (#1878) * fix gptq tests * simplify ci * enable determinism * fix * add more expected outputs * last one * hopefully * more expected putputs with each run * new one * add evaluation * fix * remove gptq extra * style check --- .github/workflows/check_code_quality.yml | 56 +++----- .github/workflows/test_gptq.yml | 59 +++++--- optimum/gptq/eval.py | 41 ++++++ tests/gptq/Dockerfile_quantization_gpu | 26 ---- tests/gptq/test_quantization.py | 169 +++++++++-------------- 5 files changed, 166 insertions(+), 185 deletions(-) create mode 100644 optimum/gptq/eval.py delete mode 100644 tests/gptq/Dockerfile_quantization_gpu diff --git a/.github/workflows/check_code_quality.yml b/.github/workflows/check_code_quality.yml index 660f417019..c429b706bf 100644 --- a/.github/workflows/check_code_quality.yml +++ b/.github/workflows/check_code_quality.yml @@ -1,19 +1,11 @@ -name: check_code_quality +name: Code Quality on: push: - branches: [ main ] - paths: - - "optimum/**.py" - - "tests/**.py" - - "examples/**.py" + branches: [main] pull_request: - branches: [ main ] - paths: - - "optimum/**.py" - - "tests/**.py" - - "examples/**.py" + branches: [main] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -29,25 +21,23 @@ jobs: runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Create and start a virtual environment - run: | - python -m venv venv - source venv/bin/activate - - name: Install dependencies - run: | - source venv/bin/activate - pip install --upgrade pip - pip install .[quality] - - name: Check style with black - run: | - source venv/bin/activate - black --check . - - name: Check style with ruff - run: | - source venv/bin/activate - ruff . + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[quality] + + - name: Check style with black + run: | + black --check . + + - name: Check style with ruff + run: | + ruff . diff --git a/.github/workflows/test_gptq.yml b/.github/workflows/test_gptq.yml index 0f3c31c6d2..7e7d3959a6 100644 --- a/.github/workflows/test_gptq.yml +++ b/.github/workflows/test_gptq.yml @@ -1,29 +1,46 @@ -name: GPTQ Quantization / Test GPU +name: GPTQ / Python - Test on: workflow_dispatch: - schedule: - - cron: 0 1 */3 * * # at 1am every 3 days + push: + branches: [main] + paths: + - tests/gptq/** + - optimum/gptq/** + - .github/workflows/test_gptq.yml pull_request: - types: [opened, synchronize, reopened, labeled] - # uncomment to enable on PR merge on main branch: - #push: - # branches: - # - main + branches: [main] + paths: + - tests/gptq/** + - optimum/gptq/** + - .github/workflows/test_gptq.yml + schedule: + # every day at midnight + - cron: "0 0 * * *" jobs: - do-the-job: - if: ${{ (github.event_name == 'workflow_dispatch') || (github.event_name == 'schedule') || contains( github.event.pull_request.labels.*.name, 'gpu-test') }} - name: Start self-hosted EC2 runner + test_gptq: runs-on: [single-gpu, nvidia-gpu, t4, ci] - env: - AWS_REGION: us-east-1 + steps: - - name: Checkout - uses: actions/checkout@v2 - - name: Build image - run: | - docker build -f tests/gptq/Dockerfile_quantization_gpu -t gptq-gpu . - - name: Test with unittest within docker container - run: | - docker run --rm --gpus all -v $(pwd)/hf_cache:/root/.cache/huggingface --workdir=/workspace/optimum/tests gptq-gpu:latest + - name: Checkout code + uses: actions/checkout@v4 + + - name: Run tests + uses: addnab/docker-run-action@v3 + with: + image: pytorch/pytorch:2.2.2-cuda12.1-cudnn8-runtime + # latest auto-gptq was built with pytorch 2.2 and cuda 12.1 + options: | + --rm + --gpus all + --shm-size 16G + --env RUN_SLOW=1 + --env HF_HOME=/mnt/cache/ + --volume /mnt/cache/:/mnt/cache/ + --volume ${{ github.workspace }}:/workspace + --workdir /workspace + run: | + pip install auto-gptq + pip install -e .[tests] + pytest tests/gptq -s -vvvv --durations=0 diff --git a/optimum/gptq/eval.py b/optimum/gptq/eval.py new file mode 100644 index 0000000000..3ae6e4d7bf --- /dev/null +++ b/optimum/gptq/eval.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from datasets import load_dataset +from tqdm import tqdm + + +def evaluate_perplexity(model, tokenizer): + def _perplexity(nlls, n_samples, seqlen): + return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen)) + + # load and prepare dataset + data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + data = tokenizer("\n\n".join(data["text"]), return_tensors="pt") + data = data.input_ids.to(model.device) + + seqlen = 512 + model = model.eval() + n_samples = data.numel() // seqlen + + nlls = [] + + with tqdm(range(n_samples), desc="Perplexity -") as progress_bar: + for i in progress_bar: + start_index = i * seqlen + end_index = (i + 1) * seqlen + batch = data[:, start_index:end_index].to(model.device) + with torch.no_grad(): + logits = model(batch).logits + shift_logits = logits[:, :-1, :].contiguous().float() + shift_labels = data[:, start_index:end_index][:, 1:] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + neg_log_likelihood = loss.float() * seqlen + nlls.append(neg_log_likelihood) + + curr_ppl = _perplexity(nlls, i + 1, seqlen) + progress_bar.set_description(f"Perplexity {curr_ppl:.3f}") + + ppl = _perplexity(nlls, n_samples, seqlen) + + return ppl.item() diff --git a/tests/gptq/Dockerfile_quantization_gpu b/tests/gptq/Dockerfile_quantization_gpu deleted file mode 100644 index 34a2a13552..0000000000 --- a/tests/gptq/Dockerfile_quantization_gpu +++ /dev/null @@ -1,26 +0,0 @@ -FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 -CMD nvidia-smi - -# Ignore interactive questions during `docker build` -ENV DEBIAN_FRONTEND noninteractive - -# Install and update tools to minimize security vulnerabilities -RUN apt-get update -RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \ - bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev python3-pip && \ - apt-get clean -RUN unattended-upgrade -RUN apt-get autoremove -y - -RUN python3 -m pip install -U pip - -RUN pip install torch torchvision torchaudio -RUN pip install transformers accelerate auto-gptq datasets - -# Install Optimum -COPY . /workspace/optimum -RUN pip install /workspace/optimum[tests] - -ENV RUN_SLOW=1 -WORKDIR /workspace/optimum/tests/ -CMD pytest gptq/test_*.py --durations=0 -s -vvvvv diff --git a/tests/gptq/test_quantization.py b/tests/gptq/test_quantization.py index 5ed1619fde..220d023586 100644 --- a/tests/gptq/test_quantization.py +++ b/tests/gptq/test_quantization.py @@ -23,12 +23,19 @@ from optimum.gptq import GPTQQuantizer, load_quantized_model from optimum.gptq.data import get_dataset -from optimum.utils.import_utils import is_auto_gptq_available -from optimum.utils.testing_utils import require_accelerate, require_auto_gptq, require_torch_gpu +from optimum.gptq.eval import evaluate_perplexity +from optimum.gptq.utils import get_block_name_with_pattern, get_preceding_modules, get_seqlen +from optimum.utils import recurse_getattr +from optimum.utils.import_utils import is_accelerate_available, is_auto_gptq_available +from optimum.utils.testing_utils import require_auto_gptq, require_torch_gpu if is_auto_gptq_available(): from auto_gptq import AutoGPTQForCausalLM + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + +if is_accelerate_available(): + from accelerate import init_empty_weights @slow @@ -37,15 +44,10 @@ class GPTQTest(unittest.TestCase): model_name = "bigscience/bloom-560m" - input_text = "Hello my name is" - EXPECTED_OUTPUTS = set() - EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I") - EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I") - EXPECTED_OUTPUTS.add("Hello my name is John and I am a very good looking man.") - EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of") + expected_fp16_perplexity = 30 + expected_quantized_perplexity = 34 - # this seems a little small considering that we are doing 4bit quant but we have a small model and ww don't quantize the embeddings - EXPECTED_RELATIVE_DIFFERENCE = 1.664253062 + expected_compression_ratio = 1.66 bits = 4 group_size = 128 @@ -53,24 +55,30 @@ class GPTQTest(unittest.TestCase): disable_exllama = True exllama_config = None cache_block_outputs = True - modules_to_quantize_inside_block = None + modules_in_block_to_quantize = None device_map_for_quantization = "cuda" + device_for_inference = 0 dataset = [ "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm." ] - # called only once for all test in this class + # called only once for all tests in this class @classmethod def setUpClass(cls): """ Setup quantized model """ + + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.model_fp16 = AutoModelForCausalLM.from_pretrained( cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map_for_quantization ) - cls.mem_fp16 = cls.model_fp16.get_memory_footprint() + cls.fp16_mem = cls.model_fp16.get_memory_footprint() + + if cls.device_map_for_quantization != "cpu": + cls.fp16_ppl = evaluate_perplexity(cls.model_fp16, cls.tokenizer) - cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True) cls.quantizer = GPTQQuantizer( bits=cls.bits, dataset=cls.dataset, @@ -79,10 +87,13 @@ def setUpClass(cls): disable_exllama=cls.disable_exllama, exllama_config=cls.exllama_config, cache_block_outputs=cls.cache_block_outputs, - modules_to_quantize_inside_block=cls.modules_to_quantize_inside_block, + modules_in_block_to_quantize=cls.modules_in_block_to_quantize, ) + cls.quantized_model = cls.quantizer.quantize_model(cls.model_fp16, cls.tokenizer).to(cls.device_for_inference) + cls.quantized_mem = cls.quantized_model.get_memory_footprint() - cls.quantized_model = cls.quantizer.quantize_model(cls.model_fp16, cls.tokenizer) + if cls.device_map_for_quantization != "cpu": + cls.quantized_ppl = evaluate_perplexity(cls.quantized_model, cls.tokenizer) def test_memory_footprint(self): """ @@ -90,19 +101,26 @@ def test_memory_footprint(self): memory footprint of the converted model and the class type of the linear layers of the converted models """ - mem_quantized = self.quantized_model.get_memory_footprint() + self.assertAlmostEqual(self.fp16_mem / self.quantized_mem, self.expected_compression_ratio, places=2) - self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE) + def test_perplexity(self): + """ + A simple test to check if the model conversion has been done correctly by checking on the + the perplexity of the converted models + """ + + self.assertEqual(int(self.fp16_ppl), self.expected_fp16_perplexity) + self.assertEqual(int(self.quantized_ppl), self.expected_quantized_perplexity) def test_quantized_layers_class(self): """ A simple test to check if the model conversion has been done correctly by checking on the the class type of the linear layers of the converted models """ - from auto_gptq.utils.import_utils import dynamically_import_QuantLinear QuantLinear = dynamically_import_QuantLinear( use_triton=False, + use_qigen=False, desc_act=self.desc_act, group_size=self.group_size, bits=self.bits, @@ -114,32 +132,10 @@ def test_quantized_layers_class(self): def check_quantized_layers_type(self, model, value): self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.QUANT_TYPE == value) - def check_inference_correctness(self, model): - """ - Test the generation quality of the quantized model and see that we are matching the expected output. - Given that we are operating on small numbers + the testing model is relatively small, we might not get - the same output across GPUs. So we'll generate few tokens (5-10) and check their output. - """ - # Check that inference pass works on the model - encoded_input = self.tokenizer(self.input_text, return_tensors="pt") - - # Get the generation - output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) - - # Check the exactness of the result - self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) - - def test_generate_quality(self): - self.check_inference_correctness(self.quantized_model) - - @require_torch_gpu - @require_accelerate - @slow def test_serialization(self): """ Test the serialization of the model and the loading of the quantized weights """ - from accelerate import init_empty_weights with tempfile.TemporaryDirectory() as tmpdirname: self.quantizer.save(self.quantized_model, tmpdirname) @@ -152,7 +148,7 @@ def test_serialization(self): quantized_model_from_saved = load_quantized_model( empty_model, save_folder=tmpdirname, - device_map={"": 0}, + device_map={"": self.device_for_inference}, disable_exllama=self.disable_exllama, exllama_config=self.exllama_config, ) @@ -161,54 +157,37 @@ def test_serialization(self): else: self.check_quantized_layers_type(quantized_model_from_saved, "exllama") - with torch.device("cuda"): - _ = AutoModelForCausalLM.from_pretrained(tmpdirname) - _ = AutoGPTQForCausalLM.from_quantized(tmpdirname) - - self.check_inference_correctness(quantized_model_from_saved) + # transformers and auto-gptq compatibility + # quantized models are more compatible with device map than + # device context managers (they're never used in transformers testing suite) + _ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference}) + _ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference}) class GPTQTestCPUInit(GPTQTest): device_map_for_quantization = "cpu" - def test_generate_quality(self): - self.check_inference_correctness(self.quantized_model.to(0)) + def test_perplexity(self): + pass class GPTQTestExllama(GPTQTest): disable_exllama = False exllama_config = {"version": 1} - EXPECTED_OUTPUTS = set() - EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I") - EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.") - EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of") - EXPECTED_OUTPUTS.add("Hello my name is Nate and I am a new member of the") class GPTQTestActOrder(GPTQTest): - EXPECTED_OUTPUTS = set() - EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.") - EXPECTED_OUTPUTS.add("Hello my name is jessie and i am a very sweet and") - EXPECTED_OUTPUTS.add("Hello my name is nathalie, I am a young girl from") - EXPECTED_OUTPUTS.add("Hello my name is\nI am a student of the University of the'") - disable_exllama = True desc_act = True - def test_generate_quality(self): - # act_order don't work with qlinear_cuda kernel - pass - def test_serialization(self): # act_order don't work with qlinear_cuda kernel pass - @require_torch_gpu def test_exllama_serialization(self): """ Test the serialization of the model and the loading of the quantized weights with exllama kernel """ - from accelerate import init_empty_weights with tempfile.TemporaryDirectory() as tmpdirname: self.quantizer.save(self.quantized_model, tmpdirname) @@ -219,21 +198,23 @@ def test_exllama_serialization(self): ) empty_model.tie_weights() quantized_model_from_saved = load_quantized_model( - empty_model, save_folder=tmpdirname, device_map={"": 0}, exllama_config={"version": 1} + empty_model, + save_folder=tmpdirname, + device_map={"": self.device_for_inference}, + exllama_config={"version": 1}, ) self.check_quantized_layers_type(quantized_model_from_saved, "exllama") - with torch.device("cuda"): - _ = AutoModelForCausalLM.from_pretrained(tmpdirname) - _ = AutoGPTQForCausalLM.from_quantized(tmpdirname) - - self.check_inference_correctness(quantized_model_from_saved) + # transformers and auto-gptq compatibility + # quantized models are more compatible with device map than + # device context managers (they're never used in transformers testing suite) + _ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference}) + _ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference}) def test_exllama_max_input_length(self): """ Test if the max_input_length works with exllama + act_order """ - from accelerate import init_empty_weights with tempfile.TemporaryDirectory() as tmpdirname: self.quantizer.save(self.quantized_model, tmpdirname) @@ -246,9 +227,9 @@ def test_exllama_max_input_length(self): quantized_model_from_saved = load_quantized_model( empty_model, save_folder=tmpdirname, - device_map={"": 0}, - max_input_length=4028, + device_map={"": self.device_for_inference}, exllama_config={"version": 1}, + max_input_length=4028, ) self.check_quantized_layers_type(quantized_model_from_saved, "exllama") @@ -268,26 +249,16 @@ def test_exllama_max_input_length(self): class GPTQTestExllamav2(GPTQTest): desc_act = False disable_exllama = True - EXPECTED_OUTPUTS = set() - EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I") - EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.") - EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of") - EXPECTED_OUTPUTS.add("Hello my name is Nate and I am a new member of the") - - def test_generate_quality(self): - # don't need to test - pass + exllama_config = {"version": 2} def test_serialization(self): # don't need to test pass - @require_torch_gpu def test_exllama_serialization(self): """ Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel """ - from accelerate import init_empty_weights with tempfile.TemporaryDirectory() as tmpdirname: self.quantizer.save(self.quantized_model, tmpdirname) @@ -300,24 +271,19 @@ def test_exllama_serialization(self): quantized_model_from_saved = load_quantized_model( empty_model, save_folder=tmpdirname, - device_map={"": 0}, + device_map={"": self.device_for_inference}, ) self.check_quantized_layers_type(quantized_model_from_saved, "exllamav2") - with torch.device("cuda"): - _ = AutoModelForCausalLM.from_pretrained(tmpdirname) - _ = AutoGPTQForCausalLM.from_quantized(tmpdirname) - - self.check_inference_correctness(quantized_model_from_saved) + # transformers and auto-gptq compatibility + # quantized models are more compatible with device map than + # device context managers (they're never used in transformers testing suite) + _ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference}) + _ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference}) class GPTQTestNoBlockCaching(GPTQTest): cache_block_outputs = False - EXPECTED_OUTPUTS = set() - EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I") - EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.") - EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of") - EXPECTED_OUTPUTS.add("Hello my name is Aiden and I am a very good looking") class GPTQTestModuleQuant(GPTQTest): @@ -327,7 +293,7 @@ class GPTQTestModuleQuant(GPTQTest): ["mlp.dense_h_to_4h"], ["mlp.dense_4h_to_h"], ] - EXPECTED_RELATIVE_DIFFERENCE = 1.57705236164535 + expected_compression_ratio = 1.577 def test_not_converted_layers(self): # self_attention.dense should not be converted @@ -350,16 +316,11 @@ class GPTQUtilsTest(unittest.TestCase): ] def test_get_seqlen(self): - from optimum.gptq.utils import get_seqlen - model = AutoModelForCausalLM.from_pretrained(self.model_name) seqlen = get_seqlen(model) self.assertEqual(seqlen, self.expected_seqlen) def test_get_block_name(self): - from optimum.gptq.utils import get_block_name_with_pattern - from optimum.utils import recurse_getattr - model = AutoModelForCausalLM.from_pretrained(self.model_name) block_name = get_block_name_with_pattern(model) self.assertEqual(block_name, self.expected_block_name) @@ -367,8 +328,6 @@ def test_get_block_name(self): self.assertEqual(block_class_name, self.expected_block_name_class) def test_get_preceding_modules(self): - from optimum.gptq.utils import get_preceding_modules - model = AutoModelForCausalLM.from_pretrained(self.model_name) modules_names = get_preceding_modules(model, self.expected_block_name) self.assertCountEqual(modules_names, self.expected_preceding_modules)