-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat #16596
Merged
Lunderberg
merged 2 commits into
apache:main
from
Lunderberg:relax_reorder_permute_dims_after_concat
Feb 23, 2024
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
187 changes: 187 additions & 0 deletions
187
src/relax/transform/reorder_permute_dims_after_concat.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/relax/transform/reorder_permute_dims_after_concat.cc | ||
* \brief Reorder concat(permute_dims(A), permute_dims(B)) into permute_dims(concat(A,B)) | ||
*/ | ||
|
||
#include <tvm/relax/analysis.h> | ||
#include <tvm/relax/dataflow_matcher.h> | ||
#include <tvm/relax/expr.h> | ||
#include <tvm/relax/expr_functor.h> | ||
#include <tvm/relax/transform.h> | ||
|
||
#include <optional> | ||
#include <unordered_set> | ||
#include <vector> | ||
|
||
#include "../op/tensor/index.h" | ||
#include "../op/tensor/linear_algebra.h" | ||
#include "../op/tensor/manipulate.h" | ||
|
||
namespace tvm { | ||
namespace relax { | ||
|
||
namespace { | ||
std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreatePatterns() { | ||
// TODO(Lunderberg): Allow pattern-matching to handle a flexible | ||
// number of arguments, each of which matches the same type of | ||
// pattern. | ||
// | ||
// Because we instantiate one DFPattern for each value in | ||
// `min_concat <= i <= max_concat`, we don't want to set | ||
// `max_concat` to an extremely high value. The current value of 12 | ||
// was chosen to be significantly higher than the highest value | ||
// required so far (3, for query/key/value in attention layers), but | ||
// not so high that it requires an excessive number of `DFPattern`. | ||
// | ||
// This value is deliberately *NOT* exposed, as `max_concat` may be | ||
// increased at any point that it is required, and other use cases | ||
// should not depend on its value. If there is a use case that | ||
// requires more matmuls to be handled, and pattern-matching does | ||
// not yet support a flexible number of `Tuple` elements, | ||
// `max_concat` should be increased. | ||
size_t min_concat = 2; | ||
size_t max_concat = 12; | ||
|
||
std::vector<DFPattern> pat_args; | ||
std::vector<DFPattern> pat_permute_dims; | ||
for (size_t i = 0; i < max_concat; i++) { | ||
auto arg = WildcardPattern(); | ||
pat_args.push_back(arg); | ||
pat_permute_dims.push_back(IsOp("relax.permute_dims")(arg)); | ||
} | ||
|
||
auto make_pattern_with_num_concat = [&](size_t num_concat) -> DFPattern { | ||
ICHECK_LT(num_concat, pat_permute_dims.size()); | ||
auto concat_tuple = TuplePattern( | ||
Array<DFPattern>(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat)); | ||
return IsOp("relax.concat")(concat_tuple); | ||
}; | ||
|
||
DFPattern pat_concat = make_pattern_with_num_concat(min_concat); | ||
for (size_t i = min_concat + 1; i < max_concat; i++) { | ||
pat_concat = pat_concat | make_pattern_with_num_concat(i); | ||
} | ||
|
||
auto get_permute_dims_optional_axes = [](const Expr& expr) -> Optional<Array<Integer>> { | ||
auto call = expr.as<CallNode>(); | ||
ICHECK(call); | ||
auto attrs = call->attrs.as<PermuteDimsAttrs>(); | ||
ICHECK(attrs); | ||
|
||
return attrs->axes; | ||
}; | ||
|
||
auto get_permute_dims_axes = | ||
[get_permute_dims_optional_axes](const Expr& expr) -> Array<Integer> { | ||
if (auto opt_axes = get_permute_dims_optional_axes(expr)) { | ||
return opt_axes.value(); | ||
} else { | ||
auto call = Downcast<Call>(expr); | ||
Array<Integer> permutation; | ||
auto arg_sinfo = call->args[0]->struct_info_.as<TensorStructInfoNode>(); | ||
CHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, " | ||
<< "but argument " << call->args[0] << " has struct info " | ||
<< call->args[0]->struct_info_; | ||
CHECK_GE(arg_sinfo->ndim, 0); | ||
size_t ndim = arg_sinfo->ndim; | ||
for (size_t i = 0; i < ndim; i++) { | ||
permutation.push_back(Integer(ndim - i - 1)); | ||
} | ||
return permutation; | ||
} | ||
}; | ||
|
||
auto permute_dims_axes_are_compatible = [&](const Array<Expr>& permute_dims) -> bool { | ||
auto first_axes = get_permute_dims_axes(permute_dims[0]); | ||
for (size_t i_arg = 1; i_arg < permute_dims.size(); i_arg++) { | ||
auto i_axes = get_permute_dims_axes(permute_dims[i_arg]); | ||
if (i_axes.size() != first_axes.size()) { | ||
return false; | ||
} | ||
for (size_t i_axis = 0; i_axis < first_axes.size(); i_axis++) { | ||
if (i_axes[i_axis]->value != first_axes[i_axis]->value) { | ||
return false; | ||
} | ||
} | ||
} | ||
return true; | ||
}; | ||
|
||
auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr { | ||
Array<Expr> args; | ||
Array<Expr> all_permute_dims; | ||
for (size_t i = 0; i < max_concat; i++) { | ||
if (auto permute_dim_expr = matches.Get(pat_permute_dims[i])) { | ||
all_permute_dims.push_back(permute_dim_expr.value()); | ||
args.push_back(matches[pat_args[i]]); | ||
} | ||
} | ||
|
||
ICHECK_GE(all_permute_dims.size(), min_concat) | ||
<< "InternalError: " | ||
<< "Pattern match should return at least " << min_concat << " items, but only found " | ||
<< all_permute_dims.size() << ": " << all_permute_dims; | ||
|
||
if (!permute_dims_axes_are_compatible(all_permute_dims)) { | ||
return expr; | ||
} | ||
Optional<Array<Integer>> permute_axes = get_permute_dims_optional_axes(all_permute_dims[0]); | ||
|
||
Call concat_call = Downcast<Call>(matches[pat_concat]); | ||
auto concat_attrs = concat_call->attrs.as<ConcatAttrs>(); | ||
ICHECK(concat_attrs); | ||
|
||
auto old_concat_axis = [&]() -> size_t { | ||
if (concat_attrs->axis.defined()) { | ||
return concat_attrs->axis.value()->value; | ||
} else { | ||
return 0; | ||
} | ||
}(); | ||
Integer new_concat_axis = get_permute_dims_axes(all_permute_dims[0])[old_concat_axis]; | ||
|
||
auto new_concat = concat(Tuple(args), new_concat_axis); | ||
auto new_permute_dims = permute_dims(new_concat, permute_axes); | ||
|
||
return new_permute_dims; | ||
}; | ||
|
||
return {pat_concat, rewriter}; | ||
} | ||
|
||
} // namespace | ||
|
||
namespace transform { | ||
Pass ReorderPermuteDimsAfterConcat() { | ||
auto pass_func = [=](Function func, IRModule mod, PassContext pc) { | ||
auto [pattern, rewriter] = CreatePatterns(); | ||
return RewriteCall(pattern, rewriter, func); | ||
}; | ||
return CreateFunctionPass(pass_func, 1, "ReorderPermuteDimsAfterConcat", {}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat") | ||
.set_body_typed(ReorderPermuteDimsAfterConcat); | ||
|
||
} // namespace transform | ||
} // namespace relax | ||
} // namespace tvm |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this an arbitrary choice or do we enforce this somewhere? If it's arbitrary, we should probably have an enum or global constant somewhere to document that. (I'm sure it's a large enough number of args but who knows what automatically generated code might do.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Entirely arbitrary at the moment. The choice of 12 as the max was to make it significantly larger than any actual case I've seen so far (3 for query/key/value), and because I'd been in the habit of using this Rust utility in hobby projects.
Because the long-term plan is to remove the restriction altogether, I'd actually lean away from having a global constant to document it. Having the global constant would allow other portions of the code to depend on it, which would make it harder to migrate to the long-term solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I've added a comment indicating why the value is not being exposed, and with a recommendation to increase
max_concat
if required, at least until pattern matching of arbitrary tuple sizes is supported.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I've noticed that this kind of issue comes up in other languages too. Thanks for documenting the issue. I wish there were a way to "loudly" give an error if an unsupported number of concat params is detected, but I don't see a clear one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. If the number of concat parameters is exceeded, then the match silently fails. And unlike
CombineParallelMatmul
, which could be run repeatedly if the number of concatenations is exceeded, the validity checking for reordering requires all concatenated tensors before performing any change.Long-term, I think the flexible tuple index will be the way to go. I think it will end up needing explicit tuple extent variables. That way, a pattern containing two flexible tuple extents could distinguish between a single shared extent and two independent extents. I'm picturing something like the following:
But that's for brainstorming an implementation that would be much later down the road.