-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to go about utilizing MBART for conditional generation with beam search in ONNXRuntime with TensorRT/CUDA #15871
Comments
I was also adjusting the BART export to work for mBART in a translation setting. There you need to be able to set the
Hope this helps! |
Thanks for jumping in with the tip @HaukurPall! Did this approach allow you to not only export the mBART model to ONNX but also run it (with increased speed) on CUDA/TensorRT execution providers? Moreover, I assume you made a custom version of |
Thanks for opening an issue @JeroendenBoef! Pinging @lewtun, @mfuntowicz, should this issue be moved to optimum? |
Thanks for the ping! Yes, I think it would make sense to move this issue to the @JeroendenBoef we currently have a PR in Once that is completed, I think it should address most of the points raised in this issue! |
Hey @JeroendenBoef. No, I was not able to get the inference working efficiently on the CUDA execution providers. I even attempted to use the IOBindings (as suggested by the ONNX team) but was not successful. I have put this endeavour aside until there is better support for autoregressive inference. If you still want to try this, there is a different approach to exporting the models presented in https://github.com/Ki6an/fastT5/. This is for T5, but some of the code has been adjusted to work for mBART, see issue: Ki6an/fastT5#7. I did not try to running that model on CUDA as it would require some work getting the IOBindings correct/efficient. |
Thanks for the reply and the pointer to the new PR on Thanks for the detailed response @HaukurPall, this saves me some headaches and time :). I was already afraid there would not be an improved performance but now I have confirmation that I should also postpone my efforts on this until there is a better approach in place for ORT seq2seq models. |
Hi HuggingFace team,
Last December I looked into exporting
MBartForConditionalGeneration
from"facebook/mbart-large-50-many-to-one-mmt"
for the purpose of multilingual machine translation. Originally I followed the approach as described in this BART + beam search example, extending the example to support MBART and overriding the max 2GB model size. While this approach worked forCPUExecutionProvider
in ORT sessions, it did not actually improve runtime, nor did it work forTensorRT
orCUDA
execution providers (out of cuda memory and dynamic shape inference failure).Today I saw this issue and exported
MBartForConditionalGeneration
withpython -m transformers.onnx --model=facebook/mbart-large-50-many-to-one-mmt --feature seq2seq-lm-with-past --atol=5e-5 onnx/
. While this worked for exporting to ONNX (passing all validation checks), I couldn't run an actual ORT session due to input dimensionality mismatch (past keys encoder/decoder missing forseq2seq-lm-with-past
,decoder_inputs_ids
anddecoder_attention_mask
missing forseq2seq-lm
).I could use some clarification as to whether this is the implementation I am looking for (does the latter ONNX export support
.generate()
through beam search or should I refocus my attempts at the BART + beam search modification). In case the newer command line ONNX export implementation is what I require, which feature head would be the correct head for the ConditionalGeneration many-to-one-mmt MBART head (seq2seq-lm
orseq2seq-lm-with-past
) and where can I find the additional inputs that I need for the model to run.generate()
in an ORT session? The BART beam search implementation I mentioned earlier requiredinput_ids
,attention_mask
,num_beams
,max_length
anddecoder_start_token_id
. The required inputs for the newer conversion are a bit more confusing to me.I assume @lewtun would be the person to ask for help here but I appreciate any pointers!
Environment info
transformers
version: 4.16.2The text was updated successfully, but these errors were encountered: