Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx use more memory than pytorch for some model #16264

Open
feng-1985 opened this issue Jun 7, 2023 · 3 comments
Open

onnx use more memory than pytorch for some model #16264

feng-1985 opened this issue Jun 7, 2023 · 3 comments
Labels
ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Comments

@feng-1985
Copy link

feng-1985 commented Jun 7, 2023

Describe the issue

cuda 10.2
onnx=1.8
onnxruntime-gpu=1.6

For sequnce labeling task (input the token ids, output the start_pos, end_pos), the pytorch use 1.8G, but onnx use 1.9G (although the onnx inference speed is faster). --- torch 1.10, bert base fine-tuning
For text classification task, the pytoch use 2.2G, onnx just use 0.8G. -- torch 1.9.0, roberta_base fine-tuning

To reproduce

I am use this script and datasets sequence labeling, and running just five epoch.
Then convert the torch model to onnx model.

Urgency

No response

Platform

Linux

OS Version

ubuntu 18

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.6

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

No response

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. labels Jun 7, 2023
@tianleiwu
Copy link
Contributor

For bert model, try optimizing it like the following:

pip install onnxruntime-gpu==1.15
python -m onnxruntime.transformers.optimizer --input bert.onnx --output bert_fp16.onnx --float16 --use_gpu

If everything is good, it will use fused attention kernel (like flash attention etc), which could save memory for long sequence.

Note that 1.6 does not have fused attention so you will need upgrade onnxruntime-gpu to latest version.

@feng-1985
Copy link
Author

feng-1985 commented Jun 17, 2023

Thanks for response. For the production environment, only cuda 10.2 is available, so i use the onnxruntime-gpu=1.6.
Another relate question

  1. If I convert the model to float16, does the cuda 10.2 support ?
  2. python -m onnxruntime.transformers.optimizer --input bert.onnx --output bert_fp16.onnx --float16 --use_gpu
    use_gpu the default value is false, use the parameter seems just rename the model name and set the EPs url ?

@tianleiwu
Copy link
Contributor

@feng-1985,

  1. float16 is supported in cuda 10.2 and onnxruntime-gpu 1.6.
  2. run python -m onnxruntime.transformers.optimizer --help to see the usage. The tool will apply graph optimization to convert the model graph to a new one. You can try add --use_mask_index which is not default in onnxruntime-gpu 1.6.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.
Projects
None yet
Development

No branches or pull requests

2 participants