Skip to content

Commit

Permalink
[Fix] Use proper target in VerifyGPUCode (#13548)
Browse files Browse the repository at this point in the history
Previously, the VerifyGPUCode post-processor uses hardcoded target `Target("cuda")` for applying pass LowerIntrin. This is a bit problematic since the actual target can be other GPU target (e.g., Metal). Therefore, this PR changes the hardcoded target to be the actual target.
  • Loading branch information
MasterJH5574 authored Dec 4, 2022
1 parent 7bc41ec commit 3a81aef
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,20 @@ Integer Extract(const Target& target, const char* name) {
/*! \brief Verify the correctness of the generated GPU code. */
class VerifyGPUCodeNode : public PostprocNode {
public:
Target target_{nullptr};
Map<String, PrimExpr> target_constraints_{nullptr};
int thread_warp_size_ = -1;

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
Target target = context->target.value();
this->target_ = context->target.value();
this->target_constraints_ = Map<String, PrimExpr>{
{"max_shared_memory_per_block", Extract(target, "max_shared_memory_per_block")},
{"max_threads_per_block", Extract(target, "max_threads_per_block")},
{"max_shared_memory_per_block", Extract(this->target_, "max_shared_memory_per_block")},
{"max_threads_per_block", Extract(this->target_, "max_threads_per_block")},
{"max_vthread", Integer(8)},
{"max_vector_bytes", Integer(16)},
};
thread_warp_size_ = Extract(target, "thread_warp_size").IntValue();
thread_warp_size_ = Extract(this->target_, "thread_warp_size").IntValue();
}

bool Verify(const IRModule& mod) const {
Expand Down Expand Up @@ -180,7 +181,7 @@ class VerifyGPUCodeNode : public PostprocNode {
transform::PassContext pass_ctx = transform::PassContext::Current();
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin
f = WithAttr(f, tvm::attr::kTarget, this->target_); // Required for LowerIntrin
bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
if (noalias) {
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
Expand Down

0 comments on commit 3a81aef

Please sign in to comment.