Skip to content

Commit

Permalink
[MetaSchedule] Handle 'warp_execution' implied extend of threadIdx.x …
Browse files Browse the repository at this point in the history
…in VerifyGpuCode
  • Loading branch information
vinx13 committed Jun 29, 2022
1 parent c9d0d25 commit 07a4cbc
Show file tree
Hide file tree
Showing 3 changed files with 416 additions and 45 deletions.
6 changes: 6 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,12 @@ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensori
/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";

/*!
* \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
* warp size.
*/
constexpr const char* warp_execution = "warp_execution";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
17 changes: 14 additions & 3 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@ namespace tir {

class ThreadExtentChecker : private StmtVisitor {
public:
static bool Check(const Stmt& stmt) {
static bool Check(const Stmt& stmt, int thread_warp_size) {
try {
ThreadExtentChecker().VisitStmt(stmt);
ICHECK(thread_warp_size > 0);
ThreadExtentChecker(thread_warp_size).VisitStmt(stmt);
return true;
} catch (const dmlc::Error& e) {
return false;
}
}

private:
explicit ThreadExtentChecker(int thread_warp_size) : thread_warp_size_(thread_warp_size) {}

void VisitStmt_(const ForNode* loop) {
runtime::ThreadScope thread_scope = GetThreadScope(loop);
if (IsThreadIdx(thread_scope)) {
Expand Down Expand Up @@ -64,6 +67,10 @@ class ThreadExtentChecker : private StmtVisitor {
}

void VisitStmt_(const BlockNode* block) {
int old_thread_idx_x = thread_idx_x;
if (block->annotations.count(attr::warp_execution)) {
thread_idx_x = thread_warp_size_;
}
if (Optional<Integer> low_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) {
if (Optional<Integer> high_inclusive =
Expand All @@ -77,11 +84,13 @@ class ThreadExtentChecker : private StmtVisitor {
}
}
StmtVisitor::VisitStmt_(block);
thread_idx_x = old_thread_idx_x;
}

int64_t thread_idx_x = 1;
int64_t thread_idx_y = 1;
int64_t thread_idx_z = 1;
int thread_warp_size_ = -1;
};

} // namespace tir
Expand All @@ -104,6 +113,7 @@ Integer Extract(const Target& target, const char* name) {
class VerifyGPUCodeNode : public PostprocNode {
public:
Map<String, PrimExpr> target_constraints_{nullptr};
int thread_warp_size_ = -1;

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
Expand All @@ -114,6 +124,7 @@ class VerifyGPUCodeNode : public PostprocNode {
{"max_vthread", Integer(8)},
{"max_vector_bytes", Integer(16)},
};
thread_warp_size_ = Extract(target, "thread_warp_size");
}

bool Verify(const IRModule& mod) const {
Expand All @@ -133,7 +144,7 @@ class VerifyGPUCodeNode : public PostprocNode {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
if (!tir::ThreadExtentChecker::Check(prim_func->body)) {
if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) {
return false;
}
IRModule lowered{nullptr};
Expand Down
Loading

0 comments on commit 07a4cbc

Please sign in to comment.