Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Transform] Keep private non-primitive functions in FuseTIR #16565

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,12 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2);
*
* Any binding blocks that are left empty will be removed by the normalizer.
*
* \param entry_functions Names of functions that should be considered
* as entry points, in addition to any externally exposed functions.
*
* \return The Pass.
*/
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions = {});

/*!
* \brief Pass that changes calls to operators that can be done in-place
Expand Down
235 changes: 123 additions & 112 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,57 +961,73 @@ std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var&
*/
class TIRFuseMutator : public ExprMutator {
public:
static IRModule Transform(const IRModule& mod) {
Map<GlobalVar, BaseFunc> funcs_to_keep;
for (const auto& [gv, func] : mod->functions) {
// 1. If a TIR function has global symbol, we keep the function.
// 2. Always keep ExternFunc.
if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
if (prim_func->GetAttr<String>("global_symbol").defined()) {
funcs_to_keep.Set(gv, func);
static IRModule Transform(IRModule mod) {
// Collect all primitive relax functions
Map<GlobalVar, Function> primitive_relax;
for (const auto& [gvar, base_func] : mod->functions) {
// Only fuse primitive relax functions
if (base_func->HasNonzeroAttr(attr::kPrimitive)) {
if (auto func = base_func.as<relax::Function>()) {
primitive_relax.Set(gvar, func.value());
}
} else if (func->IsInstance<ExternFuncNode>()) {
funcs_to_keep.Set(gv, func);
}
}

if (primitive_relax.empty()) {
return mod;
}

mod.CopyOnWrite();

IRModule updates;
std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements;

// Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder.
TIRFuseMutator mutator(mod);

// Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_`
for (const auto& [gv, func] : mod->functions) {
// Only fuse primitive relax functions
if (func->IsInstance<relax::FunctionNode>() && func->HasNonzeroAttr(attr::kPrimitive)) {
const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, gv);
mutator.fused_tir_funcs_.Set(gv, prim_func);
if (!indices.empty()) {
mutator.inplace_indices_.Set(gv, indices);
}
}
for (const auto& [old_gvar, func] : primitive_relax) {
const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, old_gvar);

GlobalVar new_gvar(old_gvar->name_hint);
UpdateStructInfo(new_gvar,
FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)));

mod->Remove(old_gvar);
updates->Add(new_gvar, prim_func);
replacements[old_gvar] = Replacement{new_gvar, func, indices};
}

TIRFuseMutator mutator(replacements);

// Step 2. Update all non-primitive relax functions and add it, with the dependent function,
// into the new IRModule

for (const auto& [gv, func] : mod->functions) {
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
if (func->IsInstance<relax::FunctionNode>()) {
ICHECK(!func->HasNonzeroAttr(attr::kPrimitive))
<< "Module should not contain any primitive relax functions at this point";
relax::Function update_func = Downcast<Function>(mutator.VisitExpr(func));
mutator.builder_->AddFunction(update_func, gv->name_hint);
}
}

// Step 3. Add all functions that need to be kept.
auto modified_mod = mutator.builder_->GetContextIRModule();
for (const auto& [gv, func] : funcs_to_keep) {
if (!modified_mod->ContainGlobalVar(gv->name_hint)) {
modified_mod->Add(gv, func);
if (!update_func.same_as(func)) {
updates->Add(gv, update_func);
}
}
}

// Step 4. Copy over module attributes and return.
if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict);
return modified_mod;
// Step 4. Copy over updated functions and return.
mod->Update(updates);
return mod;
}

private:
explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {}
struct Replacement {
GlobalVar fused_tir_gvar;
Function original_function;
Array<Integer> inplace_indices;
};

explicit TIRFuseMutator(
std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements)
: replacements_(replacements) {}

using ExprMutator::VisitExpr_;

Expand All @@ -1035,92 +1051,86 @@ class TIRFuseMutator : public ExprMutator {

Call call = Downcast<Call>(builder_->Normalize(ExprMutator::VisitExpr_(op)));

if (call->op->IsInstance<GlobalVarNode>()) {
// Case 1. It is a relax cross function call
GlobalVar old_gv = Downcast<GlobalVar>(call->op);
auto relax_func = Downcast<Function>(mod_->Lookup(old_gv));
auto it = fused_tir_funcs_.find(old_gv);
if (it != fused_tir_funcs_.end()) {
const tir::PrimFunc& fused_tir = (*it).second;
// Case 1.1. It calls a primitive relax function, update the call into a call_tir
GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint);
// Step a. Flatten all args since call_tir does not support Tuple value.
Array<Expr> arg_list;
Array<PrimExpr> tir_vars;
for (size_t i = 0; i < call->args.size(); ++i) {
auto arg = call->args[i];
auto sinfo = GetStructInfo(arg);

ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
!sinfo.as<TupleStructInfoNode>())
<< "InternalError: "
<< "All tuple parameters should be expanded before this point in FuseTIR. "
<< "However, argument " << arg << " with struct info " << arg->struct_info_
<< " is passed as argument " << i << " to Primitive Relax function " << old_gv
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
<< relax_func->params[i]->struct_info_;

if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
CHECK(shape->values.defined())
<< "FuseTIR requires all shape input has struct_info value.";
for (const PrimExpr& prim_value : shape->values.value()) {
CHECK(prim_value->IsInstance<tir::VarNode>())
<< "All shape inputs are expected to be single tir var.";
tir_vars.push_back(prim_value);
}
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
CHECK(prim_value->value.defined())
<< "FuseTIR requires all R.Prim arguments to have a known value.";
PrimExpr expr = prim_value->value.value();
CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently requires all R.Prim "
"arguments to provide a single tir::Var.";
tir_vars.push_back(expr);

} else {
arg_list.push_back(arg);
}
}
// Step b. Create call_tir or call_tir_inplace
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
if (!tir_vars.empty()) {
call_args.push_back(ShapeExpr(tir_vars));
}
Op call_op = call_tir_op_;
Attrs call_attrs = call->attrs;
if (auto it = inplace_indices_.find(old_gv); it != inplace_indices_.end()) {
call_op = call_tir_inplace_op_;
auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
inplace_attrs->inplace_indices = (*it).second;
call_attrs = Attrs(inplace_attrs);
auto opt_gvar = call->op.as<GlobalVar>();
if (!opt_gvar) {
// Case 1. The Call isn't a relax-to-relax function call, no need to update.
return call;
}
GlobalVar old_gvar = opt_gvar.value();

auto it = replacements_.find(old_gvar);
if (it == replacements_.end()) {
// Case 2. The callee function is not a primitive relax
// function, no need to update.
return call;
}
const Replacement& replacement = it->second;
const GlobalVar& fused_tir_gv = replacement.fused_tir_gvar;
const Function& relax_func = replacement.original_function;

// Case 3. It calls a primitive relax function, update the call
// into a call_tir or call_tir_inplace.

// Step a. Collect all relax/symbolic arguments. Tuple arguments
// are not supported by PrimFunc, so this step verifies that
// ExpandTupleArguments has already removed them.
Array<Expr> arg_list;
Array<PrimExpr> tir_vars;
for (size_t i = 0; i < call->args.size(); ++i) {
auto arg = call->args[i];
auto sinfo = GetStructInfo(arg);

ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
!sinfo.as<TupleStructInfoNode>())
<< "InternalError: "
<< "All tuple parameters should be expanded before this point in FuseTIR. "
<< "However, argument " << arg << " with struct info " << arg->struct_info_
<< " is passed as argument " << i << " to Primitive Relax function " << old_gvar
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
<< relax_func->params[i]->struct_info_;

if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
CHECK(shape->values.defined()) << "FuseTIR requires all shape input has struct_info value.";
for (const PrimExpr& prim_value : shape->values.value()) {
CHECK(prim_value->IsInstance<tir::VarNode>())
<< "All shape inputs are expected to be single tir var.";
tir_vars.push_back(prim_value);
}
return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
CHECK(prim_value->value.defined())
<< "FuseTIR requires all R.Prim arguments to have a known value.";
PrimExpr expr = prim_value->value.value();
CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently requires all R.Prim "
"arguments to provide a single tir::Var.";
tir_vars.push_back(expr);

} else {
// Case 1.2. The callee function is not primitive, nothing to do.
return call;
}
} else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) {
// Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc.
if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
Array<Expr> new_args = call->args;
new_args.Set(0, new_gv);
return Call(call->op, new_args, call->attrs, call->sinfo_args, call->span);
arg_list.push_back(arg);
}
}

// Case 3. CallNode in other types. Leave it as it is.
return call;
// Step b. Create call_tir or call_tir_inplace
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
if (!tir_vars.empty()) {
call_args.push_back(ShapeExpr(tir_vars));
}
Op call_op = call_tir_op_;
Attrs call_attrs = call->attrs;
if (replacement.inplace_indices.size()) {
call_op = call_tir_inplace_op_;
auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
inplace_attrs->inplace_indices = replacement.inplace_indices;
call_attrs = Attrs(inplace_attrs);
}
return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
}

private:
/*! \brief The IRModule */
const IRModule& mod_;
/*! \brief The map from global var of primitive relax function to generated prim func. */
Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
/*! \brief The map from global var of primitive relax function to in-place indices
* (if there are any). */
Map<GlobalVar, Array<Integer>> inplace_indices_;
/*! \brief The map from global var to how it should be replaced
*
* Has one entry for each primitive relax function in the IRModule.
*/
std::unordered_map<GlobalVar, Replacement, ObjectPtrHash, ObjectPtrEqual> replacements_;
};

IRModule FuseTIR(IRModule mod) {
Expand All @@ -1142,6 +1152,7 @@ Pass FuseTIR() {
ExpandTupleArguments(),
RemoveUnusedParameters(),
inner_pass,
DeadCodeElimination(),
},
"FuseTIR");
}
Expand Down
60 changes: 60 additions & 0 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,5 +2254,65 @@ def main(
_check(Module, Expected)


def test_private_nonprimitive_func():
"""Input IRModule may contain calls to non-primitive functions

This is a regression test. Prior implementations did not preserve
relax-to-relax function calls.
"""

@I.ir_module
class Before:
@R.function
def main(
input_ids: R.Tensor((1,), dtype="int32"),
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((1, 4096), dtype="float16"):
cls = Before
with R.dataflow():
gv = cls.fused_func(input_ids, input_embeds)
R.output(gv)
return gv

@R.function(private=True)
def fused_func(
input_ids: R.Tensor((1,), dtype="int32"),
input_embeds: R.Tensor((4096, 4096), dtype="float16"),
) -> R.Tensor((1, 4096), dtype="float16"):
cls = Before
with R.dataflow():
lv = R.call_tir(
cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16")
)
gv = R.call_tir(
cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16")
)
R.output(gv)
return gv

@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
):
for i, j in T.grid(T.int64(4096), T.int64(4096)):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
Out[vi, vj] = A[vi, vj] + T.float16(1.0)

@T.prim_func(private=True)
def take(
A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"),
B: T.Buffer((T.int64(1),), "int32"),
T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"),
):
for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)):
with T.block("T_take"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]

_check(Before, Before)


if __name__ == "__main__":
tvm.testing.main()
Loading