Skip to content

Commit

Permalink
[Unity][Transform] Extract partial-tuple-usage from FuseTIR (#16120)
Browse files Browse the repository at this point in the history
* [Unity][Transform] Extract partial-tuple-usage from FuseTIR

Prior to this commit, the `FuseTIR` pass explicitly tracked usage of
tuple arguments, to minimize the set of arguments provided to each
kernel.  The additional tgracking and handling of partially-used
tuples makes it difficult to follow the primary changes being made by
`FuseTIR`.

This commit implements the same functionality in terms of the
`ExpandTupleArguments` and `RemoveUnusedParameters` transforms,
introduced in #16115 and
#16116 respectively.  By using these
passes before the main `FuseOps` changes, partial tuple usage is
already handled at that point.

This commit is intended to minimize any changes to user-facing
behavior, and so these pre-process passes are currently used
internally by `FuseOps`.  This may be avoided in the future by pulling
this internal delegation out into a lowering pipeline.

* Updated based on review comments

* ci bump
  • Loading branch information
Lunderberg authored Jan 3, 2024
1 parent 1af82ad commit ec542da
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 162 deletions.
261 changes: 106 additions & 155 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,58 +385,47 @@ class FusedTIRConstructor : public ExprVisitor {
: mod_(mod), func_name_(func_name) {}

void VisitExpr_(const FunctionNode* func) final {
// Step 1. Create buffers for function params

// Record which fields in a tuple passed as a parameter are actually accessed by the function.
std::unordered_set<const Object*> tuple_param;
for (auto param : func->params) {
if (GetStructInfo(param)->IsInstance<TupleStructInfoNode>()) {
tuple_param.insert(param.get());
}
}

PostOrderVisit(func->body, [=, &tuple_param](Expr e) {
if (auto tup_get = e.as<TupleGetItemNode>();
tup_get && tuple_param.count(tup_get->tuple.get())) {
func_info_.used_tuple_field_indices[tup_get->tuple.get()].insert(tup_get->index);
}
});

std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
for (const Var& relax_param : func->params) {
auto sinfo = GetStructInfo(relax_param);
if (sinfo->IsInstance<ShapeStructInfoNode>()) {
// It's a symbolic shape var, no need to alloc Buffers.
continue;
}

auto [params, buffers] = [=]() {
if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
// Add only those tuple fields which are actually used by the function body into the
// function parameters.
int index = 0;
Array<tir::Var> params;
Array<tir::Buffer> buffers;
for (auto i : func_info_.used_tuple_field_indices[relax_param.get()]) {
auto [ret_params, ret_buffers] =
CreateParamsAndBuffers(tuple->fields[i], relax_param->name_hint(), index);
ICHECK_EQ(ret_params.size(), ret_buffers.size());
// Adding tuple field results to the end of params and buffers.
params.insert(params.end(), ret_params.begin(), ret_params.end());
buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end());
index += ret_params.size();
size_t size_before = prim_func_params.size();
CollectPrimFuncParams(relax_param, &prim_func_params);

auto param_buffers = [&]() -> Array<tir::Buffer> {
Array<tir::Buffer> out;
for (size_t i = size_before; i < prim_func_params.size(); i++) {
if (auto buf = prim_func_params[i].as<tir::Buffer>()) {
out.push_back(buf.value());
}
return std::make_pair(params, buffers);
} else {
return CreateParamsAndBuffers(sinfo, relax_param->name_hint());
}
return out;
}();

ICHECK_EQ(params.size(), buffers.size());
for (size_t i = 0; i < params.size(); ++i) {
func_info_.buffer_map.Set(params[i], buffers[i]);
func_info_.params.push_back(params[i]);
func_info_.expr2buffers.Set(relax_param, param_buffers);
}

// Move all scalar params after buffer params. To ensure that the
// order is deterministic and predictable for testing purposes,
// std::stable_sort is used instead of std::sort.
std::stable_sort(prim_func_params.begin(), prim_func_params.end(),
[](const auto& a, const auto& b) {
bool a_is_var = a.template as<tir::VarNode>();
bool b_is_var = b.template as<tir::VarNode>();
return a_is_var < b_is_var;
});

for (const auto& param : prim_func_params) {
if (auto opt = param.as<tir::Buffer>()) {
auto buffer = opt.value();
// Differentiate buffer name and param name by adding prefix
// `p_` to the buffer name. Every symbol should be unique in
// TVMScript, and while they can be de-deplicated when
// printed, it's more readable when done explicitly. Since
// Buffer is used more than param it gets the name with better
// readability.
tir::Var param = tir::Var("p_" + buffer->name, PrimType(DataType::Handle()));
func_info_.params.push_back(param);
func_info_.buffer_map.Set(param, buffer);
}
func_info_.expr2buffers.Set(relax_param, buffers);
}

// Step 2. Visit Function body and create intermediate buffers
Expand All @@ -458,13 +447,9 @@ class FusedTIRConstructor : public ExprVisitor {
}

// Step 4. Append symbolic vars
const relax::Var& last_relax_param = func->params.back();
if (GetStructInfo(last_relax_param)->IsInstance<ShapeStructInfoNode>()) {
auto [params, buffers] =
CreateParamsAndBuffers(GetStructInfo(last_relax_param), last_relax_param->name_hint());
ICHECK(buffers.empty());
for (size_t i = 0; i < params.size(); ++i) {
func_info_.params.push_back(params[i]);
for (const auto& param : prim_func_params) {
if (auto var = param.as<tir::Var>()) {
func_info_.params.push_back(var.value());
}
}

Expand Down Expand Up @@ -548,12 +533,7 @@ class FusedTIRConstructor : public ExprVisitor {
int end_buf_idx = 0;
const TupleType& tuple_type = Downcast<TupleType>(tuple_get_item->tuple->checked_type());
for (int i = 0; i < tuple_get_item->index; ++i) {
auto it = func_info_.used_tuple_field_indices.find(tuple_get_item->tuple.get());
// If this tuple is not passed as a parameter, or if the field at the index i is actually
// used, the corresponding buffer needs to be taken into account by this function.
if (it == func_info_.used_tuple_field_indices.end() || it->second.count(i)) {
begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
}
begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
}
end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]);
func_info_.expr2buffers.Set(
Expand Down Expand Up @@ -719,64 +699,47 @@ class FusedTIRConstructor : public ExprVisitor {
}

/*!
* \brief Create an TIR func params and buffers with specified relax type and shape
* \brief Collect TIR func params and buffers with specified relax type and shape
* \param struct_info The struct info
* \param name_hint The name hint for params and buffers
* \param index The index used for unique name_hint if type is Tuple.
* -1 means no need to add postfix since the relax param is not a Tuple.
* \return The created TIR func params and buffers
* \param out The vector into which to collect the params/buffers
*/
static std::pair<Array<tir::Var>, Array<tir::Buffer>> CreateParamsAndBuffers(
StructInfo struct_info, const String& name_hint, int index = -1) {
Array<tir::Var> params;
Array<tir::Buffer> buffers;
// The symbolic shape params must be defined at the end of the param list.
bool symbolic_shape_param_started = false;
static void CollectPrimFuncParams(const Var& relax_param,
std::vector<Variant<tir::Var, tir::Buffer>>* out) {
auto struct_info = GetStructInfo(relax_param);

CHECK(!struct_info.as<TupleStructInfoNode>())
<< "InternalError: "
<< "All tuple parameters should be expanded before this point in FuseTIR. "
<< "However, parameter " << relax_param << " has struct info " << struct_info;

auto name_hint = relax_param->name_hint();

if (const auto* tensor = struct_info.as<TensorStructInfoNode>()) {
// Case 1. the relax param is a Tensor, we directly create a tir var and buffer
// Case 1. The relax param is a Tensor, we directly create a tir var and buffer
const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape.";
CHECK(!symbolic_shape_param_started)
<< "The symbolic shape params must be defined at the end of the param "
"list.";
String name = index == -1 ? name_hint : name_hint + "_" + std::to_string(index);
ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape.";
DataType dtype = tensor->dtype;
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name);
// Differentiate buffer name and param name by adding prefix `v_` to param
// Every symbol should be unique in TVMScript, and Buffer is used more than param
// So we decide to make sure buffer names have better readability.
tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle()));
params.push_back(std::move(param));
buffers.push_back(std::move(buffer));
} else if (const auto* tuple = struct_info.as<TupleStructInfoNode>()) {
// Case 2. the relax param is a Tuple, we recursively visit each field until it's a Tensor
// Enable postfix
CHECK(!symbolic_shape_param_started)
<< "The symbolic shape params must be defined at the end of the param "
"list.";
if (index == -1) index = 0;
for (size_t i = 0; i < tuple->fields.size(); ++i) {
auto [ret_params, ret_buffers] = CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
ICHECK_EQ(ret_params.size(), ret_buffers.size());
// Adding tuple field results to the end of params and buffers.
params.insert(params.end(), ret_params.begin(), ret_params.end());
buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end());
index += ret_params.size();
}
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
out->push_back(std::move(buffer));

} else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {
// Case 2. The relax param is a scalar, we directly create a tir var
ICHECK(prim_value->value->IsInstance<tir::VarNode>());
out->push_back(Downcast<tir::Var>(prim_value->value));

} else if (const auto* shape_expr = struct_info.as<ShapeStructInfoNode>()) {
// Case 3. the relax param is a scalar, we directly create a tir var
symbolic_shape_param_started = true;
ICHECK(index == -1) << "TypeError: The ShapeExprNode should not be in a Tuple field.";
// Case 3. The relax param is a tuple of scalars, each represented as a tir var
for (const auto& var : shape_expr->values.value()) {
ICHECK(var->IsInstance<tir::VarNode>());
params.push_back(Downcast<tir::Var>(var));
out->push_back(Downcast<tir::Var>(var));
}
} else {
ICHECK(false) << "TypeError: The param type of PrimFunc is expected to be Tensor, Tuple or "
"ShapeExpr, but got "
<< struct_info->GetTypeKey();
LOG(FATAL) << "TypeError: "
<< "The param type of PrimFunc is expected to be "
<< "Tensor, PrimValue, or ShapeExpr, "
<< "but got " << struct_info->GetTypeKey();
}
return std::make_pair(params, buffers);
}

/*!
Expand Down Expand Up @@ -870,9 +833,6 @@ class FusedTIRConstructor : public ExprVisitor {
/*! \brief The map from symbolic var to its corresponding var in the fused function */
tir::SymbolicMatcher symbolic_var_matcher =
tir::SymbolicMatcher(&analyzer, &symbolic_var_remap);

/*! \brief Record indices of tuple fields that are actually accessed. */
std::unordered_map<const Object*, std::unordered_set<size_t>> used_tuple_field_indices;
};

/*! \brief The IRModule */
Expand Down Expand Up @@ -987,34 +947,35 @@ class TIRFuseMutator : public ExprMutator {
Array<PrimExpr> tir_vars;
for (size_t i = 0; i < call->args.size(); ++i) {
auto arg = call->args[i];
Array<Expr> flattened;
if (GetStructInfo(relax_func->params[i])->IsInstance<TupleStructInfoNode>()) {
// Add only those tuple fields which are actually used by the function body
auto tup_get_indices = GetTupleAccessedIndices(relax_func.get(), relax_func->params[i]);
for (size_t tup_get_ind : tup_get_indices) {
auto flattened_inner = FlattenArg(builder_->Emit(TupleGetItem(arg, tup_get_ind)));
flattened.insert(flattened.end(), flattened_inner.begin(), flattened_inner.end());
auto sinfo = GetStructInfo(arg);

ICHECK(!relax_func->params[i]->struct_info_->IsInstance<TupleStructInfoNode>() &&
!sinfo.as<TupleStructInfoNode>())
<< "InternalError: "
<< "All tuple parameters should be expanded before this point in FuseTIR. "
<< "However, argument " << arg << " with struct info " << arg->struct_info_
<< " is passed as argument " << i << " to Primitive Relax function " << old_gv
<< ", which expects parameter " << relax_func->params[i] << " to have struct info "
<< relax_func->params[i]->struct_info_;

if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
CHECK(shape->values.defined())
<< "FuseTIR requires all shape input has struct_info value.";
for (const PrimExpr& prim_value : shape->values.value()) {
CHECK(prim_value->IsInstance<tir::VarNode>())
<< "All shape inputs are expected to be single tir var.";
tir_vars.push_back(prim_value);
}
} else {
flattened.push_back(arg);
}
} else if (const auto* prim_value = sinfo.as<PrimStructInfoNode>()) {
CHECK(prim_value->value.defined())
<< "FuseTIR requires all R.Prim arguments to have a known value.";
PrimExpr expr = prim_value->value.value();
CHECK(expr->IsInstance<tir::VarNode>())
<< "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var.";
tir_vars.push_back(expr);

for (const Expr& e : flattened) {
StructInfo sinfo = GetStructInfo(e);
if (sinfo->IsInstance<TensorStructInfoNode>()) {
arg_list.push_back(e);
} else if (const auto* shape = sinfo.as<ShapeStructInfoNode>()) {
CHECK(shape->values.defined())
<< "FuseTIR requires all shape input has struct_info value.";
for (const PrimExpr& prim_value : shape->values.value()) {
CHECK(prim_value->IsInstance<tir::VarNode>())
<< "All shape inputs are expected to be single tir var.";
tir_vars.push_back(prim_value);
}
} else {
LOG(FATAL) << "The flattened arg is expected to be either tensor or shape, but got "
<< sinfo->GetTypeKey();
}
} else {
arg_list.push_back(arg);
}
}
// Step b. Create call_tir
Expand Down Expand Up @@ -1042,23 +1003,6 @@ class TIRFuseMutator : public ExprMutator {
return call;
}

/********** Helper Functions **********/

/*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */
Array<Expr> FlattenArg(const Expr& arg) {
if (const auto* tuple_sinfo = GetStructInfoAs<TupleStructInfoNode>(arg)) {
Array<Expr> arg_list;
for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
Expr new_arg = builder_->Emit(TupleGetItem(arg, i));
Array<Expr> flattened = FlattenArg(new_arg);
arg_list.insert(arg_list.end(), flattened.begin(), flattened.end());
}
return arg_list;
} else {
return {arg};
}
}

private:
/*! \brief The IRModule */
const IRModule& mod_;
Expand All @@ -1076,10 +1020,17 @@ namespace transform {
Pass FuseTIR() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) { return relax::FuseTIR(m); };
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseTIR", //
/*required=*/{});
auto inner_pass = CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseTIRInner", //
/*required=*/{});
return tvm::transform::Sequential(
{
ExpandTupleArguments(),
RemoveUnusedParameters(),
inner_pass,
},
"FuseTIR");
}

TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR);
Expand Down
14 changes: 7 additions & 7 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def fused_exp_squeeze(x):
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_exp_squeeze, x)
lv2 = bb.emit_te(fused_exp_squeeze, lv)
lv2 = bb.call_te(fused_exp_squeeze, lv)
gv = bb.emit_output(lv2)
bb.emit_func_output(gv)
return bb.get()
Expand Down Expand Up @@ -245,7 +245,7 @@ def fused_exp_exp_squeeze(x):
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_exp_exp_squeeze, x)
lv = bb.call_te(fused_exp_exp_squeeze, x)
gv = bb.emit_output(lv)
bb.emit_func_output(gv)
return bb.get()
Expand All @@ -257,7 +257,7 @@ def test_fuse_with_tuple_as_param():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], "float32")]))
with bb.function("fused_exp_add", [x], attrs={"Primitive": True}):
with bb.function("fused_exp_add", [x], attrs={"Primitive": True}, private=True):
with bb.dataflow():
lv0 = bb.emit(relax.TupleGetItem(x, 0))
lv1 = bb.emit(relax.TupleGetItem(x, 1))
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_fuse_with_nested_tuple_as_param():
def before():
bb = relax.BlockBuilder()
x = relax.Var("x", tuple_struct_info)
with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}):
with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}, private=True):
with bb.dataflow():
lv0 = bb.emit(relax.TupleGetItem(x, 0))
lv0_exp = bb.emit_te(topi.exp, lv0)
Expand Down Expand Up @@ -373,7 +373,7 @@ def fused_exp_squeeze(x):
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_exp_squeeze, x)
lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32"))
lv2 = bb.call_te(topi.add, lv, relax.const(1, "float32"))
gv = bb.emit_output(lv2)
bb.emit_func_output(gv)
return bb.get()
Expand Down Expand Up @@ -414,7 +414,7 @@ def fused_add_exp_squeeze(x, y):
x = relax.Var("x", R.Tensor([10, 20], "float32"))
with bb.function("main", [x]):
with bb.dataflow():
lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, "float32"))
lv = bb.call_te(fused_add_exp_squeeze, x, relax.const(1, "float32"))
gv = bb.emit_output(lv)
bb.emit_func_output(gv)
return bb.get()
Expand Down Expand Up @@ -1268,7 +1268,7 @@ def reshape(
(v_ax2 * T.int64(64) + v_ax3) % T.int64(2048),
]

@R.function
@R.function(private=True)
def fused_reshape(
lv: R.Tuple(
R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32")
Expand Down

0 comments on commit ec542da

Please sign in to comment.