[RELAX][ONNX][FIX] add a parser to handle expression in the shape dim names #17505
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem
The shape names in an ONNX model can contain expressions such as the shape
int64[batch_size,past_sequence_length + sequence_length]
of the attention mask of an LLM. In this case, the second dimension contains an expressionpast_sequence_length + sequence_length
wherepast_sequence_length
andsequence_length
should be individual variables added together. However, currently, a new variable named"past_sequence_length + sequence_length"
is instead created when translating the graph.Fix
I added a simple parser that creates individual size variables for the variable names and generates the resulting prim expression. Note, in order to keep the parser simple, it evaluates expressions left to right. Not accounting for operator precedence.
Test
I added regression tests to verify that the onnx shape dim expression are evaluated correctly.
Additional small fixes
In the case when PrimValues are encountered in the BinaryBase, they are not always fully extracted before turning them into numpy arrays. I added an additional check that extracts the value from IntImm and FloatImm types before converting them to numpy arrays.