Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat #16596

Merged

Conversation

Lunderberg
Copy link
Contributor

This commit implements an optional optimization pass relax.transform.ReorderPermuteDimsAfterConcat, which reorder expressions of the form R.concat(R.permute_dims(A), R.permute_dims(B)) into R.permute_dims(R.concat(A,B)).

This pass is intended to be used alongside CombineParallelMatmul. After parallel matmuls are combined, to be lifted out, and optimized nn.Linear kernels to find the R.matmul(x, R.permute_dims(weights)) patterns they are looking for.

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """Initial IRModule

    The `R.permute_dims` followed by `R.matmul` is the relax
    equivalent of `nn.Linear`, and will frequently have optimized
    kernels.
    """
    weight_query_T = R.permute_dims(weight_query)
    query = R.matmul(x, weight_query)
    weight_key_T = R.permute_dims(weight_key)
    key = R.matmul(x, weight_key)
    weight_value_T = R.permute_dims(weight_value)
    value = R.matmul(x, weight_value)

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """After `CombineParallelMatmul`

    There's now only a single matmul to be performed, which is
    generally better than performing three small matmuls.  However,
    the optimized kernels for `nn.Linear` can no longer be applied,
    because the `R.concat` isn't part of the expected pattern.
    """
    weight_query_T = R.permute_dims(weight_query)
    weight_key_T = R.permute_dims(weight_key)
    weight_value_T = R.permute_dims(weight_value)

    fused_weight_T = R.concat([weight_query_T, weight_key_T, weight_value_T], axis=1)
    fused_qkv = R.matmul(x, fused_weight_T)

    query, key, value = R.split(fused_qkv)

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """After `ReorderPermuteDimsAfterConcat`

    There's still only a single matmul, and the optimized kernels for
    `nn.Linear` can be applied again.
    """
    fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0)

    fused_weight_T = R.permute_dims(fused_weight)
    fused_qkv = R.matmul(x, fused_weight_T)

    query, key, value = R.split(fused_qkv)

This commit implements an optional optimization pass
`relax.transform.ReorderPermuteDimsAfterConcat`, which reorder
expressions of the form `R.concat(R.permute_dims(A),
R.permute_dims(B))` into `R.permute_dims(R.concat(A,B))`.

This pass is intended to be used alongside `CombineParallelMatmul`.
After parallel matmuls are combined, to be lifted out, and optimized
`nn.Linear` kernels to find the `R.matmul(x, R.permute_dims(weights))`
patterns they are looking for.

```python
@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """Initial IRModule

    The `R.permute_dims` followed by `R.matmul` is the relax
    equivalent of `nn.Linear`, and will frequently have optimized
    kernels.
    """
    weight_query_T = R.permute_dims(weight_query)
    query = R.matmul(x, weight_query)
    weight_key_T = R.permute_dims(weight_key)
    key = R.matmul(x, weight_key)
    weight_value_T = R.permute_dims(weight_value)
    value = R.matmul(x, weight_value)

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """After `CombineParallelMatmul`

    There's now only a single matmul to be performed, which is
    generally better than performing three small matmuls.  However,
    the optimized kernels for `nn.Linear` can no longer be applied,
    because the `R.concat` isn't part of the expected pattern.
    """
    weight_query_T = R.permute_dims(weight_query)
    weight_key_T = R.permute_dims(weight_key)
    weight_value_T = R.permute_dims(weight_value)

    fused_weight_T = R.concat([weight_query_T, weight_key_T, weight_value_T], axis=1)
    fused_qkv = R.matmul(x, fused_weight_T)

    query, key, value = R.split(fused_qkv)

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """After `ReorderPermuteDimsAfterConcat`

    There's still only a single matmul, and the optimized kernels for
    `nn.Linear` can be applied again.
    """
    fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0)

    fused_weight_T = R.permute_dims(fused_weight)
    fused_qkv = R.matmul(x, fused_weight_T)

    query, key, value = R.split(fused_qkv)
```
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a solid change. I agree with the comment about having variable numbers of args for the pattern matcher, it would be good to have a robust solution for that rather than generating patterns for concrete numbers (an issue that many languages' standard libraries suffer from too).

// number of arguments, each of which matches the same type of
// pattern.
size_t min_concat = 2;
size_t max_concat = 12;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an arbitrary choice or do we enforce this somewhere? If it's arbitrary, we should probably have an enum or global constant somewhere to document that. (I'm sure it's a large enough number of args but who knows what automatically generated code might do.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Entirely arbitrary at the moment. The choice of 12 as the max was to make it significantly larger than any actual case I've seen so far (3 for query/key/value), and because I'd been in the habit of using this Rust utility in hobby projects.

Because the long-term plan is to remove the restriction altogether, I'd actually lean away from having a global constant to document it. Having the global constant would allow other portions of the code to depend on it, which would make it harder to migrate to the long-term solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I've added a comment indicating why the value is not being exposed, and with a recommendation to increase max_concat if required, at least until pattern matching of arbitrary tuple sizes is supported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've noticed that this kind of issue comes up in other languages too. Thanks for documenting the issue. I wish there were a way to "loudly" give an error if an unsupported number of concat params is detected, but I don't see a clear one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. If the number of concat parameters is exceeded, then the match silently fails. And unlike CombineParallelMatmul, which could be run repeatedly if the number of concatenations is exceeded, the validity checking for reordering requires all concatenated tensors before performing any change.

Long-term, I think the flexible tuple index will be the way to go. I think it will end up needing explicit tuple extent variables. That way, a pattern containing two flexible tuple extents could distinguish between a single shared extent and two independent extents. I'm picturing something like the following:

// A normal wildcard argument
DFWildcard arg;

// Represents the extent of the tuple containing a flexible pattern.
DFArbitraryExtent tuple_extent;

// A pattern that may have a different value within each element of the tuple.
DFRepeatedPattern args(arg, tuple_extent);

// The DFRepeatedPattern could be used to build up additional patterns
DFCallPattern transposed_args(OpPattern("relax.permute_dims"), args);

// Eventually, a pattern match for a tuple of unknown length gets instantiated.
DFArbitraryTuplePattern concat_tuple(transposed_args, tuple_extent);

// Which could then be used to match operators whose arguments are an 
// arbitrary tuple length.
DFCallPattern concat_pat(OpPattern("relax.concat"), concat_tuple);

But that's for brainstorming an implementation that would be much later down the road.

@Lunderberg Lunderberg merged commit b581575 into apache:main Feb 23, 2024
20 checks passed
@Lunderberg Lunderberg deleted the relax_reorder_permute_dims_after_concat branch February 23, 2024 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants