-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat #16596
[Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat #16596
Conversation
This commit implements an optional optimization pass `relax.transform.ReorderPermuteDimsAfterConcat`, which reorder expressions of the form `R.concat(R.permute_dims(A), R.permute_dims(B))` into `R.permute_dims(R.concat(A,B))`. This pass is intended to be used alongside `CombineParallelMatmul`. After parallel matmuls are combined, to be lifted out, and optimized `nn.Linear` kernels to find the `R.matmul(x, R.permute_dims(weights))` patterns they are looking for. ```python @R.function def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor): """Initial IRModule The `R.permute_dims` followed by `R.matmul` is the relax equivalent of `nn.Linear`, and will frequently have optimized kernels. """ weight_query_T = R.permute_dims(weight_query) query = R.matmul(x, weight_query) weight_key_T = R.permute_dims(weight_key) key = R.matmul(x, weight_key) weight_value_T = R.permute_dims(weight_value) value = R.matmul(x, weight_value) @R.function def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor): """After `CombineParallelMatmul` There's now only a single matmul to be performed, which is generally better than performing three small matmuls. However, the optimized kernels for `nn.Linear` can no longer be applied, because the `R.concat` isn't part of the expected pattern. """ weight_query_T = R.permute_dims(weight_query) weight_key_T = R.permute_dims(weight_key) weight_value_T = R.permute_dims(weight_value) fused_weight_T = R.concat([weight_query_T, weight_key_T, weight_value_T], axis=1) fused_qkv = R.matmul(x, fused_weight_T) query, key, value = R.split(fused_qkv) @R.function def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor): """After `ReorderPermuteDimsAfterConcat` There's still only a single matmul, and the optimized kernels for `nn.Linear` can be applied again. """ fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0) fused_weight_T = R.permute_dims(fused_weight) fused_qkv = R.matmul(x, fused_weight_T) query, key, value = R.split(fused_qkv) ```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a solid change. I agree with the comment about having variable numbers of args for the pattern matcher, it would be good to have a robust solution for that rather than generating patterns for concrete numbers (an issue that many languages' standard libraries suffer from too).
// number of arguments, each of which matches the same type of | ||
// pattern. | ||
size_t min_concat = 2; | ||
size_t max_concat = 12; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this an arbitrary choice or do we enforce this somewhere? If it's arbitrary, we should probably have an enum or global constant somewhere to document that. (I'm sure it's a large enough number of args but who knows what automatically generated code might do.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Entirely arbitrary at the moment. The choice of 12 as the max was to make it significantly larger than any actual case I've seen so far (3 for query/key/value), and because I'd been in the habit of using this Rust utility in hobby projects.
Because the long-term plan is to remove the restriction altogether, I'd actually lean away from having a global constant to document it. Having the global constant would allow other portions of the code to depend on it, which would make it harder to migrate to the long-term solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I've added a comment indicating why the value is not being exposed, and with a recommendation to increase max_concat
if required, at least until pattern matching of arbitrary tuple sizes is supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I've noticed that this kind of issue comes up in other languages too. Thanks for documenting the issue. I wish there were a way to "loudly" give an error if an unsupported number of concat params is detected, but I don't see a clear one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. If the number of concat parameters is exceeded, then the match silently fails. And unlike CombineParallelMatmul
, which could be run repeatedly if the number of concatenations is exceeded, the validity checking for reordering requires all concatenated tensors before performing any change.
Long-term, I think the flexible tuple index will be the way to go. I think it will end up needing explicit tuple extent variables. That way, a pattern containing two flexible tuple extents could distinguish between a single shared extent and two independent extents. I'm picturing something like the following:
// A normal wildcard argument
DFWildcard arg;
// Represents the extent of the tuple containing a flexible pattern.
DFArbitraryExtent tuple_extent;
// A pattern that may have a different value within each element of the tuple.
DFRepeatedPattern args(arg, tuple_extent);
// The DFRepeatedPattern could be used to build up additional patterns
DFCallPattern transposed_args(OpPattern("relax.permute_dims"), args);
// Eventually, a pattern match for a tuple of unknown length gets instantiated.
DFArbitraryTuplePattern concat_tuple(transposed_args, tuple_extent);
// Which could then be used to match operators whose arguments are an
// arbitrary tuple length.
DFCallPattern concat_pat(OpPattern("relax.concat"), concat_tuple);
But that's for brainstorming an implementation that would be much later down the road.
This commit implements an optional optimization pass
relax.transform.ReorderPermuteDimsAfterConcat
, which reorder expressions of the formR.concat(R.permute_dims(A), R.permute_dims(B))
intoR.permute_dims(R.concat(A,B))
.This pass is intended to be used alongside
CombineParallelMatmul
. After parallel matmuls are combined, to be lifted out, and optimizednn.Linear
kernels to find theR.matmul(x, R.permute_dims(weights))
patterns they are looking for.