Skip to content

Commit

Permalink
[Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat (#…
Browse files Browse the repository at this point in the history
…16596)

* [Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat

This commit implements an optional optimization pass
`relax.transform.ReorderPermuteDimsAfterConcat`, which reorder
expressions of the form `R.concat(R.permute_dims(A),
R.permute_dims(B))` into `R.permute_dims(R.concat(A,B))`.

This pass is intended to be used alongside `CombineParallelMatmul`.
After parallel matmuls are combined, to be lifted out, and optimized
`nn.Linear` kernels to find the `R.matmul(x, R.permute_dims(weights))`
patterns they are looking for.

```python
@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """Initial IRModule

    The `R.permute_dims` followed by `R.matmul` is the relax
    equivalent of `nn.Linear`, and will frequently have optimized
    kernels.
    """
    weight_query_T = R.permute_dims(weight_query)
    query = R.matmul(x, weight_query)
    weight_key_T = R.permute_dims(weight_key)
    key = R.matmul(x, weight_key)
    weight_value_T = R.permute_dims(weight_value)
    value = R.matmul(x, weight_value)

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """After `CombineParallelMatmul`

    There's now only a single matmul to be performed, which is
    generally better than performing three small matmuls.  However,
    the optimized kernels for `nn.Linear` can no longer be applied,
    because the `R.concat` isn't part of the expected pattern.
    """
    weight_query_T = R.permute_dims(weight_query)
    weight_key_T = R.permute_dims(weight_key)
    weight_value_T = R.permute_dims(weight_value)

    fused_weight_T = R.concat([weight_query_T, weight_key_T, weight_value_T], axis=1)
    fused_qkv = R.matmul(x, fused_weight_T)

    query, key, value = R.split(fused_qkv)

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """After `ReorderPermuteDimsAfterConcat`

    There's still only a single matmul, and the optimized kernels for
    `nn.Linear` can be applied again.
    """
    fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0)

    fused_weight_T = R.permute_dims(fused_weight)
    fused_qkv = R.matmul(x, fused_weight_T)

    query, key, value = R.split(fused_qkv)
```

* Expand description of `max_concat` variable as a temporary solution
  • Loading branch information
Lunderberg authored Feb 23, 2024
1 parent 84b3f69 commit b581575
Show file tree
Hide file tree
Showing 4 changed files with 472 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
RemovePurityChecking,
RemoveUnusedParameters,
RemoveUnusedOutputs,
ReorderPermuteDimsAfterConcat,
ReorderTakeAfterMatmul,
RewriteCUDAGraph,
RewriteDataflowReshape,
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,26 @@ def ExpandMatmulOfSum():
return _ffi_api.ExpandMatmulOfSum() # type: ignore


def ReorderPermuteDimsAfterConcat():
"""Reorder `concat(permute_dims(A), permute_dims(B))` into `permute_dims(concat(A,B))`
Useful for optimizing computations after `CombineParallelMatmul`.
The patterns for optimized `nn.Linear` implementations look for
`matmul(activations, permute_dims(weights))`. After
`CombineParallelMatmul`, the `matmul(activations,
concat(permute_dims(A), permute_dims(B)))` no longer matches this
pattern. Rearranging into `matmul(activations,
permute_dims(concat(A,B)))` restores the pattern match.
Returns
-------
ret : tvm.transform.Pass
The corresponding pass.
"""

return _ffi_api.ReorderPermuteDimsAfterConcat() # type: ignore


def ReorderTakeAfterMatmul():
"""Reorder `matmul(x, take(weights, indices))` to `take(matmul(x,weights),indices)`
Expand Down
187 changes: 187 additions & 0 deletions src/relax/transform/reorder_permute_dims_after_concat.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/transform/reorder_permute_dims_after_concat.cc
* \brief Reorder concat(permute_dims(A), permute_dims(B)) into permute_dims(concat(A,B))
*/

#include <tvm/relax/analysis.h>
#include <tvm/relax/dataflow_matcher.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

#include <optional>
#include <unordered_set>
#include <vector>

#include "../op/tensor/index.h"
#include "../op/tensor/linear_algebra.h"
#include "../op/tensor/manipulate.h"

namespace tvm {
namespace relax {

namespace {
std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreatePatterns() {
// TODO(Lunderberg): Allow pattern-matching to handle a flexible
// number of arguments, each of which matches the same type of
// pattern.
//
// Because we instantiate one DFPattern for each value in
// `min_concat <= i <= max_concat`, we don't want to set
// `max_concat` to an extremely high value. The current value of 12
// was chosen to be significantly higher than the highest value
// required so far (3, for query/key/value in attention layers), but
// not so high that it requires an excessive number of `DFPattern`.
//
// This value is deliberately *NOT* exposed, as `max_concat` may be
// increased at any point that it is required, and other use cases
// should not depend on its value. If there is a use case that
// requires more matmuls to be handled, and pattern-matching does
// not yet support a flexible number of `Tuple` elements,
// `max_concat` should be increased.
size_t min_concat = 2;
size_t max_concat = 12;

std::vector<DFPattern> pat_args;
std::vector<DFPattern> pat_permute_dims;
for (size_t i = 0; i < max_concat; i++) {
auto arg = WildcardPattern();
pat_args.push_back(arg);
pat_permute_dims.push_back(IsOp("relax.permute_dims")(arg));
}

auto make_pattern_with_num_concat = [&](size_t num_concat) -> DFPattern {
ICHECK_LT(num_concat, pat_permute_dims.size());
auto concat_tuple = TuplePattern(
Array<DFPattern>(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat));
return IsOp("relax.concat")(concat_tuple);
};

DFPattern pat_concat = make_pattern_with_num_concat(min_concat);
for (size_t i = min_concat + 1; i < max_concat; i++) {
pat_concat = pat_concat | make_pattern_with_num_concat(i);
}

auto get_permute_dims_optional_axes = [](const Expr& expr) -> Optional<Array<Integer>> {
auto call = expr.as<CallNode>();
ICHECK(call);
auto attrs = call->attrs.as<PermuteDimsAttrs>();
ICHECK(attrs);

return attrs->axes;
};

auto get_permute_dims_axes =
[get_permute_dims_optional_axes](const Expr& expr) -> Array<Integer> {
if (auto opt_axes = get_permute_dims_optional_axes(expr)) {
return opt_axes.value();
} else {
auto call = Downcast<Call>(expr);
Array<Integer> permutation;
auto arg_sinfo = call->args[0]->struct_info_.as<TensorStructInfoNode>();
CHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, "
<< "but argument " << call->args[0] << " has struct info "
<< call->args[0]->struct_info_;
CHECK_GE(arg_sinfo->ndim, 0);
size_t ndim = arg_sinfo->ndim;
for (size_t i = 0; i < ndim; i++) {
permutation.push_back(Integer(ndim - i - 1));
}
return permutation;
}
};

auto permute_dims_axes_are_compatible = [&](const Array<Expr>& permute_dims) -> bool {
auto first_axes = get_permute_dims_axes(permute_dims[0]);
for (size_t i_arg = 1; i_arg < permute_dims.size(); i_arg++) {
auto i_axes = get_permute_dims_axes(permute_dims[i_arg]);
if (i_axes.size() != first_axes.size()) {
return false;
}
for (size_t i_axis = 0; i_axis < first_axes.size(); i_axis++) {
if (i_axes[i_axis]->value != first_axes[i_axis]->value) {
return false;
}
}
}
return true;
};

auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
Array<Expr> args;
Array<Expr> all_permute_dims;
for (size_t i = 0; i < max_concat; i++) {
if (auto permute_dim_expr = matches.Get(pat_permute_dims[i])) {
all_permute_dims.push_back(permute_dim_expr.value());
args.push_back(matches[pat_args[i]]);
}
}

ICHECK_GE(all_permute_dims.size(), min_concat)
<< "InternalError: "
<< "Pattern match should return at least " << min_concat << " items, but only found "
<< all_permute_dims.size() << ": " << all_permute_dims;

if (!permute_dims_axes_are_compatible(all_permute_dims)) {
return expr;
}
Optional<Array<Integer>> permute_axes = get_permute_dims_optional_axes(all_permute_dims[0]);

Call concat_call = Downcast<Call>(matches[pat_concat]);
auto concat_attrs = concat_call->attrs.as<ConcatAttrs>();
ICHECK(concat_attrs);

auto old_concat_axis = [&]() -> size_t {
if (concat_attrs->axis.defined()) {
return concat_attrs->axis.value()->value;
} else {
return 0;
}
}();
Integer new_concat_axis = get_permute_dims_axes(all_permute_dims[0])[old_concat_axis];

auto new_concat = concat(Tuple(args), new_concat_axis);
auto new_permute_dims = permute_dims(new_concat, permute_axes);

return new_permute_dims;
};

return {pat_concat, rewriter};
}

} // namespace

namespace transform {
Pass ReorderPermuteDimsAfterConcat() {
auto pass_func = [=](Function func, IRModule mod, PassContext pc) {
auto [pattern, rewriter] = CreatePatterns();
return RewriteCall(pattern, rewriter, func);
};
return CreateFunctionPass(pass_func, 1, "ReorderPermuteDimsAfterConcat", {});
}

TVM_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat")
.set_body_typed(ReorderPermuteDimsAfterConcat);

} // namespace transform
} // namespace relax
} // namespace tvm
Loading

0 comments on commit b581575

Please sign in to comment.