[Relax][Bugfix] Infer TIR values from shapes inside a tuple #17312
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
If a Relax function contains an
R.match_cast
that defines a symbolic shape, and the value provided to theR.match_cast
has a known static shape, therelax.transform.CanoncalizeBindings()
pass can in-line the known static shape. However, while these known TIR values were only collected if the expression used inR.match_cast
was aR.Tensor
,R.Shape
, andR.Prim
(Relax types which may contain symbolic TIR values), they were not collected if theR.match_cast
expression was aR.Tuple
.For example, while using
R.match_cast
to convert fromR.Tensor([16])
toR.Tensor([batch_size])
would identify thatbatch_size
must be16
, usingR.match_cast
to convert fromR.Tuple(R.Tensor([16]))
toR.Tuple(R.Tensor([batch_size]))
would not.This commit updates the
InferSymbolicVarMap
to collect all symbolic shapes, even if they occur within aR.Tuple
.