[Relax] Implement relax.transform.RemoveSymbolicExpressionsInSubroutine #17080
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a follow-up commit to
#16637, which updated
relax.transform.FuseOps
to provide additional parameters defining symbolic variables required by the fused functions. While this ensures thatrelax.transform.FuseOps
produces well-formed Relax functions, these additional arguments can break some kernel implementations.This commit implements a new transform
RemoveSymbolicExpressionsInSubroutine
to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes.For example, consider the following Relax function:
The
data
tensor may be used to inferhidden_size
, but cannot be used to inferbatch_size
orseq_len
. TheR.Shape
parameter exists solely to definebatch_size
andseq_len
, since all symbolic variables must be defined. However, neitherbatch_size
norseq_len
are ever used outside of the expressionbatch_size * seq_len
, and the value ofbatch_size * seq_len
could be inferred from the shape of thedata
tensor.This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the
dummy_arg: R.Shape
be entirely unused, so a later use ofrelax.transform.RemoveUnusedParameters()
can remove the parameter altogether.