Skip to content

Commit

Permalink
[Unity][Transform] LiftTransformParams handling multiple functions (#…
Browse files Browse the repository at this point in the history
…14192)

Previously, the LiftTransformParams pass only works on function
`"main"`. This is a bit restrictive as in our recent practice on stable
diffusion, there are cases where multiple Relax functions inside an
IRModule all need to be transformed.

Therefore, this PR enhances the LiftTransformParams pass, so that it
will now transform **all** functions **with attribute `num_input`**. For
functions without this attribute, the pass will simply skip them.
  • Loading branch information
MasterJH5574 authored and tqchen committed Apr 1, 2023
1 parent 279317d commit 3f66edc
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 18 deletions.
37 changes: 22 additions & 15 deletions src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,35 +207,41 @@ class TransformParamsLifter : public ExprMutator {

IRModule Lift() {
auto mod = builder_->GetContextIRModule();
GlobalVar gv_main = mod->GetGlobalVar("main");
Function func = Downcast<Function>(mod->Lookup(gv_main));
func = RewriteFunc(func);
builder_->UpdateFunction(gv_main, func);
for (const auto& [gv, base_func] : mod->functions) {
// Skip non-Relax functions.
const auto* func_ = base_func.as<FunctionNode>();
if (func_ == nullptr) {
continue;
}
// Skip functions that do not have the `num_input` attribute.
Optional<Integer> opt_num_input = func_->attrs.GetAttr<Integer>(attr_num_input_);
if (!opt_num_input.defined()) {
continue;
}
Function func = RewriteFunc(GetRef<Function>(func_), opt_num_input.value()->value,
gv->name_hint + "_transform_params");
builder_->UpdateFunction(gv, func);
}

return builder_->GetContextIRModule();
}

private:
Function RewriteFunc(const Function& func) {
const std::string attr_num_input = "num_input";
auto opt_num_input = func->attrs.GetAttr<Integer>(attr_num_input);
if (!opt_num_input.defined()) {
return func;
}
Function RewriteFunc(const Function& func, int num_input, String new_func_name) {
LiftTransformParamsPlanner planner;
int64_t params_begin = opt_num_input.value()->value;

// Step 1: Create the plan of lifting transform params
lift_plan_ = planner.Plan(func, params_begin);
lift_plan_ = planner.Plan(func, num_input);

// Step 2: Add the lifted function to the module
builder_->AddFunction(lift_plan_.f_transform_params, "transform_params");
builder_->AddFunction(lift_plan_.f_transform_params, new_func_name);

// Step 3: Update the current function.

// Step 3.1: Update the function signature
Var params("params", lift_plan_.f_transform_params->ret_struct_info);
Array<Var> new_params;
for (int i = 0; i < params_begin; ++i) {
for (int i = 0; i < num_input; ++i) {
new_params.push_back(func->params[i]);
}
new_params.push_back(params);
Expand All @@ -249,7 +255,7 @@ class TransformParamsLifter : public ExprMutator {
// Step 3.3: Remove function attributes that are not needed
auto new_attrs = func->attrs;
auto* new_attrs_node = new_attrs.CopyOnWrite();
new_attrs_node->dict.erase(attr_num_input);
new_attrs_node->dict.erase(attr_num_input_);
if (new_attrs->dict.empty()) {
new_attrs = NullValue<DictAttrs>();
}
Expand Down Expand Up @@ -277,6 +283,7 @@ class TransformParamsLifter : public ExprMutator {
return VisitExpr_(static_cast<const VarNode*>(var));
}

const char* attr_num_input_ = "num_input";
// Remap the original parameters to TupleGetItem from the packed tuple of transformed parameters.
std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
// The plan of lifting the transform params
Expand Down
105 changes: 102 additions & 3 deletions tests/python/relax/test_transform_lift_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def transform_layout_IOHW_to_OIHW(
out[o, i, h, w] = w1[i, o, h, w]

@R.function
def transform_params(
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
)
Expand Down Expand Up @@ -193,7 +193,7 @@ def main(
return conv2

@R.function
def transform_params(
def main_transform_params(
params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32"))
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
Expand Down Expand Up @@ -242,7 +242,7 @@ def main(
@tvm.script.ir_module
class Expected:
@R.function
def transform_params(
def main_transform_params(
params: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
Expand Down Expand Up @@ -291,5 +291,104 @@ def main(
tvm.ir.assert_structural_equal(after, Expected)


def test_multiple_functions():
@tvm.script.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1, [1, 0])
y = R.matmul(x, w1_t)
R.output(y)
return y

@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((128, 256), "float32"),
) -> R.Tensor((256, 128), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1, [1, 0])
y = R.matmul(x, w1_t)
R.output(y)
return y

@R.function
def func3(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
with R.dataflow():
w1_t = R.permute_dims(w1, [1, 0])
y = R.matmul(x, w1_t)
R.output(y)
return y

@tvm.script.ir_module
class Expected:
@R.function
def func1(
x: R.Tensor((256, 256), dtype="float32"),
params: R.Tuple(R.Tensor((256, 256), dtype="float32")),
) -> R.Tensor((256, 256), dtype="float32"):
with R.dataflow():
lv: R.Tensor((256, 256), dtype="float32") = params[0]
y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, lv, out_dtype="void")
R.output(y)
return y

@R.function
def func1_transform_params(
params: R.Tuple(R.Tensor((256, 256), dtype="float32"))
) -> R.Tuple(R.Tensor((256, 256), dtype="float32")):
with R.dataflow():
lv: R.Tensor((256, 256), dtype="float32") = params[0]
lv1: R.Tensor((256, 256), dtype="float32") = R.permute_dims(lv, axes=[1, 0])
gv: R.Tuple(R.Tensor((256, 256), dtype="float32")) = (lv1,)
R.output(gv)
return gv

@R.function
def func2(
x: R.Tensor((256, 256), dtype="float32"),
params: R.Tuple(R.Tensor((256, 128), dtype="float32")),
) -> R.Tensor((256, 128), dtype="float32"):
with R.dataflow():
lv1: R.Tensor((256, 128), dtype="float32") = params[0]
y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, lv1, out_dtype="void")
R.output(y)
return y

@R.function
def func2_transform_params(
params: R.Tuple(R.Tensor((128, 256), dtype="float32"))
) -> R.Tuple(R.Tensor((256, 128), dtype="float32")):
with R.dataflow():
lv: R.Tensor((128, 256), dtype="float32") = params[0]
lv1: R.Tensor((256, 128), dtype="float32") = R.permute_dims(lv, axes=[1, 0])
gv: R.Tuple(R.Tensor((256, 128), dtype="float32")) = (lv1,)
R.output(gv)
return gv

@R.function
def func3(
x: R.Tensor((256, 256), dtype="float32"), w1: R.Tensor((256, 256), dtype="float32")
) -> R.Tensor((256, 256), dtype="float32"):
with R.dataflow():
w1_t: R.Tensor((256, 256), dtype="float32") = R.permute_dims(w1, axes=[1, 0])
y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_t, out_dtype="void")
R.output(y)
return y

mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 3f66edc

Please sign in to comment.