Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Allow DeadCodeElimination within ApplyPassToFunction #16801

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions src/relax/transform/dead_code_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,22 @@ class CallTracer : public ExprVisitor {
explicit CallTracer(IRModule mod) : mod_{mod}, called_funcs_{}, visiting_{} {}

void VisitExpr_(const GlobalVarNode* op) final {
called_funcs_.insert(GetRef<GlobalVar>(op));
auto func = mod_->Lookup(op->name_hint);
if (const auto* function_node = func.as<FunctionNode>()) {
VisitExpr(GetRef<Function>(function_node));
auto gvar = GetRef<GlobalVar>(op);
called_funcs_.insert(gvar);
if (auto func = mod_->functions.Get(gvar)) {
if (const auto* function_node = func.as<FunctionNode>()) {
VisitExpr(GetRef<Function>(function_node));
}
// else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein.
} else {
// The GlobalVar is not contained in the IRModule. While the
// input IRModule is ill-formed, this specific case is allowed
// for use with `relax.transform.ApplyPassToFunction`. If this
// occurs, DCE should not remove any internal functions from the
// IRModule, as their removal is only valid if we have a
// complete call graph.
all_callees_found_ = false;
}
// else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein.
}

void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); }
Expand All @@ -77,11 +87,24 @@ class CallTracer : public ExprVisitor {
VisitExpr(main_func);
}

bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; }
/* \brief Check if a function is unreachable
*
* \param gvar The function to be checked
*
* \return True if the function can be proven to be unreachable,
* either directly or indirectly, from an external caller.
* Otherwise, false.
*/
bool CheckIfProvablyUnreachable(const GlobalVar& gvar) const {
return all_callees_found_ && !called_funcs_.count(gvar);
}

private:
IRModule mod_;

/* \brief Whether all callees could be located within the IRModule */
bool all_callees_found_{true};

// Record the names of all encountered functions.
std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> called_funcs_;

Expand All @@ -101,7 +124,7 @@ IRModule RemoveUnusedFunctions(
// The tracer contains all user-provided entry functions, all
// externally-callable functions, and anything that is directly or
// indirectly accessible from an entry function.
if (!tracer.check_if_called(kv.first)) {
if (tracer.CheckIfProvablyUnreachable(kv.first)) {
to_remove.push_back(kv.first);
}
}
Expand Down
22 changes: 17 additions & 5 deletions tests/python/relax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ def pytest_configure(config):
"markers",
(
"skip_well_formed_check_before_transform: "
"Only check for well-formed IRModule after a transform"
"Suppress the default well-formed check before a IRModule transform"
),
)
config.addinivalue_line(
"markers",
(
"skip_well_formed_check_after_transform: "
"Suppress the default well-formed check after a IRModule transform"
),
)

Expand All @@ -54,15 +61,20 @@ def pytest_configure(config):
# `@pytest.mark.skip_well_formed_check_before_transform`
@pytest.fixture(autouse=True)
def apply_instrument_well_formed(unit_test_marks):

validate_before_transform = "skip_well_formed_check_before_transform" not in unit_test_marks
validate_after_transform = "skip_well_formed_check_after_transform" not in unit_test_marks

instrument = WellFormedInstrument(validate_before_transform=validate_before_transform)
current = tvm.transform.PassContext.current()
instruments = list(current.instruments)

if validate_before_transform or validate_after_transform:
instruments.append(
WellFormedInstrument(validate_before_transform=validate_before_transform)
)

override = tvm.transform.PassContext(
# Append the new instrument
instruments=[*current.instruments, instrument],
# With the new WellFormedInstrument appended
instruments=instruments,
# Forward all other parameters
opt_level=current.opt_level,
required_pass=current.required_pass,
Expand Down
155 changes: 155 additions & 0 deletions tests/python/relax/test_transform_dead_code_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

import pytest

import tvm
import tvm.testing
from tvm.relax.transform import DeadCodeElimination
Expand Down Expand Up @@ -507,5 +509,158 @@ def test_extern_func():
verify(before, before)


@pytest.mark.skip_well_formed_check_before_transform
@pytest.mark.skip_well_formed_check_after_transform
def test_compatibility_with_apply_pass_to_function():
"""DeadCodeElimination can be used with ApplyPassToFunction

The `ApplyPassToFunction` utility calls another transform, where
only the specified functions are exposed to the internal
transform. This intermediate does not contain `cls.subroutine`,
and so the intermediate is ill-formed.

In general, IRModule transformations may assume that their inputs
are well-formed. In specific cases, IRModule transformations may
accept IRModules that are ill-formed. The `DeadCodeElimination`
transform allows IRModule arguments that are ill-formed due to
a dangling GlobalVar.

After `DeadCodeElimination` completes, the resulting function is
inserted in the original IRModule, providing a well-formed output
from `ApplyPassToFunction`.

"""

@I.ir_module
class Before:
@R.function
def to_be_transformed(A: R.Tensor):
cls = Before

B = R.add(A, A)
C = cls.subroutine(B)
D = R.multiply(C, C)
return C

@R.function
def to_be_ignored(A: R.Tensor):
cls = Before

B = R.add(A, A)
C = cls.subroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subroutine(arg: R.Tensor) -> R.Tensor:
return R.add(arg, arg)

@I.ir_module
class Expected:
@R.function
def to_be_transformed(A: R.Tensor):
cls = Expected

B = R.add(A, A)
C = cls.subroutine(B)
return C

@R.function
def to_be_ignored(A: R.Tensor):
cls = Expected

B = R.add(A, A)
C = cls.subroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subroutine(arg: R.Tensor) -> R.Tensor:
return R.add(arg, arg)

# The well-formed check in conftest.py must be disabled, to avoid
# triggering on the ill-formed intermediate, so this unit test
# checks it explicitly.
assert tvm.relax.analysis.well_formed(Before)
After = tvm.ir.transform.ApplyPassToFunction(
tvm.relax.transform.DeadCodeElimination(),
"to_be_transformed",
)(Before)
assert tvm.relax.analysis.well_formed(After)
tvm.ir.assert_structural_equal(Expected, After)


@pytest.mark.skip_well_formed_check_before_transform
@pytest.mark.skip_well_formed_check_after_transform
def test_well_formed_output_with_restricted_scope():
"""DeadCodeElimination can be used with ApplyPassToFunction

If the call graph cannot be completely traced, private functions
should not be removed.

See `test_compatibility_with_apply_pass_to_function` for full
description of `DeadCodeElimination` and `ApplyPassToFunction`.

"""

@I.ir_module
class Before:
@R.function
def main(A: R.Tensor):
cls = Before

B = R.add(A, A)
C = cls.subroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subroutine(A: R.Tensor) -> R.Tensor:
cls = Before

B = R.add(A, A)
C = cls.subsubroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subsubroutine(A: R.Tensor) -> R.Tensor:
B = R.add(A, A)
C = R.multiply(B, B)
return B

@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor):
cls = Expected

B = R.add(A, A)
C = cls.subroutine(B)
return C

@R.function(private=True)
def subroutine(A: R.Tensor) -> R.Tensor:
cls = Expected

B = R.add(A, A)
C = cls.subsubroutine(B)
D = R.multiply(C, C)
return C

@R.function(private=True)
def subsubroutine(A: R.Tensor) -> R.Tensor:
B = R.add(A, A)
return B

assert tvm.relax.analysis.well_formed(Before)
After = tvm.ir.transform.ApplyPassToFunction(
tvm.relax.transform.DeadCodeElimination(),
"main|subsubroutine",
)(Before)
assert tvm.relax.analysis.well_formed(After)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()
Loading