Skip to content

Commit

Permalink
[TIR] Handle nullptr returned by FindEntryFunc (#13852)
Browse files Browse the repository at this point in the history
The FindEntryFunc function can return a null pointer. In development I got this situation, which appears as a segfault.
  • Loading branch information
Icemist authored Feb 1, 2023
1 parent 7db77ad commit a1229f6
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 27 deletions.
39 changes: 20 additions & 19 deletions src/tir/analysis/stmt_finding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,30 +111,31 @@ const BlockNode* FindAnchorBlock(const IRModule& mod) {
std::vector<const BlockNode*> blocks;
};

auto prim_func = FindEntryFunc(mod, nullptr);
if (auto prim_func = FindEntryFunc(mod, nullptr)) {
ReductionBlockCollector collector;
collector(prim_func->body);

ReductionBlockCollector collector;
collector(prim_func->body);
const auto& candidates = collector.blocks;

const auto& candidates = collector.blocks;

if (candidates.empty()) {
return nullptr;
} else if (candidates.size() == 1) {
return candidates[0];
}
if (candidates.empty()) {
return nullptr;
} else if (candidates.size() == 1) {
return candidates[0];
}

double best_flops = -1;
int best_idx = 0;
for (size_t i = 0; i < candidates.size(); ++i) {
auto loop = GetEnclosingLoop(candidates[i], prim_func->body);
auto flops = EstimateTIRFlops(loop);
if (flops > best_flops) {
best_flops = flops;
best_idx = i;
double best_flops = -1;
int best_idx = 0;
for (size_t i = 0; i < candidates.size(); ++i) {
auto loop = GetEnclosingLoop(candidates[i], prim_func->body);
auto flops = EstimateTIRFlops(loop);
if (flops > best_flops) {
best_flops = flops;
best_idx = i;
}
}
return candidates[best_idx];
}
return candidates[best_idx];
return nullptr;
}

TVM_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ConcreteScheduleNode : public ScheduleNode {
// `error_render_level_` is not visited
// `symbol_table_` is not visited
// `analyzer_` is not visited
// `rgnd_state_` is not visited
// `rand_state_` is not visited
}

virtual ~ConcreteScheduleNode() = default;
Expand Down
10 changes: 6 additions & 4 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,12 @@ inline std::unordered_set<std::string> GetBlockNames(const IRModule& mod) {
std::unordered_set<std::string> block_names;
};

auto prim_func = tir::FindEntryFunc(mod, nullptr);
BlockNameCollector collector;
collector(prim_func->body);
return collector.block_names;
if (auto prim_func = tir::FindEntryFunc(mod, nullptr)) {
BlockNameCollector collector;
collector(prim_func->body);
return collector.block_names;
}
return {};
}

/*! \brief Query if the given block name exists in the module associated with the schedule */
Expand Down
34 changes: 31 additions & 3 deletions tests/python/unittest/test_meta_schedule_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tvm.ir.module import IRModule
from tvm.script import tir as T
from tvm.tir import Schedule

from tvm import relay

# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
# fmt: off
Expand Down Expand Up @@ -93,10 +93,10 @@ def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Sched
return sch


def _create_tmp_database(tmpdir: str) -> ms.database.JSONDatabase:
def _create_tmp_database(tmpdir: str, mod_eq: str = "structural") -> ms.database.JSONDatabase:
path_workload = osp.join(tmpdir, "workloads.json")
path_tuning_record = osp.join(tmpdir, "tuning_records.json")
return ms.database.JSONDatabase(path_workload, path_tuning_record)
return ms.database.JSONDatabase(path_workload, path_tuning_record, module_equality=mod_eq)


def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord):
Expand Down Expand Up @@ -583,5 +583,33 @@ def test_json_database_get_top_k(k, expected):
assert result == expected


def MatmulFunc() -> IRModule:
a = relay.var("a", relay.TensorType((1024, 1024), "float32"))
b = relay.var("b", relay.TensorType((1024, 1024), "float32"))
func = relay.Function([a, b], relay.nn.matmul(a, b))
return tvm.IRModule.from_expr(func)


def MatmulPrimFunc() -> IRModule:
return Matmul


@pytest.mark.parametrize("f_mod", [MatmulPrimFunc, MatmulFunc])
@pytest.mark.parametrize("mod_eq", ["structural", "ignore-ndarray", "anchor-block"])
def test_json_database_commit_workload(f_mod, mod_eq):
mod: IRModule = f_mod()
with tempfile.TemporaryDirectory() as tmpdir:
database = _create_tmp_database(tmpdir, mod_eq)
database.commit_workload(mod)


@pytest.mark.parametrize("f_mod", [MatmulPrimFunc, MatmulFunc])
@pytest.mark.parametrize("mod_eq", ["structural", "ignore-ndarray", "anchor-block"])
def test_memory_database_commit_workload(f_mod, mod_eq):
mod: IRModule = f_mod()
database = ms.database.MemoryDatabase(module_equality=mod_eq)
database.commit_workload(mod)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit a1229f6

Please sign in to comment.