Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate ORTModel class #1939

Merged
merged 4 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions optimum/onnxruntime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __init__(
label_names (`List[str]`, `optional`):
The list of keys in your dictionary of inputs that correspond to the labels.
"""

logger.warning(
"The class `optimum.onnxruntime.model.ORTModel` is deprecated and will be removed in the next release."
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
)

self.compute_metrics = compute_metrics
self.label_names = ["labels"] if label_names is None else label_names
self.session = InferenceSession(str(model_path), providers=[execution_provider])
Expand Down
62 changes: 61 additions & 1 deletion optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
import re
from enum import Enum
from inspect import signature
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from packaging import version
from tqdm import tqdm
from transformers import EvalPrediction
from transformers.trainer_pt_utils import nested_concat
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import logging

import onnxruntime as ort
Expand All @@ -30,6 +34,12 @@
from ..utils.import_utils import _is_package_available


if TYPE_CHECKING:
from datasets import Dataset

from .modeling_ort import ORTModel


logger = logging.get_logger(__name__)

ONNX_WEIGHTS_NAME = "model.onnx"
Expand Down Expand Up @@ -341,3 +351,53 @@ class ORTQuantizableOperator(Enum):
Resize = "Resize"
AveragePool = "AveragePool"
Concat = "Concat"


def evaluation_loop(
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
model: "ORTModel",
dataset: "Dataset",
label_names: Optional[List[str]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
):
"""
Run evaluation and returns metrics and predictions.

Args:
model (`ORTModel`):
The ONNXRuntime model to use for the evaluation step.
dataset (`datasets.Dataset`):
Dataset to use for the evaluation step.
label_names (`List[str]`, `optional`):
The list of keys in your dictionary of inputs that correspond to the labels.
compute_metrics (`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take an `EvalPrediction` and
return a dictionary string to metric values.
"""

all_preds = None
all_labels = None

for inputs in tqdm(dataset, desc="Evaluation"):
has_labels = all(inputs.get(k) is not None for k in label_names)
if has_labels:
labels = tuple(np.array([inputs.get(name)]) for name in label_names)
if len(labels) == 1:
labels = labels[0]
else:
labels = None

inputs = {key: np.array([inputs[key]]) for key in model.input_names if key in inputs}
preds = model(**inputs)

if len(preds) == 1:
preds = preds[0]

all_preds = preds if all_preds is None else nested_concat(all_preds, preds, padding_index=-100)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

if compute_metrics is not None and all_preds is not None and all_labels is not None:
metrics = compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
else:
metrics = {}

return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=len(dataset))
Loading