Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Handle 'warp_execution' implied extend of threadIdx.x in VerifyGpuCode #11949

Merged
merged 1 commit into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,12 @@ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensori
/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";

/*!
* \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
* warp size.
*/
constexpr const char* warp_execution = "warp_execution";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
18 changes: 15 additions & 3 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,20 @@ namespace tir {

class ThreadExtentChecker : private StmtVisitor {
public:
static bool Check(const Stmt& stmt) {
static bool Check(const Stmt& stmt, int thread_warp_size) {
try {
ThreadExtentChecker().VisitStmt(stmt);
ICHECK(thread_warp_size > 0);
ThreadExtentChecker checker(thread_warp_size);
checker.VisitStmt(stmt);
return true;
} catch (const dmlc::Error& e) {
return false;
}
}

private:
explicit ThreadExtentChecker(int thread_warp_size) : thread_warp_size_(thread_warp_size) {}

void VisitStmt_(const ForNode* loop) {
runtime::ThreadScope thread_scope = GetThreadScope(loop);
if (IsThreadIdx(thread_scope)) {
Expand Down Expand Up @@ -64,6 +68,10 @@ class ThreadExtentChecker : private StmtVisitor {
}

void VisitStmt_(const BlockNode* block) {
int old_thread_idx_x = thread_idx_x;
if (block->annotations.count(attr::warp_execution)) {
thread_idx_x = thread_warp_size_;
}
if (Optional<Integer> low_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) {
if (Optional<Integer> high_inclusive =
Expand All @@ -77,11 +85,13 @@ class ThreadExtentChecker : private StmtVisitor {
}
}
StmtVisitor::VisitStmt_(block);
thread_idx_x = old_thread_idx_x;
}

int64_t thread_idx_x = 1;
int64_t thread_idx_y = 1;
int64_t thread_idx_z = 1;
int thread_warp_size_ = -1;
};

} // namespace tir
Expand All @@ -104,6 +114,7 @@ Integer Extract(const Target& target, const char* name) {
class VerifyGPUCodeNode : public PostprocNode {
public:
Map<String, PrimExpr> target_constraints_{nullptr};
int thread_warp_size_ = -1;

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
Expand All @@ -114,6 +125,7 @@ class VerifyGPUCodeNode : public PostprocNode {
{"max_vthread", Integer(8)},
{"max_vector_bytes", Integer(16)},
};
thread_warp_size_ = Extract(target, "thread_warp_size");
}

bool Verify(const IRModule& mod) const {
Expand All @@ -133,7 +145,7 @@ class VerifyGPUCodeNode : public PostprocNode {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
if (!tir::ThreadExtentChecker::Check(prim_func->body)) {
if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) {
return false;
}
IRModule lowered{nullptr};
Expand Down
Loading