Skip to content

Commit

Permalink
Add ORT inference (#113)
Browse files Browse the repository at this point in the history
* added gpu extras and added > transformers for token-classification pipeline issue

* added numpy and huggingface hub to required packages

* added modeling_* classes

* adding tests and pipelines

* remove vs code folder

* added test model and adjusted gitignore

* add readme for tests

* working tests

* added some documentation

* will ci run?

* added real model checkpoints

* test ci

* fix styling

* fix some documentation

* more doc fixes

* added some feedback and wording from michael and lewis

* renamed model class to ORTModelForXX

* moved from_transformers to from_pretrained

* applied ellas feedback

* make style

* first version of ORTModelForCausalLM without past-keys

* added first draft of new .optimize method

* added better quantize method

* fix import

* remove optimize and quantize

* added lewis feedback

* added style for test

* added >>> to code snippets

* style

* added condition for staging tests

* feedback morgan & michael

* added action

* forgot to install pytest

* forgot sentence piece

* made sure we won't have import conflicts

* make style happy
  • Loading branch information
philschmid authored Apr 28, 2022
1 parent 7417202 commit a31e59e
Show file tree
Hide file tree
Showing 20 changed files with 2,036 additions and 10 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/test_modeling_ort.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Onnxruntime Models (Inference) / Python - Test

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
os: [ubuntu-20.04 ] #, windows-2019, macos-10.15]

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .[tests,onnxruntime]
- name: Test with pytest
shell: bash
run: |
pytest tests/onnxruntime/test_modeling_ort.py
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,7 @@ dmypy.json

# Models
*.onnx
# include small test model for tests
!tests/assets/onnx/model.onnx

.vscode
4 changes: 4 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
title: 🤗 Optimum
- local: quickstart
title: Quickstart
- local: pipelines
title: Pipelines for inference
title: Get started
- sections:
- local: onnxruntime/modeling_ort
title: Inference
- local: onnxruntime/configuration
title: Configuration
- local: onnxruntime/optimization
Expand Down
103 changes: 103 additions & 0 deletions docs/source/onnxruntime/modeling_ort.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Optimum Inference with ONNX Runtime

Optimum is a utility package for building and running inference with accelerated runtime like ONNX Runtime.
Optimum can be used to load optimized models from the [Hugging Face Hub](hf.co/models) and create pipelines
to run accelerated inference without rewriting your APIs.

## Switching from Transformers to Optimum Inference

The Optimum Inference models are API compatible with Hugging Face Transformers models. This means you can just replace your `AutoModelForXxx` class with the corresponding `ORTModelForXxx` class in `optimum`. For example, this is how you can use a question answering model in `optimum`:

```diff
from transformers import AutoTokenizer, pipeline
-from transformers import AutoModelForQuestionAnswering
+from optimum.onnxruntime import ORTModelForQuestionAnswering

-model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") # pytorch checkpoint
+model = ORTModelForQuestionAnswering.from_pretrained("optimum/roberta-base-squad2") # onnx checkpoint
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")

onnx_qa = pipeline("question-answering",model=model,tokenizer=tokenizer)

question = "What's my name?"
context = "My name is Philipp and I live in Nuremberg."
pred = onnx_qa(question, context)
```

Optimum Inference also includes methods to convert vanilla Transformers models to optimized ones. Simply pass `from_transformers=True` to the `from_pretrained()` method, and your model will be loaded and converted to ONNX on-the-fly:

```python
>>> from transformers import AutoTokenizer, pipeline
>>> from optimum.onnxruntime import ORTModelForSequenceClassification
# load model from hub and convert
>>> model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english",from_transformers=True)
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
# create pipeline
>>> onnx_classifier = pipeline("text-classification",model=model,tokenizer=tokenizer)
>>> result = onnx_classifier(text="This is a great model")
[{'label': 'POSITIVE', 'score': 0.9998838901519775}]
```
You can find a complete walkhrough Optimum Inference for ONNX Runtime in this [notebook](xx).
### Working with the [Hugging Face Model Hub](https://hf.co/models)
The Optimum model classes like [`~ORTModelForSequenceClassification`] are integrated with the [Hugging Face Model Hub](https://hf.co/models)), which means you can not only
load model from the Hub, but also push your models to the Hub with `push_to_hub()` method. Below is an example which downloads a vanilla Transformers model
from the Hub and converts it to an optimum onnxruntime model and pushes it back into a new repository.
<!-- TODO: Add Quantizer into example when UX improved -->
```python
>>> from transformers import AutoTokenizer
>>> from optimum.onnxruntime import ORTModelForSequenceClassification
# load model from hub and convert
>>> model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english",from_transformers=True)
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
# save converted model
>>> model.save_pretrained("a_local_path_for_convert_onnx_model")
>>> tokenizer.save_pretrained("a_local_path_for_convert_onnx_model")
# push model onnx model to HF Hub
>>> model.push_to_hub("a_local_path_for_convert_onnx_model",
repository_id="my-onnx-repo",
use_auth_token=True
)
```

## ORTModel

[[autodoc]] onnxruntime.modeling_ort.ORTModel

## ORTModelForFeatureExtraction

[[autodoc]] onnxruntime.modeling_ort.ORTModelForFeatureExtraction

## ORTModelForQuestionAnswering

[[autodoc]] onnxruntime.modeling_ort.ORTModelForQuestionAnswering

## ORTModelForSequenceClassification

[[autodoc]] onnxruntime.modeling_ort.ORTModelForSequenceClassification

## ORTModelForTokenClassification

[[autodoc]] onnxruntime.modeling_ort.ORTModelForTokenClassification

218 changes: 218 additions & 0 deletions docs/source/pipelines.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Optimum pipelines for inference

The [`pipeline`] makes it simple to use models from the [Model Hub](https://huggingface.co/models) for accelerated inference on a variety of tasks such as text classification.
Even if you don't have experience with a specific modality or understand the code powering the models, you can still use them with the [`pipeline`]! This tutorial will teach you to:

<Tip>

You can also use the `pipeline()` function from Transformers and provide your `OptimumModel`.

</Tip>

Currenlty supported tasks are:

**Onnx Runtime**

* `feature-extraction`
* `text-classification`
* `token-classification`
* `question-answering`
* `zero-shot-classification`
* `text-generation`

## Optimum pipeline usage

While each task has an associated [~`pipeline`], which it is simpler to use the general [~`pipeline`] abstraction which contains all the specific task pipelines.
The [~`pipeline`] automatically loads a default model and tokenizer capable of inference for your task.

1. Start by creating a [~`pipeline`] and specify an inference task:

```python
>>> from optimum import pipeline

>>> classifier = pipeline(task="text-classification", accelerator="ort")
```
2. Pass your input text to the [~`pipeline`]:
```python
>>> classifier("I like you. I love you.")
[{'label': 'POSITIVE', 'score': 0.9998838901519775}]
```

_Note: The default models used in the [~`pipeline`] are not optimized or quantized, there won't be an performance improvement compared to there pytorch counter parts._

### Using vanilla Transformers model and converting to ONNX

The [`pipeline`] accepts any supported model from the [Model Hub](https://huggingface.co/models).
There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task.
Once you've picked an appropriate model, load it with the `from_pretrained("{model_id}",from_transformers=True)` method associated with the `ORTModelFor*`
[`AutoTokenizer'] class. For example, here's how you can load the [`ORTModelForQuestionAnswering`] class for question answering:

```python
>>> from transformers import AutoTokenizer
>>> from optimum.onnxruntime import ORTModelForQuestionAnswering
>>> from optimum import pipeline

>>> tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
>>> # loading the pytorch checkpoint and converting to ORT format by providing the from_transformers=True parameter
>>> model = ORTModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2",from_transformers=True)

>>> onnx_qa = pipeline("question-answering", model=model, tokenizer=tokenizer)
>>> question = "What's my name?"
>>> context = "My name is Philipp and I live in Nuremberg."

>>> pred = onnx_qa(question=question, context=context)
```

### Using Optimum models

The [`pipeline`] is tightly integrated with [Model Hub](https://huggingface.co/models) and can load optimized models directly, e.g. those created with OnnxRuntime.
There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task.
Once you've picked an appropriate model, load it with the `from_pretrained()` method associated with the corresponding `ORTModelFor*`
and [`AutoTokenizer'] class. For example, here's how you can load an optimized model for question answering:

```python
>>> from transformers import AutoTokenizer
>>> from optimum.onnxruntime import ORTModelForQuestionAnswering
>>> from optimum import pipeline

>>> tokenizer = AutoTokenizer.from_pretrained("optimum/roberta-base-squad2")
>>> # loading already converted and optimized ORT checkpoint for inference
>>> model = ORTModelForQuestionAnswering.from_pretrained("optimum/roberta-base-squad2")

>>> onnx_qa = pipeline("question-answering", model=model, tokenizer=tokenizer)
>>> question = "What's my name?"
>>> context = "My name is Philipp and I live in Nuremberg."

>>> pred = onnx_qa(question=question, context=context)
```


### Optimizing and Quantizing in Pipelines

The [`pipeline`] can not only run inference on vanilla Onnxruntime checkpoints you can also use checkpoints optimized with `ORTQuantizer` and `ORTOptimizer`
Below you can find two examples on how you could [~`ORTOptimizer`] and [~`ORTQuantizer`] to optimize/quantize your model and use it for inference afterwards.

### Quantizing with [~`ORTQuantizer`]

```python
>>> from pathlib import Path
>>> from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer
>>> from optimum.onnxruntime.configuration import AutoQuantizationConfig
>>> from optimum.pipelines import pipeline
>>> from transformers import AutoTokenizer

# define model_id and load tokenizer
>>> model_id = "distilbert-base-uncased-finetuned-sst-2-english"
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> save_path = Path("optimum_model")
>>> save_path.mkdir(exist_ok=True)

# use ORTQuantizer to export the model and define quantization configuration
>>> quantizer = ORTQuantizer.from_pretrained(model_id, feature="sequence-classification")
>>> qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True)

# apply the quantization configuration to the model
>>> quantizer.export(
onnx_model_path=save_path / "model.onnx",
onnx_quantized_model_output_path=save_path / "model-quantized.onnx",
quantization_config=qconfig,
)
>>> quantizer.model.config.save_pretrained(save_path) # saves config.json

# load optimized model from local path or repository
>>> model = ORTModelForSequenceClassification.from_pretrained(save_path,file_name="model-quantized.onnx")

# create transformers pipeline
>>> onnx_clx = pipeline("text-classification", model=model, tokenizer=tokenizer)
>>> text = "I like the new ORT pipeline"
>>> pred = onnx_clx(text)
>>> print(pred)

# save model & push model to the hub
>>> tokenizer.save_pretrained("new_path_for_directory")
>>> model.save_pretrained("new_path_for_directory")
>>> model.push_to_hub("new_path_for_directory",
repository_id="my-onnx-repo",
use_auth_token=True
)
```

### Optimizing with [~`ORTOptimizer`]

```python
>>> from pathlib import Path
>>> from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer
>>> from optimum.onnxruntime.configuration import OptimizationConfig
>>> from optimum.pipelines import pipeline

# define model_id and load tokenizer
>>> model_id = "distilbert-base-uncased-finetuned-sst-2-english"
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> save_path = Path("optimum_model")
>>> save_path.mkdir(exist_ok=True)

# use ORTOptimizer to export the model and define quantization configuration
>>> optimizer = ORTOptimizer.from_pretrained(model_id, feature="sequence-classification")
>>> optimization_config = OptimizationConfig(optimization_level=2)

# apply the optimization configuration to the model
>>> optimizer.export(
onnx_model_path=save_path / "model.onnx",
onnx_optimized_model_output_path=save_path / "model-optimized.onnx",
optimization_config=optimization_config,
)
>>> optimizer.model.config.save_pretrained(save_path) # saves config.json

# load optimized model from local path or repository
>>> model = ORTModelForSequenceClassification.from_pretrained(save_path,file_name="model-optimized.onnx")

# create transformers pipeline
>>> onnx_clx = pipeline("text-classification", model=model, tokenizer=tokenizer)
>>> text = "I like the new ORT pipeline"
>>> pred = onnx_clx(text)
>>> print(pred)

# save model & push model to the hub
>>> tokenizer.save_pretrained("new_path_for_directory")
>>> model.save_pretrained("new_path_for_directory")
>>> model.push_to_hub("new_path_for_directory",
repository_id="my-onnx-repo",
use_auth_token=True)
```

## Transformers pipeline usage

The [`pipeline`] is just a light wrapper around the `transformers.pipeline` function to enable checks for supported tasks and additional features
, like quantization and optimization. This being said you can use the `transformers.pipeline` and just replace your `AutoFor*` with the optimum
`ORTModelFor*` class.

```diff
from transformers import AutoTokenizer, pipeline
-from transformers import AutoModelForQuestionAnswering
+from optimum.onnxruntime import ORTModelForQuestionAnswering

-model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
+model = ORTModelForQuestionAnswering.from_transformers("optimum/roberta-base-squad2")
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")

onnx_qa = pipeline("question-answering",model=model,tokenizer=tokenizer)

question = "What's my name?"
context = "My name is Philipp and I live in Nuremberg."
pred = onnx_qa(question, context)
```
Loading

0 comments on commit a31e59e

Please sign in to comment.