Skip to content

Commit

Permalink
[Pass][UX] Statement rewriter for DataflowBlock (tlc-pack#210)
Browse files Browse the repository at this point in the history
- Implements a few APIs to quickly perform statement-level mutation: `add`/`remove_unused`/`remove_all_unused`/`replace_all_uses`.
- Implemented `remove_all_unused` to remove dead statements inside `DataflowBlock` cc: @psrivas2
- Address minor issues (unnecessary headers and bad docstrings) in tlc-pack#163
  • Loading branch information
ganler authored and junrushao committed Jan 25, 2023
1 parent dd8502e commit 374dfc5
Show file tree
Hide file tree
Showing 11 changed files with 1,196 additions and 29 deletions.
40 changes: 33 additions & 7 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/function.h>

#include <utility>

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -111,34 +113,58 @@ TVM_DLL tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr);
/*!
* \brief Analyze var -> value mapping from VarBindings.
*
* \param m the IRModule to check.
* \param m The IRModule to check.
* \return Var -> Value (Expr)
*/
TVM_DLL runtime::Map<Var, Expr> AnalyzeVar2Value(const IRModule& m);
TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const IRModule& m);

/*!
* \brief Analyze var -> value mapping from VarBindings.
*
* \param expr the expression to check.
* \param expr The expression to check.
* \return Var -> Value (Expr)
*/
TVM_DLL runtime::Map<Var, Expr> AnalyzeVar2Value(const Expr& expr);
TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const Expr& expr);

/*!
* \brief Analyze var -> value mapping from VarBindings.
*
* \param dfb the dataflow block to check.
* \param dfb The dataflow block to check.
* \return Var -> Value (Expr)
*/
TVM_DLL runtime::Map<Var, Expr> AnalyzeVar2Value(const DataflowBlock& dfb);
TVM_DLL Map<Var, Expr> AnalyzeVar2Value(const DataflowBlock& dfb);

/*!
* \brief Return a mapping from variable name to its Bindings.
*
* \param fn The function to be analyzed.
* \return A mapping from variable name to its Bindings.
*/
TVM_DLL Map<String, Array<Binding>> NameToBinding(const Function& fn);

/*!
* \brief Get the use-def chain of variables inside a dataflow block.
*
* \param dfb The dataflow block to be analyzed.
* \return A map mapping variable definitoins to a set of uses.
*/
TVM_DLL runtime::Map<Var, Array<Var>> UseDefChain(const DataflowBlock& dfb);
TVM_DLL Map<Var, Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb);

/*!
* \brief Get the use-def chain of variables inside a function.
*
* \param fn The function to be analyzed.
* \return A map from variable definitoins to a set of uses and variables needed by return value.
*/
std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn);

/*!
* \brief Remove unused statements inside DataflowBlocks.
*
* \param fn The function to remove unused statements.
* \return The function that contains no unused statements in DataflowBlock.
*/
TVM_DLL Function RemoveAllUnused(const Function fn);

} // namespace relax
} // namespace tvm
Expand Down
115 changes: 115 additions & 0 deletions include/tvm/relax/binding_rewrite.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/binding_rewrite.h
* \brief An IR rewriter to easily add/remove/replace bindings (statements).
*/

#ifndef TVM_RELAX_BINDING_REWRITE_H_

#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/utils.h>

#include <map>
#include <set>
#include <type_traits>
#include <utility>
#include <vector>

namespace tvm {
namespace relax {

/*! \brief Statement rewriter for relax.DataflowBlock. */
class DataflowBlockRewriteNode : public Object {
public:
/*! \brief Replace all uses of old_var with new_var. */
void ReplaceAllUses(Var old_var, Var new_var);
/*! \brief Insert a Binding statement. */
void Add(Binding binding);
/*! \brief Insert an expression as VarBinding with variable name. */
void Add(String var_name, Expr expr, bool is_dfvar = false) {
auto var = is_dfvar ? DataflowVar(var_name, expr->shape(), expr->checked_type())
: Var(var_name, expr->shape(), expr->checked_type());
Add(VarBinding(std::move(var), std::move(expr)));
}
/*! \brief Insert an expression as VarBinding with automatic variable name. */
void Add(Expr expr, bool is_dfvar = false) {
Add(name_table_.GetUniqueName("tmp"), expr, is_dfvar);
}
/*! \brief Remove the definition statement of an unused variable. */
void RemoveUnused(Var unused, bool allow_undef = false);
/*! \brief Remove the definition statements of all unused variables. */
void RemoveAllUnused();

/*! \brief The rewritten dataflow block. */
DataflowBlock MutatedDataflowBlock() { return dfb_.value(); }
/*! \brief The rewritten function. */
Function MutatedFunc() { return root_fn_.value(); }
/*! \brief The rewritten IRModule. */
IRModule MutateIRModule(IRModule irmod);

/*! \brief Visit attributes. */
void VisitAttrs(AttrVisitor* v) {
v->Visit("dfb", &dfb_);
v->Visit("root_fn", &root_fn_);
}

static constexpr const char* _type_key = "relax.DataflowBlockRewrite";
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object);

protected:
friend class DataflowBlockRewrite;

Optional<DataflowBlock> dfb_; //!< The rewritten dataflow block.
Optional<Function> root_fn_; //!< The rewritten function.
const FunctionNode* original_fn_ptr_; //!< Pointer to the original function.
Map<Var, Array<Var>> to_users_; //!< Map from variable to its users.
Array<Var> fn_outputs_; //!< Variables required by function outputs.

private:
NameTable name_table_; //!< Name table for tracking and generating unique names.
};

/*!
* \brief A statement rewriter for relax.DataflowBlock.
* \sa DataflowBlockRewriteNode
*/
class DataflowBlockRewrite : public ObjectRef {
public:
TVM_DLL explicit DataflowBlockRewrite(DataflowBlock dfb, Function root_fn);

/*!
* \brief mutable accessor.
* \return mutable access pointer.
*/
DataflowBlockRewriteNode* operator->() {
ICHECK(get() != nullptr);
return static_cast<DataflowBlockRewriteNode*>(get_mutable());
}

TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockRewrite, ObjectRef, DataflowBlockRewriteNode);
};

} // namespace relax
} // namespace tvm

#define TVM_RELAX_BINDING_REWRITE_H_
#endif // TVM_RELAX_BINDING_REWRITE_H_
36 changes: 35 additions & 1 deletion include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
#ifndef TVM_RELAX_UTILS_H_
#define TVM_RELAX_UTILS_H_

#include <tvm/runtime/logging.h>

#include <algorithm>
#include <cctype>
#include <string>
#include <type_traits>
#include <unordered_map>

namespace tvm {
Expand Down Expand Up @@ -53,8 +57,38 @@ class NameTable {
return unique_prefix;
}

NameTable() = default;

template <typename Iter, typename Lambda>
explicit NameTable(Iter begin, Iter end, Lambda f) {
// static_assert is more reader-friendly than SFINAE when template specialization is not needed.
static_assert(std::is_convertible<decltype(f(*begin)), std::string>::value,
"Lambda f must has a signature of [?](*it) -> string {}");
for (auto it = begin; it != end; ++it) {
const std::string& name = f(*it);
const size_t idx_last_first_num = std::distance(
std::find_if(name.rbegin(), name.rend(), [](char c) { return !std::isdigit(c); }),
name.rend());
// name = {O = others}{D = consecutive digits}
// let O -> prefix;
std::string prefix = name.substr(0, idx_last_first_num);
ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) << "Invalid variable name: " << name;
if (0 == alloc_map_.count(prefix)) alloc_map_[prefix] = 0;
if (idx_last_first_num < name.size()) { // has some digits.
// let D's nearest natural number -> idx;
// note: stoul("000123") = 123;
alloc_map_[prefix] =
std::max(alloc_map_[prefix], std::stoi(name.substr(idx_last_first_num)));
}
}
}

template <typename Iter>
explicit NameTable(Iter begin, Iter end)
: NameTable(begin, end, [](const decltype(*begin)& v) { return v; }) {}

private:
std::unordered_map<std::string, uint32_t> alloc_map_;
std::unordered_map<std::string, int> alloc_map_;
};

/*!
Expand Down
23 changes: 22 additions & 1 deletion python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import Dict, List

import tvm
from tvm.relax.expr import DataflowBlock, Var, Expr, Function
from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Binding
from . import _ffi_api


Expand Down Expand Up @@ -92,3 +92,24 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]:
A mapping from variable definition to its uses.
"""
return _ffi_api.udchain(dfb)


def name_to_binding(func: Function) -> Dict[str, List[Binding]]:
"""Return a map from variable name to its bindings."""
return _ffi_api.name_to_binding(func)


def remove_all_unused(func: Function) -> Function:
"""Remove all unused variables from the function.
Parameters
----------
func : Function
The input function to be analyzed.
Returns
-------
Function
The function with unused variables removed.
"""
return _ffi_api.remove_all_unused(func)
Loading

0 comments on commit 374dfc5

Please sign in to comment.