Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ORT inference #113

Merged
merged 39 commits into from
Apr 28, 2022
Merged

Add ORT inference #113

merged 39 commits into from
Apr 28, 2022

Conversation

philschmid
Copy link
Contributor

@philschmid philschmid commented Mar 24, 2022

Poc Accelerated Inference

Our goal with Optimum is to offer the open-source reference toolkit to do transformers acceleration work, and inference is one part of it.

Over the last couple of weeks and months, I created a PoC, enabling us to create OptimizedModelFor* classes to mimic the transformers API to allow support for transformers.pipelines and other features.

I have opened this PR to discuss the current POC. I also put quite some work into already providing documentation and tests.
You can access the documentation here:
The main two new pages are under onnxruntime "Inference" and under getting started "Pipelines."

The current PR supports:

  • 5 Pipelines along with 9 architectures for CPU & GPU
  • implements a from_transformers method to automatically convert models with transformers.onnx
  • integrates with the hub to pull and push models (onnx)
  • integrated support for quantizing and optimizing (very experimental atm)

Below is a snippet on how you can load the vanilla transformers model and convert it to then use it in a pipeline

>>> from transformers import AutoTokenizer
>>> from optimum.onnxruntime import OnnxForQuestionAnswering
>>> from optimum import optimum_pipeline

>>> tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
>>> model = OnnxForQuestionAnswering.from_transformers("deepset/roberta-base-squad2")

>>> onnx_qa = optimum_pipeline("question-answering", model=model, tokenizer=tokenizer)
>>> question = "Whats my name?"
>>> context = "My Name is Philipp and I live in Nuremberg."
>>> pred = onnx_qa(question=question, context=context)

More examples can be found here:

Todos

  • use ORTOptimizer and ORTQuantizer in .quantize and .optimize
  • add OnnxForCausalLM model
  • rename model classes to ORTModelForXX
  • move from_transformers to from_pretrained(model_id, from_transformers=True)
    After another discussion with the Optimum Team:
  • Remove .optimize and .quantize for now to think about a save approach for offline optimizations

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge work @philschmid !!!
I left a few comments and questions to initiate the discussion on a few details about this really cool feature.

optimum/modeling_base.py Show resolved Hide resolved
optimum/modeling_base.py Outdated Show resolved Hide resolved
optimum/modeling_base.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also accept numpy arrays directly, and return tensors of the same nature as the input (numpy arrays if numpy array, torch tensors if torch.Tensor), but I might be missing some details with pipelines.

onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy()
# run inference
outputs = self.model.run(None, onnx_inputs)
# converts output to namedtuple for pipelines post-processing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a return_dict parameter? Just a suggestion, I don't know how useful it would be.

optimum/pipelines.py Outdated Show resolved Hide resolved
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this awesome feature 🔥 ! I left a few nits about the docs and some general questions about the naming of the OnnxForXxx classes.

I also agree with @michaelbenayoun that we need to discuss whether these classes should also be able to optimise / quantize ONNX models (I can see pros and cons with each approach).

Otherwise, great stuff!

@@ -131,3 +131,7 @@ dmypy.json

# Models
*.onnx
# include small test model for tests
!tests/assets/onnx/model.onnx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have a small model on the Hub, e.g. under the optimum org?

docs/source/onnxruntime/modeling_ort.mdx Outdated Show resolved Hide resolved
docs/source/onnxruntime/modeling_ort.mdx Outdated Show resolved Hide resolved
docs/source/onnxruntime/modeling_ort.mdx Outdated Show resolved Hide resolved
docs/source/onnxruntime/modeling_ort.mdx Outdated Show resolved Hide resolved
The [`optimum_pipeline`] makes it simple to use models from the [Model Hub](https://huggingface.co/models) for accelerated inference on a variety of tasks such as text classification.
Even if you don't have experience with a specific modality or understand the code powering the models, you can still use them with the [`optimum_pipeline`]! This tutorial will teach you to:

<Tip>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this tip be moved below, i.e. just after the example in Optimum pipeline usage?

docs/source/pipelines.mdx Outdated Show resolved Hide resolved
optimum/modeling_base.py Outdated Show resolved Hide resolved
optimum/modeling_base.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Show resolved Hide resolved
Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @philschmid, this new feature is 🔥. I left couple of minor comments, also happy to discuss about the optimize and quantize methods as @michaelbenayoun and @lewtun pointed out.

setup.py Outdated Show resolved Hide resolved
optimum/modeling_base.py Show resolved Hide resolved
optimum/modeling_base.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
@philschmid
Copy link
Contributor Author

philschmid commented Apr 13, 2022

I reworked the optimize and quantize methods to reuse ORTOptimizer and ORTQuantizer parts to create the best possible DX with the least configurations. Currently the quantize only works with is_static=false which is okay IMO. For static quantization you can/should use the ORTQuantizer. below is a nice e2e snippet

from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForSequenceClassification
from optimum.onnxruntime.configuration import OptimizationConfig, AutoQuantizationConfig
from optimum.pipelines import optimum_pipeline

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english",from_transformers=True)

# optimization_config=99 enables all available graph optimisations
optimization_config = OptimizationConfig(optimization_level=99)
model.optimize(optimization_config=optimization_config)

# dynamic quantization configuration
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True)
model.quantize(quantization_config=qconfig)

# create transformers pipeline
onnx_clx = optimum_pipeline("text-classification", model=model, tokenizer=tokenizer)
text = "I like the new ORT pipeline"
pred = onnx_clx(text)
print(pred)

Copy link
Contributor

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge work @philschmid 🔥
I just left a couple of minor comments.

tests/README.md Show resolved Hide resolved
docs/source/pipelines.mdx Outdated Show resolved Hide resolved
@AlekseyKorshuk
Copy link

AlekseyKorshuk commented Apr 21, 2022

Hi All,

You are doing kind of the same as I did here: https://github.com/AlekseyKorshuk/optimum-transformers. Btw this project is done to show myself to join Hugging Face Team 🤗

This may seem a little cheeky on my part. But I have now applied for an internship at Hugging Face and I would be very happy to work on this update fully on the part of the company.
Let me know if this option is possible, So I would join the work immediately and start working on this cool open source project as a part of the Team.

Please don’t hesitate to contact me should you need further information.

Best regards,
Aleksey Korshuk

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for iterating on this great feature @philschmid 🚀 ! Most of my comments are nits and I think the only "major" thing involves reworking the docs about quantization / optimization.

Apart from that, LGTM!

docs/source/onnxruntime/modeling_ort.mdx Outdated Show resolved Hide resolved
docs/source/onnxruntime/modeling_ort.mdx Outdated Show resolved Hide resolved
docs/source/onnxruntime/modeling_ort.mdx Outdated Show resolved Hide resolved

The Optimum model classes, e.g. [`~ORTModelForSequenceClassification`] are directly integrated with the [Hugging Face Model Hub](https://hf.co/models)) meaning you can not only
load model from the Hub but also push your models to the Hub with `push_to_hub()` method. Below you find an example which pulls a vanilla transformers model
from the Hub and converts it to an optimum model uses the `[~ORTQuantizer]` to quantize the model and pushes it back into a new repository.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't seem to do quantization in this example right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True removed it for now and added a TODO

docs/source/pipelines.mdx Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/pipelines.py Show resolved Hide resolved
setup.py Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
tests/test_modeling_base.py Show resolved Hide resolved
@philschmid philschmid changed the title [WIP] Add ORT inference Add ORT inference Apr 26, 2022
Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems great, thanks a lot @philschmid, it will be super useful!

pred = onnx_qa(question, context)
```

Optimum Inference also includes methods to convert vanilla Transformers models to optimized ones. Simply pass `from_transformers=True` to the `from_pretrained()` method, and your model will be loaded and converted to ONNX on-the-fly:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really an "optimized" one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be that critical, or do you have a suggestion what we could use instead of "optimized" ?

docs/source/onnxruntime/modeling_ort.mdx Outdated Show resolved Hide resolved
docs/source/onnxruntime/modeling_ort.mdx Outdated Show resolved Hide resolved
docs/source/pipelines.mdx Outdated Show resolved Hide resolved
optimum/modeling_base.py Outdated Show resolved Hide resolved
@staticmethod
def load_model(path: Union[str, Path], provider=None):
"""
loads ONNX Inference session with Provider. Default Provider is if GPU available else `CPUExecutionProvider`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uppercase, and missing CUDAExecutionProvider after "Default" i think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we should default to GPU, it will pick the GPU:0 which can be problematic if the user already has something running on that GPU.

IMO: Default to CPU and use CUDAExecutionProvider if:

  1. available
  2. provided by the user (i.e. he knows what he's doing, nothing hidden and follow what ORT does by default too)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's stick with it for now and adjust it afterward when we have feedback or issues. That way it is the easiest UX. And you can always provide provider, e.g. Nils does the same for sentence-transformers.

optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
Copy link
Member

@mfuntowicz mfuntowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look great overall, congrats 💪🏻.

Few performances/usability comments I think it would be interesting to address here or in a following PR, no strong blocking point for me, a very good starting point

Comment on lines +176 to +177
if len(str(model_id).split("@")) == 2:
model_id, revision = model_id.split("@")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about uniformizing the way to grab a specific revision for a model towards what transformers provides, using a specific revision parameter:

revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great Idea and suggestion. But let's try to get this PR in. I would create separate small issues we can then work on after the PR is merged.

@staticmethod
def load_model(path: Union[str, Path], provider=None):
"""
loads ONNX Inference session with Provider. Default Provider is if GPU available else `CPUExecutionProvider`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we should default to GPU, it will pick the GPU:0 which can be problematic if the user already has something running on that GPU.

IMO: Default to CPU and use CUDAExecutionProvider if:

  1. available
  2. provided by the user (i.e. he knows what he's doing, nothing hidden and follow what ORT does by default too)

available else `CPUExecutionProvider`
"""
if provider is None:
provider = "CUDAExecutionProvider" if _is_gpu_available() else "CPUExecutionProvider"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about wrapping thoses names within an enum to make it more user-friendly for the user?

class ORTDevice(Enum):
    CPU = "CPUExecutionProvider"
    CUDA = "CUDAExecutionProvider"
    TENSORRT = "TensorRTExecutionProvider"
    ...

I find those name not very friendly for someone coming from PyTorch/Device

or even something a la PyTorch:

class device:
    def __init__(self, specs: str):
        parse_specs(specs)
    
   @property
    def is_cpu(self) -> bool:
        pass
  
   @property
   def is_cuda(self) -> bool:
      pass

   @property
   def num_cores(self) -> Optional[int]:
        return specs["core"]  # for symplicity

   @property
   def gpu_id(self) -> Optional[int]:
        return specs["gpu_id"]
   
   def to_execution_provider_specs(self) -> Tuple[str, dict[str, any]]
       return "CPUExecutionProvider", {"num_intraops_threads": self.num_cores}

   
cpu_bound_device = device("cpu:0-16")
cpu_cores_device = device("cpu:16")
gpu_device = device("cuda:0")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great Idea and suggestion. But let's try to get this PR in. I would create separate small issues we can then work on after the PR is merged.

Comment on lines +680 to +683
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
"attention_mask": attention_mask.cpu().detach().numpy(),
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use OrtValue in the future to reduce the round-trip of the data Torch -> Numpy -> ORT

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is especially true if the device property binds to cuda, then we are allocating Torch(GPU) -> Torch(CPU) -> Numpy -> ORT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great Idea and suggestion. But let's try to get this PR in. I would create separate small issues we can then work on after the PR is merged.

optimum/pipelines.py Outdated Show resolved Hide resolved
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
accelerator: Optional[str] = "ort",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"onnxruntime"? Again, for non-expert, "ort" might not ring a bell

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in transformers we have pt and tf as well so i kept it similar + we have OrtModel* classes. I would rather stick with ort

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok lets go this way, we can still have "onnxruntime" as an alias for "ort" if we want both.

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks great, thanks a lot @philschmid ! Ready to merge after the confirmation about the tests workflow

tests/test_modeling_base.py Outdated Show resolved Hide resolved
tests/test_modeling_base.py Outdated Show resolved Hide resolved
@@ -0,0 +1,467 @@
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see any yaml file defining a workflow to handle the tests from test_modeling_ort.py ? It seems that test_onnxruntime.yml is only running the tests from test_onnxruntime.py and not all the ones from the onnxruntime directory. The tests from test_modeling_base.py would be covered by test_optimum_common.yml

@philschmid philschmid merged commit a31e59e into main Apr 28, 2022
@philschmid philschmid deleted the add-ort-inference branch April 28, 2022 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants