Skip to content

Commit

Permalink
[Fix] Misc fixes from tlc-pack#32, tlc-pack#42, tlc-pack#48, tlc-pack#54
Browse files Browse the repository at this point in the history
 (tlc-pack#57)

* [FIX][Pass] concurrent modification in RemoveUnusedVars (tlc-pack#32)

* [Pass] Fix concurrent modification in RemoveUnusedVars

When running RemoveUnusedVars (i.e. remove_all_unused), in some cases the map users will raise Concurrent modification error. This commit fixed it by changing the logic to "iterate the map first and update it later".

* change the algorithm by store keys first

* Add an ICHECK before getting map value

* change the information of the ICHECK

* fix find include path (tlc-pack#42)

Co-authored-by: Hongyi Jin <[email protected]>

* [BUG] ExternFunc is not considered in attach_global_symbol.cc (tlc-pack#48)

* [Cherry-Pick] Minor fix for TaskScheduler and VerifyGPUCode (tlc-pack#54)

* [Fix] Task scheduler error prompt upon build/run failure (#13601)

* [Fix] Use proper target in VerifyGPUCode (#13548)

Previously, the VerifyGPUCode post-processor uses hardcoded target `Target("cuda")` for applying pass LowerIntrin. This is a bit problematic since the actual target can be other GPU target (e.g., Metal). Therefore, this PR changes the hardcoded target to be the actual target.

Co-authored-by: Chaosfan <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
  • Loading branch information
5 people authored and tqchen committed Dec 30, 2022
1 parent bf91a03 commit 41a0f88
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 14 deletions.
8 changes: 5 additions & 3 deletions python/tvm/_ffi/libinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ def find_include_path(name=None, search_path=None, optional=False):
include_path : list(string)
List of all found paths to header files.
"""
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
source_dir = os.path.join(ffi_dir, "..", "..", "..")

if os.environ.get("TVM_HOME", None):
source_dir = os.environ["TVM_HOME"]
else:
ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
source_dir = os.path.join(ffi_dir, "..", "..", "..")
third_party_dir = os.path.join(source_dir, "3rdparty")

header_path = []
Expand Down
11 changes: 6 additions & 5 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,20 @@ Integer Extract(const Target& target, const char* name) {
/*! \brief Verify the correctness of the generated GPU code. */
class VerifyGPUCodeNode : public PostprocNode {
public:
Target target_{nullptr};
Map<String, PrimExpr> target_constraints_{nullptr};
int thread_warp_size_ = -1;

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
Target target = context->target.value();
this->target_ = context->target.value();
this->target_constraints_ = Map<String, PrimExpr>{
{"max_shared_memory_per_block", Extract(target, "max_shared_memory_per_block")},
{"max_threads_per_block", Extract(target, "max_threads_per_block")},
{"max_shared_memory_per_block", Extract(this->target_, "max_shared_memory_per_block")},
{"max_threads_per_block", Extract(this->target_, "max_threads_per_block")},
{"max_vthread", Integer(8)},
{"max_vector_bytes", Integer(16)},
};
thread_warp_size_ = Extract(target, "thread_warp_size").IntValue();
thread_warp_size_ = Extract(this->target_, "thread_warp_size").IntValue();
}

bool Verify(const IRModule& mod) const {
Expand Down Expand Up @@ -180,7 +181,7 @@ class VerifyGPUCodeNode : public PostprocNode {
transform::PassContext pass_ctx = transform::PassContext::Current();
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin
f = WithAttr(f, tvm::attr::kTarget, this->target_); // Required for LowerIntrin
bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
if (noalias) {
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
Expand Down
4 changes: 3 additions & 1 deletion src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array<RunnerResult>& r
std::string err = error_msg.value();
TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) //
<< "[Task #" << task_id << ": " << name << "] Trial #" << trials
<< ": Error in building:\n"
<< ": Error in "
<< (builder_result->error_msg.defined() ? "building" : "running")
<< ":\n"
<< err << "\n"
<< tir::AsTVMScript(sch->mod()) << "\n"
<< Concat(sch->trace().value()->AsPython(false), "\n");
Expand Down
15 changes: 10 additions & 5 deletions src/relax/ir/binding_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class RemoveUnusedVars : public ExprMutator {
do {
prev_size = unused.size();

std::vector<Var> users_keys;
for (const auto& kv : users) {
// var -> [users...]
// var is unused iff
Expand All @@ -207,17 +208,21 @@ class RemoveUnusedVars : public ExprMutator {
if (kv.second.empty() && // kv.first is not used by fn outputs.
fn_outputs.end() == std::find(fn_outputs.begin(), fn_outputs.end(), kv.first)) {
unused.push_back(kv.first);
} else {
users_keys.push_back(kv.first);
}
}

for (size_t i = prev_size; i < unused.size(); ++i) {
users.erase(unused[i]);
// remove def site.
for (auto kv : users) { // remove use site.
auto it = std::find(kv.second.begin(), kv.second.end(), unused[i]);
if (it != kv.second.end()) {
kv.second.erase(it);
users.Set(kv.first, std::move(kv.second));
for (const auto& key: users_keys) { // remove use site.
ICHECK(users.count(key)) << "the key " << key << " is expected to be in the mapping users.";
Array<Var> cur_users = users[key];
auto it = std::find(cur_users.begin(), cur_users.end(), unused[i]);
if (it != cur_users.end()) {
cur_users.erase(it);
users.Set(key, std::move(cur_users));
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/relax/transform/attach_global_symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class GlobalSymbolAttacher {
func = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol", p.first->name_hint);
} else if (auto* relax_func = func.as<FunctionNode>()) {
func = WithAttr(GetRef<Function>(relax_func), "global_symbol", p.first->name_hint);
} else if (auto* extern_func = func.as<ExternFuncNode>()) {
func = WithAttr(GetRef<ExternFunc>(extern_func), "global_symbol", p.first->name_hint);
} else {
LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey();
throw;
Expand Down

0 comments on commit 41a0f88

Please sign in to comment.