Skip to content

Commit

Permalink
[Relax] Implement relax.transform.RemoveSymbolicExpressionsInSubroutine
Browse files Browse the repository at this point in the history
This is a follow-up commit to
#16637, which updated
`relax.transform.FuseOps` to provide additional parameters defining
symbolic variables required by the fused functions.  While this
ensures that `relax.transform.FuseOps` produces well-formed Relax
functions, these additional arguments can break some kernel
implementations.

This commit implements a new transform
`RemoveSymbolicExpressionsInSubroutine` to resolve this issue.  This
transform identifies function arguments whose sole purpose is to
compute a symbolic expression, when that symbolic expression could be
inferred from tensor shapes.

For example, consider the following Relax function:

```python
@R.function
def func(
    data: R.Tensor(["batch_size * seq_len", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]):

    batch_size = T.int64()
    seq_len = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights)
    return output
```

The `data` tensor may be used to infer `hidden_size`, but cannot be
used to infer `batch_size` or `seq_len`.  The `R.Shape` parameter
exists solely to define `batch_size` and `seq_len`, since all symbolic
variables must be defined.  However, neither `batch_size` nor
`seq_len` are ever used outside of the expression `batch_size *
seq_len`, and the value of `batch_size * seq_len` could be inferred
from the shape of the `data` tensor.

This new transform identifies cases where an argument is otherwise
unnecessary, and replaces the symbolic expression with a new
argument.  This makes the `dummy_arg: R.Shape` be entirely unused, so
a later use of `relax.transform.RemoveUnusedParameters()` can remove
the parameter altogether.

```python
@R.function
def func(
    data: R.Tensor(["data_dim0", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ):

    data_dim0 = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights)
    return output
```
  • Loading branch information
Lunderberg committed Sep 11, 2024
1 parent 72b75fe commit 27a6820
Show file tree
Hide file tree
Showing 7 changed files with 721 additions and 7 deletions.
17 changes: 17 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,23 @@ TVM_DLL Pass RemoveUnusedParameters();
*/
TVM_DLL Pass RemoveUnusedOutputs();

/*! \brief Remove unnecessary symbolic expressions in subroutines
*
* If all occurrences of a symbolic variable within a subroutine
* occur within the same symbolic expression, then the subroutine
* could be simplified to be in terms of that expression.
*
* For example, if a subroutine accepts symbolic shape parameters `N`
* and `M`, and the variables `N` and `M` are only ever used to
* compute `N*M`, then the subroutine could instead accept a symbolic
* shape parameter `new_var = N*M`. This can allow shape parameters
* to be inferred from tensor shapes, rather than requiring additional
* arguments.
*
* \return The pass
*/
TVM_DLL Pass RemoveSymbolicExpressionInSubroutine();

/*!
* \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps.
* \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Relax transformations. """
"""Relax transformations."""

from .transform import (
AdjustMatmulOrder,
Expand Down Expand Up @@ -65,6 +65,7 @@
PatternCheckContext,
RealizeVDevice,
RemovePurityChecking,
RemoveSymbolicExpressionInSubroutine,
RemoveUnusedOutputs,
RemoveUnusedParameters,
ReorderPermuteDimsAfterConcat,
Expand Down
52 changes: 52 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,58 @@ def RemoveUnusedOutputs() -> tvm.ir.transform.Pass:
return _ffi_api.RemoveUnusedOutputs() # type: ignore


def RemoveSymbolicExpressionInSubroutine() -> tvm.ir.transform.Pass:
"""Remove unnecessary symbolic expressions in subroutines
If all occurrences of a symbolic variable within a subroutine
occur within the same symbolic expression, then the subroutine
could be simplified to be in terms of that expression.
For example, consider an elementwise operation that takes input of
shape `arg: R.Tensor([m * n])`, producing output of shape
`R.Tensor([m * n])`. The symbolic variables `m` and `n` cannot be
inferred from the shape of `arg`, as only their product `m*n` can
be determined from the tensor's shape. In order to be
well-formed, Relax requires one of the three following
workarounds.
1. Remove the symbolic variables, producing `arg:
R.Tensor(ndim=1)`. This no longer provides the symbolic
variables, and is well-formed. However, this also causes the
output shape to be `R.Tensor(ndim=1)`. The calling scope can
no longer determine that the input and output shape are
identical.
This is the default behavior of the `relax::BlockBuilder`
2. Provide an additional argument to define the symbolic variable.
If the elementwise operation takes an addition argument
`R.Shape([m, n])`, then that additional argument would
define the symbolic variables.
This is the output produced by `relax.transform.FuseOps`, and
while it is well-formed, the additional non-tensor argument can
be unexpected by downstream transforms.
3. Update the shape of `arg` to `R.Tensor([arg_size])`. This
allows the symbolic variable `arg_size` to be inferred from the
tensor's shape, and propagates to the output shape of
`R.Tensor([arg_size])`. Within the calling scope, an
argument of `R.Tensor([m * n])` can then be inferred to produce
an output of `R.Tensor([m * n])`, without requiring an
additional parameter to provide the shape.
This transform updates internal function that use option (2) to
instead use option (3).
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.RemoveSymbolicExpressionInSubroutine() # type: ignore


def InlinePrivateFunctions() -> tvm.ir.transform.Pass:
"""Inline all private relax functions
Expand Down
247 changes: 247 additions & 0 deletions src/relax/transform/remove_symbolic_expression_in_subroutine.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/transform/remove_symbolic_expression_in_subroutine.cc
*
* \brief Replace symbolic expressions with single variables, when possible.
*
* For example, if a subroutine accepts symbolic shape parameters `N`
* and `M`, and the variables `N` and `M` are only ever used to
* compute `N*M`, then the subroutine could instead accept a symbolic
* shape parameter `new_var = N*M`. This can allow shape parameters
* to be inferred from tensor shapes, rather than requiring additional
* arguments.
*/

#include <tvm/node/object_path.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr_functor.h>

#include <algorithm>
#include <unordered_map>
#include <unordered_set>

namespace tvm {
namespace relax {

namespace {

// Utility templates for unordered map/set that use structural hash/equal.

template <typename Key, typename Value>
using StructMap = std::unordered_map<Key, Value, StructuralHash, StructuralEqual>;

template <typename Key>
using StructSet = std::unordered_set<Key, StructuralHash, StructuralEqual>;

/* \brief Collect symbolic expressions that may be inferred from a function signature
*
* \param func The function whose signature should be inspected
*
* \return A map from PrimExpr to the location where it occurs in the signature
*/
StructMap<PrimExpr, std::string> CollectInferableExpressions(const Function& func) {
StructMap<PrimExpr, std::string> output;

auto mark = [&](const PrimExpr& expr, const ObjectPath& path) {
if (!output.count(expr)) {
std::stringstream ss;
ss << path;
output[expr] = ss.str();
}
};

std::function<void(const StructInfo&, const ObjectPath&)> visit = [&](const StructInfo& sinfo,
const ObjectPath& path) {
if (auto tensor = sinfo.as<TensorStructInfoNode>()) {
if (auto opt_shape = tensor->GetShape()) {
auto shape_path = path->Attr("shape");
auto shape = opt_shape.value();
for (size_t i = 0; i < shape.size(); i++) {
mark(shape[i], shape_path->ArrayIndex(i));
}
}
} else if (auto tuple = sinfo.as<TupleStructInfoNode>()) {
for (size_t i = 0; i < tuple->fields.size(); i++) {
visit(tuple->fields[i], path->ArrayIndex(i));
}
}
};

for (const auto& param : func->params) {
visit(GetStructInfo(param), ObjectPath::Root(param->name_hint()));
}

return output;
}

/* \brief Collect expressions that are required in a function body
*
* This recurses into StructInfo and sub-expressions, but does not
* recurse beyond any expression in `inferable_expressions`. This
* allows the transform to determine whether a `tir::Var` ever occurs
* outside of an expression that can be inferred.
*/
class RequiredExpressionCollector : private StructInfoVisitor,
private ExprVisitor,
private tir::ExprVisitor {
public:
static StructSet<PrimExpr> Collect(
const Function& func, const StructMap<PrimExpr, std::string>& inferable_expressions) {
RequiredExpressionCollector visitor(inferable_expressions);
visitor.VisitExpr(func->body);
return visitor.required_expressions_;
}

private:
explicit RequiredExpressionCollector(
const StructMap<PrimExpr, std::string>& inferable_expressions)
: inferable_expressions_(inferable_expressions) {}

using relax::ExprVisitor::VisitExpr;
using tir::ExprVisitor::VisitExpr;

// Required in order to recurse from `TensorStructInfo` into its
// `ShapeExpr`. This hands control from `StructInfoVisitor` into
// `ExprVisitor`.
void VisitStructInfoExprField(const Expr& expr) override { VisitExpr(expr); }

// Required in order to recurse into `ShapeStructInfo`. This hands
// control from `ExprVisitor` back to `StructInfoVisitor`.
void VisitExprDepStructInfoField(const StructInfo& struct_info) override {
VisitStructInfo(struct_info);
}

void VisitPrimExpr(const PrimExpr& expr) override {
required_expressions_.insert(expr);
if (!inferable_expressions_.count(expr)) {
tir::ExprVisitor::VisitExpr(expr);
}
}

void VisitStructInfoExprField(const PrimExpr& expr) override { VisitPrimExpr(expr); }

const StructMap<PrimExpr, std::string>& inferable_expressions_;
StructSet<PrimExpr> required_expressions_;
};

/* \brief Replace occurrences of a PrimExpr in the symbolic variables
*
* In most cases, the `tvm::relax::Bind` utility should be used
* instead. Here, though, we are replacing a `PrimExpr` with a
* `tir::Var`, whereas `tvm::relax::Bind` supports the more standard
* case of replacing a `tir::Var` with a `PrimExpr`.
*/
class SymbolicSubexprReplacer : public relax::ExprMutator,
public StructInfoMutator,
public tir::ExprMutator {
public:
using relax::ExprMutator::operator();
using relax::ExprMutator::VisitExpr;
using tir::ExprMutator::operator();
using tir::ExprMutator::VisitExpr;

explicit SymbolicSubexprReplacer(StructMap<PrimExpr, tir::Var> replacements)
: replacements_(replacements) {}

StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override {
return VisitStructInfo(struct_info);
}
Expr VisitStructInfoExprField(const Expr& expr) override { return VisitExpr(expr); }
PrimExpr VisitStructInfoExprField(const PrimExpr& expr) override { return VisitExpr(expr); }
PrimExpr VisitPrimExpr(const PrimExpr& expr) override { return VisitExpr(expr); }

PrimExpr VisitExpr(const PrimExpr& expr) override {
if (auto it = replacements_.find(expr); it != replacements_.end()) {
return it->second;
} else {
return tir::ExprMutator::VisitExpr(expr);
}
}

StructMap<PrimExpr, tir::Var> replacements_;
};

} // namespace

Function RemoveSymbolicExpressionInSubroutine(Function func) {
bool is_exposed_externally = func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
if (is_exposed_externally) return func;

auto inferable_expressions = CollectInferableExpressions(func);

auto required_expressions = RequiredExpressionCollector::Collect(func, inferable_expressions);

StructMap<PrimExpr, tir::Var> replacements;
for (const auto& [expr, name] : inferable_expressions) {
bool is_tir_var = expr->IsInstance<tir::VarNode>();

auto expr_depends_on = tir::UndefinedVars(expr);
bool internal_variable_is_required =
std::any_of(expr_depends_on.begin(), expr_depends_on.end(),
[&](const tir::Var& subvar) { return required_expressions.count(subvar); });

if (!is_tir_var && !internal_variable_is_required) {
// For human-readability, use the location used to infer the
// shape to name the variable. (e.g. `A_dim0` for a parameter
// inferred from parameter `A->shape[0]`.)
replacements[expr] = tir::Var(name, expr->dtype);
}
}

if (replacements.empty()) {
return func;
}

SymbolicSubexprReplacer mutator(replacements);
return Downcast<Function>(mutator(func));
}

namespace transform {
Pass RemoveSymbolicExpressionInSubroutine() {
auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule {
IRModule updates;

for (const auto& [gvar, base_func] : mod->functions) {
if (auto func = base_func.as<Function>()) {
auto mutated = RemoveSymbolicExpressionInSubroutine(func.value());
if (!mutated.same_as(base_func)) {
updates->Add(gvar, mutated);
}
}
}

if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "RemoveSymbolicExpressionInSubroutine", {});
}

TVM_REGISTER_GLOBAL("relax.transform.RemoveSymbolicExpressionInSubroutine")
.set_body_typed(RemoveSymbolicExpressionInSubroutine);

} // namespace transform
} // namespace relax
} // namespace tvm
16 changes: 12 additions & 4 deletions src/relax/transform/remove_unused_parameters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,25 @@ std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
// symbolic variables. We still want to remove the relax variable
// to reduce computational steps in the parent, but we need to
// provide the symbolic variables the other steps.
auto defined_tir_params = [&]() -> PSet<tir::Var> {
auto required_tir_vars = [&]() -> PSet<tir::Var> {
auto arr = FreeSymbolicVars(func->body);
return {arr.begin(), arr.end()};
}();

auto inferable_tir_params = [&]() -> PSet<tir::Var> {
auto param_sinfo =
TupleStructInfo(params.Map([](const auto& var) { return GetStructInfo(var); }));
auto arr = DefinableTIRVarsInStructInfo(param_sinfo);
return {arr.begin(), arr.end()};
}();

// Use an array to define the order of the symbolic variables
// Collect any additional TIR variables that should be provided.
// The `DefinableTIRVarsInStructInfo` function returns the TIR
// variables in order of their occurrence, so the output is
// deterministic.
Array<tir::Var> free_tir_vars;
for (const auto& tir_var : FreeSymbolicVars(func->body)) {
if (!defined_tir_params.count(tir_var)) {
for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(func))) {
if (required_tir_vars.count(tir_var) && !inferable_tir_params.count(tir_var)) {
free_tir_vars.push_back(tir_var);
}
}
Expand Down
Loading

0 comments on commit 27a6820

Please sign in to comment.