From 7d649c0df682d3bac3b497a2a24c9c009342e007 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 3 Jul 2024 13:27:05 +0200 Subject: [PATCH 1/4] deprecate and create alternative --- optimum/onnxruntime/model.py | 5 +++ optimum/onnxruntime/utils.py | 62 +++++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/model.py b/optimum/onnxruntime/model.py index 23ca6e5e6a..caa662f382 100644 --- a/optimum/onnxruntime/model.py +++ b/optimum/onnxruntime/model.py @@ -49,6 +49,11 @@ def __init__( label_names (`List[str]`, `optional`): The list of keys in your dictionary of inputs that correspond to the labels. """ + + logger.warning( + "The class `optimum.onnxruntime.model.ORTModel` is deprecated and will be removed in the next release." + ) + self.compute_metrics = compute_metrics self.label_names = ["labels"] if label_names is None else label_names self.session = InferenceSession(str(model_path), providers=[execution_provider]) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 37d0feefcc..ad40af92b9 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -17,11 +17,15 @@ import re from enum import Enum from inspect import signature -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from packaging import version +from tqdm import tqdm +from transformers import EvalPrediction +from transformers.trainer_pt_utils import nested_concat +from transformers.trainer_utils import EvalLoopOutput from transformers.utils import logging import onnxruntime as ort @@ -30,6 +34,12 @@ from ..utils.import_utils import _is_package_available +if TYPE_CHECKING: + from datasets import Dataset + + from .modeling_ort import ORTModel + + logger = logging.get_logger(__name__) ONNX_WEIGHTS_NAME = "model.onnx" @@ -341,3 +351,53 @@ class ORTQuantizableOperator(Enum): Resize = "Resize" AveragePool = "AveragePool" Concat = "Concat" + + +def evaluation_loop( + model: "ORTModel", + dataset: "Dataset", + label_names: Optional[List[str]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, +): + """ + Run evaluation and returns metrics and predictions. + + Args: + model (`ORTModel`): + The ONNXRuntime model to use for the evaluation step. + dataset (`datasets.Dataset`): + Dataset to use for the evaluation step. + label_names (`List[str]`, `optional`): + The list of keys in your dictionary of inputs that correspond to the labels. + compute_metrics (`Callable[[EvalPrediction], Dict]`, `optional`): + The function that will be used to compute metrics at evaluation. Must take an `EvalPrediction` and + return a dictionary string to metric values. + """ + + all_preds = None + all_labels = None + + for inputs in tqdm(dataset, desc="Evaluation"): + has_labels = all(inputs.get(k) is not None for k in label_names) + if has_labels: + labels = tuple(np.array([inputs.get(name)]) for name in label_names) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + inputs = {key: np.array([inputs[key]]) for key in model.input_names if key in inputs} + preds = model(**inputs) + + if len(preds) == 1: + preds = preds[0] + + all_preds = preds if all_preds is None else nested_concat(all_preds, preds, padding_index=-100) + all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + + if compute_metrics is not None and all_preds is not None and all_labels is not None: + metrics = compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) + else: + metrics = {} + + return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=len(dataset)) From bdf1d4874e4ba47c2f4dbaed7720818160911021 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 4 Jul 2024 10:20:20 +0200 Subject: [PATCH 2/4] update optimization examples --- .../optimization/multiple-choice/run_swag.py | 19 ++++++++------ .../optimization/question-answering/run_qa.py | 24 +++++++++--------- .../text-classification/README.md | 6 ++--- .../text-classification/run_glue.py | 25 +++++++++++-------- .../token-classification/run_ner.py | 17 ++++++------- 5 files changed, 47 insertions(+), 44 deletions(-) diff --git a/examples/onnxruntime/optimization/multiple-choice/run_swag.py b/examples/onnxruntime/optimization/multiple-choice/run_swag.py index 3c43846b9a..b2a9398d94 100644 --- a/examples/onnxruntime/optimization/multiple-choice/run_swag.py +++ b/examples/onnxruntime/optimization/multiple-choice/run_swag.py @@ -37,7 +37,7 @@ from optimum.onnxruntime import ORTModelForMultipleChoice, ORTOptimizer from optimum.onnxruntime.configuration import OptimizationConfig -from optimum.onnxruntime.model import ORTModel +from optimum.onnxruntime.utils import evaluation_loop # Will error if the minimal version of Transformers is not installed. The version of transformers must be >= 4.19.0 @@ -236,7 +236,6 @@ def main(): ) os.makedirs(training_args.output_dir, exist_ok=True) - optimized_model_path = os.path.join(training_args.output_dir, "model_optimized.onnx") tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name or model_args.model_name_or_path) @@ -254,13 +253,18 @@ def main(): optimizer = ORTOptimizer.from_pretrained(model) # Optimize the model - optimizer.optimize( + optimized_model_path = optimizer.optimize( optimization_config=optimization_config, save_dir=training_args.output_dir, use_external_data_format=onnx_export_args.use_external_data_format, one_external_file=onnx_export_args.one_external_file, ) + model = ORTModelForMultipleChoice.from_pretrained( + optimized_model_path, + provider=optim_args.execution_provider, + ) + if training_args.do_eval: # Prepare the dataset downloading, preprocessing and metric creation to perform the evaluation and / or the # prediction step(s) @@ -339,13 +343,12 @@ def compute_metrics(eval_predictions): # Evaluation logger.info("*** Evaluate ***") - ort_model = ORTModel( - optimized_model_path, - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=eval_dataset, + label_names=["labels"], compute_metrics=compute_metrics, - label_names=["label"], ) - outputs = ort_model.evaluation_loop(eval_dataset) # Save evaluation metrics with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f: diff --git a/examples/onnxruntime/optimization/question-answering/run_qa.py b/examples/onnxruntime/optimization/question-answering/run_qa.py index 04a9bd34f3..407714cb01 100644 --- a/examples/onnxruntime/optimization/question-answering/run_qa.py +++ b/examples/onnxruntime/optimization/question-answering/run_qa.py @@ -37,7 +37,7 @@ from optimum.onnxruntime import ORTModelForQuestionAnswering, ORTOptimizer from optimum.onnxruntime.configuration import OptimizationConfig -from optimum.onnxruntime.model import ORTModel +from optimum.onnxruntime.utils import evaluation_loop # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -305,7 +305,6 @@ def main(): ) os.makedirs(training_args.output_dir, exist_ok=True) - optimized_model_path = os.path.join(training_args.output_dir, "model_optimized.onnx") tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name or model_args.model_name_or_path) @@ -323,13 +322,15 @@ def main(): optimizer = ORTOptimizer.from_pretrained(model) # Optimize the model - optimizer.optimize( + optimized_model_path = optimizer.optimize( optimization_config=optimization_config, save_dir=training_args.output_dir, use_external_data_format=onnx_export_args.use_external_data_format, one_external_file=onnx_export_args.one_external_file, ) + model = ORTModelForQuestionAnswering.from_pretrained(optimized_model_path, provider=optim_args.execution_provider) + # Prepare the dataset downloading, preprocessing and metric creation to perform the evaluation and / or the # prediction step(s) if training_args.do_eval or training_args.do_predict: @@ -478,13 +479,12 @@ def compute_metrics(p: EvalPrediction): # During Feature creation dataset samples might increase, we will select required samples again eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) - ort_model = ORTModel( - optimized_model_path, - execution_provider=optim_args.execution_provider, - compute_metrics=compute_metrics, + outputs = evaluation_loop( + model=model, + dataset=eval_dataset, label_names=["start_positions", "end_positions"], + compute_metrics=compute_metrics, ) - outputs = ort_model.evaluation_loop(eval_dataset) predictions = post_processing_function(eval_examples, eval_dataset, outputs.predictions) metrics = compute_metrics(predictions) @@ -514,12 +514,12 @@ def compute_metrics(p: EvalPrediction): # During Feature creation dataset samples might increase, we will select required samples again predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) - ort_model = ORTModel( - optimized_model_path, - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=eval_dataset, label_names=["start_positions", "end_positions"], + compute_metrics=compute_metrics, ) - outputs = ort_model.evaluation_loop(predict_dataset) predictions = post_processing_function(predict_examples, predict_dataset, outputs.predictions) metrics = compute_metrics(predictions) diff --git a/examples/onnxruntime/optimization/text-classification/README.md b/examples/onnxruntime/optimization/text-classification/README.md index 42a99cc73d..3a7dce2b59 100644 --- a/examples/onnxruntime/optimization/text-classification/README.md +++ b/examples/onnxruntime/optimization/text-classification/README.md @@ -14,13 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. --> -# Text classification +# Text classification ## GLUE tasks -The script [`run_glue.py`](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/optimization/text-classification/run_glue.py) -allows us to apply graph optimizations and fusion using [ONNX Runtime](https://github.com/microsoft/onnxruntime) for sequence classification tasks such as -the ones from the [GLUE benchmark](https://gluebenchmark.com/). +The script [`run_glue.py`](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/optimization/text-classification/run_glue.py) allows us to apply graph optimizations and fusion using [ONNX Runtime](https://github.com/microsoft/onnxruntime) for sequence classification tasks such as the ones from the [GLUE benchmark](https://gluebenchmark.com/). The following example applies graph optimization on a DistilBERT fine-tuned on the sst-2 task. Here the optimization level is selected to be 1, enabling basic optimizations such as redundant node eliminations and constant folding. Higher optimization level will result in hardware dependent optimized graph. diff --git a/examples/onnxruntime/optimization/text-classification/run_glue.py b/examples/onnxruntime/optimization/text-classification/run_glue.py index a07193915b..222dda1507 100644 --- a/examples/onnxruntime/optimization/text-classification/run_glue.py +++ b/examples/onnxruntime/optimization/text-classification/run_glue.py @@ -42,7 +42,7 @@ from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer from optimum.onnxruntime.configuration import OptimizationConfig -from optimum.onnxruntime.model import ORTModel +from optimum.onnxruntime.utils import evaluation_loop # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -250,7 +250,6 @@ def main(): ) os.makedirs(training_args.output_dir, exist_ok=True) - optimized_model_path = os.path.join(training_args.output_dir, "model_optimized.onnx") tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) @@ -268,13 +267,17 @@ def main(): optimizer = ORTOptimizer.from_pretrained(model) # Optimize the model - optimizer.optimize( + optimized_model_path = optimizer.optimize( optimization_config=optimization_config, save_dir=training_args.output_dir, use_external_data_format=onnx_export_args.use_external_data_format, one_external_file=onnx_export_args.one_external_file, ) + model = ORTModelForSequenceClassification.from_pretrained( + optimized_model_path, provider=optim_args.execution_provider + ) + # Prepare the dataset downloading, preprocessing and metric creation to perform the evaluation and / or the # prediction step(s) if training_args.do_eval or training_args.do_predict: @@ -408,13 +411,13 @@ def compute_metrics(p: EvalPrediction): desc="Running tokenizer on the evaluation dataset", ) - ort_model = ORTModel( - optimized_model_path, - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + eval_dataset=eval_dataset, compute_metrics=compute_metrics, label_names=["label"], ) - outputs = ort_model.evaluation_loop(eval_dataset) + # Save metrics with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f: json.dump(outputs.metrics, f, indent=4, sort_keys=True) @@ -436,10 +439,12 @@ def compute_metrics(p: EvalPrediction): desc="Running tokenizer on the test dataset", ) - ort_model = ORTModel( - optimized_model_path, execution_provider=optim_args.execution_provider, label_names=["label"] + outputs = evaluation_loop( + model=model, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + label_names=["label"], ) - outputs = ort_model.evaluation_loop(predict_dataset) predictions = np.squeeze(outputs.predictions) if is_regression else np.argmax(outputs.predictions, axis=1) # Save predictions diff --git a/examples/onnxruntime/optimization/token-classification/run_ner.py b/examples/onnxruntime/optimization/token-classification/run_ner.py index 73db3671d2..2e7b63792c 100644 --- a/examples/onnxruntime/optimization/token-classification/run_ner.py +++ b/examples/onnxruntime/optimization/token-classification/run_ner.py @@ -38,7 +38,7 @@ from optimum.onnxruntime import ORTModelForTokenClassification, ORTOptimizer from optimum.onnxruntime.configuration import OptimizationConfig -from optimum.onnxruntime.model import ORTModel +from optimum.onnxruntime.utils import evaluation_loop # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -276,7 +276,6 @@ def main(): ) os.makedirs(training_args.output_dir, exist_ok=True) - optimized_model_path = os.path.join(training_args.output_dir, "model_optimized.onnx") tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name or model_args.model_name_or_path) @@ -480,12 +479,11 @@ def compute_metrics(p): desc="Running tokenizer on the validation dataset", ) - ort_model = ORTModel( - optimized_model_path, - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=eval_dataset, compute_metrics=compute_metrics, ) - outputs = ort_model.evaluation_loop(eval_dataset) # Save evaluation metrics with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f: @@ -509,12 +507,11 @@ def compute_metrics(p): desc="Running tokenizer on the prediction dataset", ) - ort_model = ORTModel( - optimized_model_path, - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=predict_dataset, compute_metrics=compute_metrics, ) - outputs = ort_model.evaluation_loop(predict_dataset) predictions = np.argmax(outputs.predictions, axis=2) # Remove ignored index (special tokens) From 9699f2313a834e45354602eed52a8f46e71fe41d Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 4 Jul 2024 10:32:50 +0200 Subject: [PATCH 3/4] update quant examples --- .../run_image_classification.py | 17 +++++++------- .../quantization/multiple-choice/run_swag.py | 14 +++++------ .../quantization/question-answering/README.md | 8 ++----- .../quantization/question-answering/run_qa.py | 21 ++++++++--------- .../text-classification/README.md | 5 +--- .../text-classification/run_glue.py | 23 ++++++++++--------- .../token-classification/README.md | 5 +--- .../token-classification/run_ner.py | 22 +++++++++--------- 8 files changed, 53 insertions(+), 62 deletions(-) diff --git a/examples/onnxruntime/quantization/image-classification/run_image_classification.py b/examples/onnxruntime/quantization/image-classification/run_image_classification.py index 3d0fa72882..6feaaef4f3 100644 --- a/examples/onnxruntime/quantization/image-classification/run_image_classification.py +++ b/examples/onnxruntime/quantization/image-classification/run_image_classification.py @@ -22,7 +22,6 @@ import sys from dataclasses import dataclass, field from functools import partial -from pathlib import Path from typing import Optional import datasets @@ -38,7 +37,6 @@ from optimum.onnxruntime import ORTQuantizer from optimum.onnxruntime.configuration import AutoCalibrationConfig, QuantizationConfig -from optimum.onnxruntime.model import ORTModel from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification from optimum.onnxruntime.preprocessors import QuantizationPreprocessor from optimum.onnxruntime.preprocessors.passes import ( @@ -47,6 +45,7 @@ ExcludeNodeAfter, ExcludeNodeFollowedBy, ) +from optimum.onnxruntime.utils import evaluation_loop logger = logging.getLogger(__name__) @@ -378,13 +377,16 @@ def compute_metrics(p: EvalPrediction): quantization_preprocessor.register_pass(ExcludeNodeFollowedBy("Add", "Softmax")) # Apply quantization on the model - quantizer.quantize( + quantized_model_path = quantizer.quantize( save_dir=training_args.output_dir, calibration_tensors_range=ranges, quantization_config=qconfig, preprocessor=quantization_preprocessor, use_external_data_format=onnx_export_args.use_external_data_format, ) + model = ORTModelForImageClassification.from_pretrained( + quantized_model_path, provider=optim_args.execution_provider + ) # Evaluation if training_args.do_eval: @@ -409,13 +411,12 @@ def compute_metrics(p: EvalPrediction): # Set the validation transforms eval_dataset = eval_dataset.with_transform(preprocess_function) - ort_model = ORTModel( - Path(training_args.output_dir) / "model_quantized.onnx", - execution_provider=optim_args.execution_provider, - compute_metrics=compute_metrics, + outputs = evaluation_loop( + model=model, + dataset=eval_dataset, label_names=[labels_column], + compute_metrics=compute_metrics, ) - outputs = ort_model.evaluation_loop(eval_dataset) # Save metrics with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f: json.dump(outputs.metrics, f, indent=4, sort_keys=True) diff --git a/examples/onnxruntime/quantization/multiple-choice/run_swag.py b/examples/onnxruntime/quantization/multiple-choice/run_swag.py index 9d9642c12d..9a8423f836 100644 --- a/examples/onnxruntime/quantization/multiple-choice/run_swag.py +++ b/examples/onnxruntime/quantization/multiple-choice/run_swag.py @@ -38,7 +38,6 @@ from optimum.onnxruntime import ORTModelForMultipleChoice, ORTQuantizer from optimum.onnxruntime.configuration import AutoCalibrationConfig, QuantizationConfig -from optimum.onnxruntime.model import ORTModel from optimum.onnxruntime.preprocessors import QuantizationPreprocessor from optimum.onnxruntime.preprocessors.passes import ( ExcludeGeLUNodes, @@ -46,6 +45,7 @@ ExcludeNodeAfter, ExcludeNodeFollowedBy, ) +from optimum.onnxruntime.utils import evaluation_loop # Will error if the minimal version of Transformers is not installed. The version of transformers must be >= 4.19.0 @@ -409,13 +409,14 @@ def compute_metrics(eval_predictions): quantization_preprocessor.register_pass(ExcludeNodeFollowedBy("Add", "Softmax")) # Apply quantization on the model - quantizer.quantize( + quantized_model_path = quantizer.quantize( save_dir=training_args.output_dir, calibration_tensors_range=ranges, quantization_config=qconfig, preprocessor=quantization_preprocessor, use_external_data_format=onnx_export_args.use_external_data_format, ) + model = ORTModelForMultipleChoice.from_pretrained(quantized_model_path, provider=optim_args.execution_provider) # Evaluation if training_args.do_eval: @@ -436,13 +437,12 @@ def compute_metrics(eval_predictions): load_from_cache_file=not data_args.overwrite_cache, ) - ort_model = ORTModel( - os.path.join(training_args.output_dir, "model_quantized.onnx"), - execution_provider=optim_args.execution_provider, - compute_metrics=compute_metrics, + outputs = evaluation_loop( + model=model, + dataset=eval_dataset, label_names=["label"], + compute_metrics=compute_metrics, ) - outputs = ort_model.evaluation_loop(eval_dataset) # Save evaluation metrics with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f: diff --git a/examples/onnxruntime/quantization/question-answering/README.md b/examples/onnxruntime/quantization/question-answering/README.md index 380afff8ca..8345ca8e4d 100644 --- a/examples/onnxruntime/quantization/question-answering/README.md +++ b/examples/onnxruntime/quantization/question-answering/README.md @@ -16,13 +16,9 @@ limitations under the License. # Question answering +The script [`run_qa.py`](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/quantization/question-answering/run_qa.py) allows us to apply different quantization approaches (such as dynamic and static quantization) as well as graph optimizations using [ONNX Runtime](https://github.com/microsoft/onnxruntime) for question answering tasks. -The script [`run_qa.py`](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/quantization/question-answering/run_qa.py) -allows us to apply different quantization approaches (such as dynamic and static quantization) as well as graph -optimizations using [ONNX Runtime](https://github.com/microsoft/onnxruntime) for question answering tasks. - -Note that if your dataset contains samples with no possible answers (like SQuAD version 2), you need to pass along -the flag `--version_2_with_negative`. +Note that if your dataset contains samples with no possible answers (like SQuAD version 2), you need to pass along the flag `--version_2_with_negative`. The following example applies post-training dynamic quantization on a DistilBERT fine-tuned on the SQuAD1.0 dataset. diff --git a/examples/onnxruntime/quantization/question-answering/run_qa.py b/examples/onnxruntime/quantization/question-answering/run_qa.py index 4a6a854fd9..4b5648d70d 100644 --- a/examples/onnxruntime/quantization/question-answering/run_qa.py +++ b/examples/onnxruntime/quantization/question-answering/run_qa.py @@ -24,7 +24,6 @@ import sys from dataclasses import dataclass, field from functools import partial -from pathlib import Path from typing import Optional import datasets @@ -39,7 +38,6 @@ from optimum.onnxruntime import ORTQuantizer from optimum.onnxruntime.configuration import AutoCalibrationConfig, QuantizationConfig -from optimum.onnxruntime.model import ORTModel from optimum.onnxruntime.modeling_ort import ORTModelForQuestionAnswering from optimum.onnxruntime.preprocessors import QuantizationPreprocessor from optimum.onnxruntime.preprocessors.passes import ( @@ -48,6 +46,7 @@ ExcludeNodeAfter, ExcludeNodeFollowedBy, ) +from optimum.onnxruntime.utils import evaluation_loop # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -651,25 +650,25 @@ def compute_metrics(p: EvalPrediction): quantization_preprocessor.register_pass(ExcludeNodeFollowedBy("Add", "Softmax")) # Apply quantization on the model - quantizer.quantize( + quantized_model_path = quantizer.quantize( save_dir=training_args.output_dir, calibration_tensors_range=ranges, quantization_config=qconfig, preprocessor=quantization_preprocessor, use_external_data_format=onnx_export_args.use_external_data_format, ) + model = ORTModelForQuestionAnswering.from_pretrained(quantized_model_path, provider=optim_args.execution_provider) # Evaluation if training_args.do_eval: logger.info("*** Evaluate ***") - ort_model = ORTModel( - Path(training_args.output_dir) / "model_quantized.onnx", - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=eval_dataset, compute_metrics=compute_metrics, label_names=["start_positions", "end_positions"], ) - outputs = ort_model.evaluation_loop(eval_dataset) predictions = post_processing_function(eval_examples, eval_dataset, outputs.predictions) metrics = compute_metrics(predictions) @@ -681,12 +680,12 @@ def compute_metrics(p: EvalPrediction): if training_args.do_predict: logger.info("*** Predict ***") - ort_model = ORTModel( - Path(training_args.output_dir) / "model_quantized.onnx", - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=predict_dataset, + compute_metrics=compute_metrics, label_names=["start_positions", "end_positions"], ) - outputs = ort_model.evaluation_loop(predict_dataset) predictions = post_processing_function(predict_examples, predict_dataset, outputs.predictions) metrics = compute_metrics(predictions) diff --git a/examples/onnxruntime/quantization/text-classification/README.md b/examples/onnxruntime/quantization/text-classification/README.md index 460bb56fba..95fd333517 100644 --- a/examples/onnxruntime/quantization/text-classification/README.md +++ b/examples/onnxruntime/quantization/text-classification/README.md @@ -18,10 +18,7 @@ limitations under the License. ## GLUE tasks -The script [`run_glue.py`](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/quantization/text-classification/run_glue.py) -allows us to apply different quantization approaches (such as dynamic and static quantization) as well as graph -optimizations using [ONNX Runtime](https://github.com/microsoft/onnxruntime) for sequence classification tasks such as -the ones from the [GLUE benchmark](https://gluebenchmark.com/). +The script [`run_glue.py`](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/quantization/text-classification/run_glue.py) allows us to apply different quantization approaches (such as dynamic and static quantization) as well as graph optimizations using [ONNX Runtime](https://github.com/microsoft/onnxruntime) for sequence classification tasks such as the ones from the [GLUE benchmark](https://gluebenchmark.com/). The following example applies post-training dynamic quantization on a DistilBERT fine-tuned on the sst-2 task. diff --git a/examples/onnxruntime/quantization/text-classification/run_glue.py b/examples/onnxruntime/quantization/text-classification/run_glue.py index bc141b2194..4b9ee0403c 100644 --- a/examples/onnxruntime/quantization/text-classification/run_glue.py +++ b/examples/onnxruntime/quantization/text-classification/run_glue.py @@ -23,7 +23,6 @@ import sys from dataclasses import dataclass, field from functools import partial -from pathlib import Path from typing import Optional import datasets @@ -44,7 +43,6 @@ from optimum.onnxruntime import ORTQuantizer from optimum.onnxruntime.configuration import AutoCalibrationConfig, QuantizationConfig -from optimum.onnxruntime.model import ORTModel from optimum.onnxruntime.modeling_ort import ORTModelForSequenceClassification from optimum.onnxruntime.preprocessors import QuantizationPreprocessor from optimum.onnxruntime.preprocessors.passes import ( @@ -53,6 +51,7 @@ ExcludeNodeAfter, ExcludeNodeFollowedBy, ) +from optimum.onnxruntime.utils import evaluation_loop # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -476,13 +475,16 @@ def compute_metrics(p: EvalPrediction): quantization_preprocessor.register_pass(ExcludeNodeFollowedBy("Add", "Softmax")) # Apply quantization on the model - quantizer.quantize( + quantized_model_path = quantizer.quantize( save_dir=training_args.output_dir, calibration_tensors_range=ranges, quantization_config=qconfig, preprocessor=quantization_preprocessor, use_external_data_format=onnx_export_args.use_external_data_format, ) + model = ORTModelForSequenceClassification.from_pretrained( + quantized_model_path, provider=optim_args.execution_provider + ) # Evaluation if training_args.do_eval: @@ -504,13 +506,13 @@ def compute_metrics(p: EvalPrediction): f" Evaluation results may suffer from a wrong matching." ) - ort_model = ORTModel( - Path(training_args.output_dir) / "model_quantized.onnx", - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=eval_dataset, compute_metrics=compute_metrics, label_names=["label"], ) - outputs = ort_model.evaluation_loop(eval_dataset) + # Save metrics with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f: json.dump(outputs.metrics, f, indent=4, sort_keys=True) @@ -525,12 +527,11 @@ def compute_metrics(p: EvalPrediction): if data_args.max_predict_samples is not None: predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) - ort_model = ORTModel( - Path(training_args.output_dir) / "model_quantized.onnx", - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=predict_dataset, label_names=["label"], ) - outputs = ort_model.evaluation_loop(predict_dataset) predictions = np.squeeze(outputs.predictions) if is_regression else np.argmax(outputs.predictions, axis=1) # Save predictions diff --git a/examples/onnxruntime/quantization/token-classification/README.md b/examples/onnxruntime/quantization/token-classification/README.md index f56388ed3c..540b3cbe2d 100644 --- a/examples/onnxruntime/quantization/token-classification/README.md +++ b/examples/onnxruntime/quantization/token-classification/README.md @@ -16,10 +16,7 @@ limitations under the License. # Token classification - -The script [`run_ner.py`](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/quantization/token-classification/run_ner.py) -allows us to apply different quantization approaches (such as dynamic and static quantization) as well as graph -optimizations using [ONNX Runtime](https://github.com/microsoft/onnxruntime) for token classification tasks. +The script [`run_ner.py`](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/quantization/token-classification/run_ner.py) allows us to apply different quantization approaches (such as dynamic and static quantization) as well as graph optimizations using [ONNX Runtime](https://github.com/microsoft/onnxruntime) for token classification tasks. The following example applies post-training dynamic quantization on a DistilBERT fine-tuned on the CoNLL-2003 task diff --git a/examples/onnxruntime/quantization/token-classification/run_ner.py b/examples/onnxruntime/quantization/token-classification/run_ner.py index 1cc12d3fbc..3a5798c57a 100644 --- a/examples/onnxruntime/quantization/token-classification/run_ner.py +++ b/examples/onnxruntime/quantization/token-classification/run_ner.py @@ -25,7 +25,6 @@ import sys from dataclasses import dataclass, field from functools import partial -from pathlib import Path from typing import Optional import datasets @@ -40,7 +39,6 @@ from optimum.onnxruntime import ORTQuantizer from optimum.onnxruntime.configuration import AutoCalibrationConfig, QuantizationConfig -from optimum.onnxruntime.model import ORTModel from optimum.onnxruntime.modeling_ort import ORTModelForTokenClassification from optimum.onnxruntime.preprocessors import QuantizationPreprocessor from optimum.onnxruntime.preprocessors.passes import ( @@ -49,6 +47,7 @@ ExcludeNodeAfter, ExcludeNodeFollowedBy, ) +from optimum.onnxruntime.utils import evaluation_loop # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -551,13 +550,16 @@ def compute_metrics(p): quantization_preprocessor.register_pass(ExcludeNodeFollowedBy("Add", "Softmax")) # Apply quantization on the model - quantizer.quantize( + quantized_model_path = quantizer.quantize( save_dir=training_args.output_dir, calibration_tensors_range=ranges, quantization_config=qconfig, preprocessor=quantization_preprocessor, use_external_data_format=onnx_export_args.use_external_data_format, ) + model = ORTModelForTokenClassification.from_pretrained( + quantized_model_path, provider=optim_args.execution_provider + ) # Evaluation if training_args.do_eval: @@ -572,12 +574,11 @@ def compute_metrics(p): desc="Running tokenizer on the validation dataset", ) - ort_model = ORTModel( - Path(training_args.output_dir) / "model_quantized.onnx", - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=eval_dataset, compute_metrics=compute_metrics, ) - outputs = ort_model.evaluation_loop(eval_dataset) # Save evaluation metrics with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f: @@ -602,12 +603,11 @@ def compute_metrics(p): desc="Running tokenizer on the prediction dataset", ) - ort_model = ORTModel( - Path(training_args.output_dir) / "model_quantized.onnx", - execution_provider=optim_args.execution_provider, + outputs = evaluation_loop( + model=model, + dataset=predict_dataset, compute_metrics=compute_metrics, ) - outputs = ort_model.evaluation_loop(predict_dataset) predictions = np.argmax(outputs.predictions, axis=2) # Remove ignored index (special tokens) From 4d8a7663d3a2fb82f88a59cc3489f1ccac9675b3 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 4 Jul 2024 10:56:39 +0200 Subject: [PATCH 4/4] fix --- examples/onnxruntime/optimization/multiple-choice/run_swag.py | 2 +- examples/onnxruntime/quantization/question-answering/run_qa.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/onnxruntime/optimization/multiple-choice/run_swag.py b/examples/onnxruntime/optimization/multiple-choice/run_swag.py index b2a9398d94..bcddc97590 100644 --- a/examples/onnxruntime/optimization/multiple-choice/run_swag.py +++ b/examples/onnxruntime/optimization/multiple-choice/run_swag.py @@ -346,7 +346,7 @@ def compute_metrics(eval_predictions): outputs = evaluation_loop( model=model, dataset=eval_dataset, - label_names=["labels"], + label_names=["label"], compute_metrics=compute_metrics, ) diff --git a/examples/onnxruntime/quantization/question-answering/run_qa.py b/examples/onnxruntime/quantization/question-answering/run_qa.py index 4b5648d70d..50661b7b42 100644 --- a/examples/onnxruntime/quantization/question-answering/run_qa.py +++ b/examples/onnxruntime/quantization/question-answering/run_qa.py @@ -683,7 +683,6 @@ def compute_metrics(p: EvalPrediction): outputs = evaluation_loop( model=model, dataset=predict_dataset, - compute_metrics=compute_metrics, label_names=["start_positions", "end_positions"], ) predictions = post_processing_function(predict_examples, predict_dataset, outputs.predictions)