Skip to content

Commit

Permalink
[Relax] Allow DeadCodeElimination within ApplyPassToFunction (#16801)
Browse files Browse the repository at this point in the history
The `tvm.ir.transform.ApplyPassToFunction` allows a transform to be
applied selectively to some portions of a `IRModule`, without applying
to the entire `IRModule`.  For example, to apply an optimization
pass (e.g. `relax.transform.ExpandMatmulOfSum`) or an
interface-altering pass (e.g. `relax.transform.BundleModelParams`) to
specific functions.  It does so by generating an intermediate
`IRModule` containing only the functions specified, applying the
transform to that intermediate, then merging the results.

When using `ApplyPassToFunction` to apply `DeadCodeElimination`, or a
pipeline containing `DeadCodeElimination`, this intermediate
`IRModule` may contain calls to `GlobalVar` instances that are not
within the intermediate `IRModule`.  Prior to this commit, this
resulted in an error being thrown when collecting the call graph.
This commit updates `DeadCodeElimination` to instead handle incomplete
call-graph collection.
  • Loading branch information
Lunderberg authored Apr 3, 2024
1 parent 35c6143 commit 545e097
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 12 deletions.
37 changes: 30 additions & 7 deletions src/relax/transform/dead_code_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,22 @@ class CallTracer : public ExprVisitor {
explicit CallTracer(IRModule mod) : mod_{mod}, called_funcs_{}, visiting_{} {}

void VisitExpr_(const GlobalVarNode* op) final {
called_funcs_.insert(GetRef<GlobalVar>(op));
auto func = mod_->Lookup(op->name_hint);
if (const auto* function_node = func.as<FunctionNode>()) {
VisitExpr(GetRef<Function>(function_node));
auto gvar = GetRef<GlobalVar>(op);
called_funcs_.insert(gvar);
if (auto func = mod_->functions.Get(gvar)) {
if (const auto* function_node = func.as<FunctionNode>()) {
VisitExpr(GetRef<Function>(function_node));
}
// else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein.
} else {
// The GlobalVar is not contained in the IRModule. While the
// input IRModule is ill-formed, this specific case is allowed
// for use with `relax.transform.ApplyPassToFunction`. If this
// occurs, DCE should not remove any internal functions from the
// IRModule, as their removal is only valid if we have a
// complete call graph.
all_callees_found_ = false;
}
// else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein.
}

void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); }
Expand All @@ -77,11 +87,24 @@ class CallTracer : public ExprVisitor {
VisitExpr(main_func);
}

bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; }
/* \brief Check if a function is unreachable
*
* \param gvar The function to be checked
*
* \return True if the function can be proven to be unreachable,
* either directly or indirectly, from an external caller.
* Otherwise, false.
*/
bool CheckIfProvablyUnreachable(const GlobalVar& gvar) const {
return all_callees_found_ && !called_funcs_.count(gvar);
}

private:
IRModule mod_;

/* \brief Whether all callees could be located within the IRModule */
bool all_callees_found_{true};

// Record the names of all encountered functions.
std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> called_funcs_;

Expand All @@ -101,7 +124,7 @@ IRModule RemoveUnusedFunctions(
// The tracer contains all user-provided entry functions, all
// externally-callable functions, and anything that is directly or
// indirectly accessible from an entry function.
if (!tracer.check_if_called(kv.first)) {
if (tracer.CheckIfProvablyUnreachable(kv.first)) {
to_remove.push_back(kv.first);
}
}
Expand Down
22 changes: 17 additions & 5 deletions tests/python/relax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ def pytest_configure(config):
"markers",
(
"skip_well_formed_check_before_transform: "
"Only check for well-formed IRModule after a transform"
"Suppress the default well-formed check before a IRModule transform"
),
)
config.addinivalue_line(
"markers",
(
"skip_well_formed_check_after_transform: "
"Suppress the default well-formed check after a IRModule transform"
),
)

Expand All @@ -54,15 +61,20 @@ def pytest_configure(config):
# `@pytest.mark.skip_well_formed_check_before_transform`
@pytest.fixture(autouse=True)
def apply_instrument_well_formed(unit_test_marks):

validate_before_transform = "skip_well_formed_check_before_transform" not in unit_test_marks
validate_after_transform = "skip_well_formed_check_after_transform" not in unit_test_marks

instrument = WellFormedInstrument(validate_before_transform=validate_before_transform)
current = tvm.transform.PassContext.current()
instruments = list(current.instruments)

if validate_before_transform or validate_after_transform:
instruments.append(
WellFormedInstrument(validate_before_transform=validate_before_transform)
)

override = tvm.transform.PassContext(
# Append the new instrument
instruments=[*current.instruments, instrument],
# With the new WellFormedInstrument appended
instruments=instruments,
# Forward all other parameters
opt_level=current.opt_level,
required_pass=current.required_pass,
Expand Down
155 changes: 155 additions & 0 deletions tests/python/relax/test_transform_dead_code_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

import pytest

import tvm
import tvm.testing
from tvm.relax.transform import DeadCodeElimination
Expand Down Expand Up @@ -507,5 +509,158 @@ def test_extern_func():
verify(before, before)


@pytest.mark.skip_well_formed_check_before_transform
@pytest.mark.skip_well_formed_check_after_transform
def test_compatibility_with_apply_pass_to_function():
"""DeadCodeElimination can be used with ApplyPassToFunction
The `ApplyPassToFunction` utility calls another transform, where
only the specified functions are exposed to the internal
transform. This intermediate does not contain `cls.subroutine`,
and so the intermediate is ill-formed.
In general, IRModule transformations may assume that their inputs
are well-formed. In specific cases, IRModule transformations may
accept IRModules that are ill-formed. The `DeadCodeElimination`
transform allows IRModule arguments that are ill-formed due to
a dangling GlobalVar.
After `DeadCodeElimination` completes, the resulting function is
inserted in the original IRModule, providing a well-formed output
from `ApplyPassToFunction`.
"""

@I.ir_module
class Before:
@R.function
def to_be_transformed(A: R.Tensor):
cls = Before

B = R.add(A, A)
C = cls.subroutine(B)
D = R.multiply(C, C)
return C

@R.function
def to_be_ignored(A: R.Tensor):
cls = Before

B = R.add(A, A)
C = cls.subroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subroutine(arg: R.Tensor) -> R.Tensor:
return R.add(arg, arg)

@I.ir_module
class Expected:
@R.function
def to_be_transformed(A: R.Tensor):
cls = Expected

B = R.add(A, A)
C = cls.subroutine(B)
return C

@R.function
def to_be_ignored(A: R.Tensor):
cls = Expected

B = R.add(A, A)
C = cls.subroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subroutine(arg: R.Tensor) -> R.Tensor:
return R.add(arg, arg)

# The well-formed check in conftest.py must be disabled, to avoid
# triggering on the ill-formed intermediate, so this unit test
# checks it explicitly.
assert tvm.relax.analysis.well_formed(Before)
After = tvm.ir.transform.ApplyPassToFunction(
tvm.relax.transform.DeadCodeElimination(),
"to_be_transformed",
)(Before)
assert tvm.relax.analysis.well_formed(After)
tvm.ir.assert_structural_equal(Expected, After)


@pytest.mark.skip_well_formed_check_before_transform
@pytest.mark.skip_well_formed_check_after_transform
def test_well_formed_output_with_restricted_scope():
"""DeadCodeElimination can be used with ApplyPassToFunction
If the call graph cannot be completely traced, private functions
should not be removed.
See `test_compatibility_with_apply_pass_to_function` for full
description of `DeadCodeElimination` and `ApplyPassToFunction`.
"""

@I.ir_module
class Before:
@R.function
def main(A: R.Tensor):
cls = Before

B = R.add(A, A)
C = cls.subroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subroutine(A: R.Tensor) -> R.Tensor:
cls = Before

B = R.add(A, A)
C = cls.subsubroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subsubroutine(A: R.Tensor) -> R.Tensor:
B = R.add(A, A)
C = R.multiply(B, B)
return B

@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor):
cls = Expected

B = R.add(A, A)
C = cls.subroutine(B)
return C

@R.function(private=True)
def subroutine(A: R.Tensor) -> R.Tensor:
cls = Expected

B = R.add(A, A)
C = cls.subsubroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subsubroutine(A: R.Tensor) -> R.Tensor:
B = R.add(A, A)
return B

assert tvm.relax.analysis.well_formed(Before)
After = tvm.ir.transform.ApplyPassToFunction(
tvm.relax.transform.DeadCodeElimination(),
"main|subsubroutine",
)(Before)
assert tvm.relax.analysis.well_formed(After)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 545e097

Please sign in to comment.