From a1229f6fe4a533308e6f6c8e2c965e356b83f5bf Mon Sep 17 00:00:00 2001 From: Alexey Date: Wed, 1 Feb 2023 17:05:18 +0300 Subject: [PATCH] [TIR] Handle nullptr returned by FindEntryFunc (#13852) The FindEntryFunc function can return a null pointer. In development I got this situation, which appears as a segfault. --- src/tir/analysis/stmt_finding.cc | 39 ++++++++++--------- src/tir/schedule/concrete_schedule.h | 2 +- src/tir/schedule/utils.h | 10 +++-- .../unittest/test_meta_schedule_database.py | 34 ++++++++++++++-- 4 files changed, 58 insertions(+), 27 deletions(-) diff --git a/src/tir/analysis/stmt_finding.cc b/src/tir/analysis/stmt_finding.cc index 1d8cb462c14b..300f779ae951 100644 --- a/src/tir/analysis/stmt_finding.cc +++ b/src/tir/analysis/stmt_finding.cc @@ -111,30 +111,31 @@ const BlockNode* FindAnchorBlock(const IRModule& mod) { std::vector blocks; }; - auto prim_func = FindEntryFunc(mod, nullptr); + if (auto prim_func = FindEntryFunc(mod, nullptr)) { + ReductionBlockCollector collector; + collector(prim_func->body); - ReductionBlockCollector collector; - collector(prim_func->body); + const auto& candidates = collector.blocks; - const auto& candidates = collector.blocks; - - if (candidates.empty()) { - return nullptr; - } else if (candidates.size() == 1) { - return candidates[0]; - } + if (candidates.empty()) { + return nullptr; + } else if (candidates.size() == 1) { + return candidates[0]; + } - double best_flops = -1; - int best_idx = 0; - for (size_t i = 0; i < candidates.size(); ++i) { - auto loop = GetEnclosingLoop(candidates[i], prim_func->body); - auto flops = EstimateTIRFlops(loop); - if (flops > best_flops) { - best_flops = flops; - best_idx = i; + double best_flops = -1; + int best_idx = 0; + for (size_t i = 0; i < candidates.size(); ++i) { + auto loop = GetEnclosingLoop(candidates[i], prim_func->body); + auto flops = EstimateTIRFlops(loop); + if (flops > best_flops) { + best_flops = flops; + best_idx = i; + } } + return candidates[best_idx]; } - return candidates[best_idx]; + return nullptr; } TVM_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 95d5fe9c2e44..44d9e9b69c94 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -56,7 +56,7 @@ class ConcreteScheduleNode : public ScheduleNode { // `error_render_level_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visited - // `rgnd_state_` is not visited + // `rand_state_` is not visited } virtual ~ConcreteScheduleNode() = default; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index d40906209fb9..a6aced4632fb 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -456,10 +456,12 @@ inline std::unordered_set GetBlockNames(const IRModule& mod) { std::unordered_set block_names; }; - auto prim_func = tir::FindEntryFunc(mod, nullptr); - BlockNameCollector collector; - collector(prim_func->body); - return collector.block_names; + if (auto prim_func = tir::FindEntryFunc(mod, nullptr)) { + BlockNameCollector collector; + collector(prim_func->body); + return collector.block_names; + } + return {}; } /*! \brief Query if the given block name exists in the module associated with the schedule */ diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index d4681d40111b..806ea2d1827b 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -30,7 +30,7 @@ from tvm.ir.module import IRModule from tvm.script import tir as T from tvm.tir import Schedule - +from tvm import relay # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off @@ -93,10 +93,10 @@ def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Sched return sch -def _create_tmp_database(tmpdir: str) -> ms.database.JSONDatabase: +def _create_tmp_database(tmpdir: str, mod_eq: str = "structural") -> ms.database.JSONDatabase: path_workload = osp.join(tmpdir, "workloads.json") path_tuning_record = osp.join(tmpdir, "tuning_records.json") - return ms.database.JSONDatabase(path_workload, path_tuning_record) + return ms.database.JSONDatabase(path_workload, path_tuning_record, module_equality=mod_eq) def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord): @@ -583,5 +583,33 @@ def test_json_database_get_top_k(k, expected): assert result == expected +def MatmulFunc() -> IRModule: + a = relay.var("a", relay.TensorType((1024, 1024), "float32")) + b = relay.var("b", relay.TensorType((1024, 1024), "float32")) + func = relay.Function([a, b], relay.nn.matmul(a, b)) + return tvm.IRModule.from_expr(func) + + +def MatmulPrimFunc() -> IRModule: + return Matmul + + +@pytest.mark.parametrize("f_mod", [MatmulPrimFunc, MatmulFunc]) +@pytest.mark.parametrize("mod_eq", ["structural", "ignore-ndarray", "anchor-block"]) +def test_json_database_commit_workload(f_mod, mod_eq): + mod: IRModule = f_mod() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir, mod_eq) + database.commit_workload(mod) + + +@pytest.mark.parametrize("f_mod", [MatmulPrimFunc, MatmulFunc]) +@pytest.mark.parametrize("mod_eq", ["structural", "ignore-ndarray", "anchor-block"]) +def test_memory_database_commit_workload(f_mod, mod_eq): + mod: IRModule = f_mod() + database = ms.database.MemoryDatabase(module_equality=mod_eq) + database.commit_workload(mod) + + if __name__ == "__main__": tvm.testing.main()