Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAX][ONNX][FIX] add a parser to handle expression in the shape dim names #17505

Conversation

PatrikPerssonInceptron
Copy link
Contributor

Problem

The shape names in an ONNX model can contain expressions such as the shape int64[batch_size,past_sequence_length + sequence_length] of the attention mask of an LLM. In this case, the second dimension contains an expression past_sequence_length + sequence_length where past_sequence_length and sequence_length should be individual variables added together. However, currently, a new variable named "past_sequence_length + sequence_length" is instead created when translating the graph.

Fix

I added a simple parser that creates individual size variables for the variable names and generates the resulting prim expression. Note, in order to keep the parser simple, it evaluates expressions left to right. Not accounting for operator precedence.

Test

I added regression tests to verify that the onnx shape dim expression are evaluated correctly.

Additional small fixes

In the case when PrimValues are encountered in the BinaryBase, they are not always fully extracted before turning them into numpy arrays. I added an additional check that extracts the value from IntImm and FloatImm types before converting them to numpy arrays.

expressions such as past_sequence_length + sequence_length where each
variable becomes a tvm.tir.SizeVar
updated binary base to completely unpack relax.PrimValue if it contains
tir.IntImm or tir.FloatImm

added regression tests
@Hzfengsy
Copy link
Member

Thanks for the improvements

@Hzfengsy Hzfengsy merged commit d5b9f5c into apache:main Nov 10, 2024
19 checks passed
@PatrikPerssonInceptron PatrikPerssonInceptron deleted the feature/onnx-input-shape-computations branch November 18, 2024 08:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants