Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add cp.async support for tir.if_then_else #13966

Merged
merged 23 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,13 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
// use size of argument list to indicate whether or not to use predicated cp.async
if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
} else {
this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size,
this->PrintExpr(op->args[5]));
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
} else if (op->op.same_as(builtin::ptx_wait_group())) {
Expand Down
31 changes: 31 additions & 0 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,5 +659,36 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
return asm_code;
}

std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset,
const std::string& bytes,
const std::string& predicate_value) {
std::string predicated_asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
int src_bytes = {pred_guard} ? {bytes} : 0;
__asm__ __volatile__(
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
);
}
)";
Replacer replacer;
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
replacer.register_rule("{pred_guard}", predicate_value);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

} // namespace codegen
} // namespace tvm
16 changes: 16 additions & 0 deletions src/target/source/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,22 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
const std::string& global_ptr,
const std::string& global_elem_offset, const std::string& bytes);

/*!
* \brief Print predicated ptx cp.async assembly string given parameters.
* \param shared_ptr: The pointer to the destination shared memory.
* \param shared_elem_offset: The offset into the shared memory.
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
* \param predicate_value: The value of predicate `@p`.
*/
std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset,
const std::string& bytes,
const std::string& predicate_value);

} // namespace codegen
} // namespace tvm

Expand Down
156 changes: 94 additions & 62 deletions src/tir/transforms/inject_ptx_async_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,73 +47,105 @@ class PTXAsyncCopyInjector : public StmtMutator {
return StmtMutator::VisitStmt_(attr);
}

Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false,
PrimExpr predicate_value = PrimExpr()) {
if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes());

const int indices_lanes = load->indices[0]->dtype.lanes();
const int bytes = indices_lanes * load->buffer->dtype.bytes();

if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type = GetPointerType(load->buffer->data->type_annotation);
ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
<< "Both store and load buffer should have a pointer type annotation.";

int index_factor = 1;
if (dst_elem_type.value() != src_elem_type.value()) {
// The only case where src and dst have different dtypes is when the dst shared memory
// is a byte buffer generated by merging dynamic shared memory.
ICHECK(store->buffer.scope() == "shared.dyn");
ICHECK(dst_elem_type.value() == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics according to their
// "value" dtype. Their "indices" are supposed to be applied after such pointer cast,
// for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value;
// To replace BufferStore/Load with cp.async, we need to multiply the store index by
// the byte size of the "value" dtype, to get the correct offset into the byte buffer.
index_factor = src_elem_type->bytes();
}

if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)};
// use arguments size to indicate whether or not to use predicated cp.async
if (predicated) {
args.push_back(predicate_value);
}
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args));
}

// Predicated load don't support vectorized indexing.
if (!predicated) {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();

auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
}
}
}
return StmtMutator::VisitStmt_(store);
}

Stmt VisitStmt_(const BufferStoreNode* store) {
if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) {
if (auto* load = store->value.as<BufferLoadNode>()) {
if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes());

const int indices_lanes = load->indices[0]->dtype.lanes();
const int bytes = indices_lanes * load->buffer->dtype.bytes();

if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type = GetPointerType(load->buffer->data->type_annotation);
ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
<< "Both store and load buffer should have a pointer type annotation.";

int index_factor = 1;
if (dst_elem_type.value() != src_elem_type.value()) {
// The only case where src and dst have different dtypes is when the dst shared memory
// is a byte buffer generated by merging dynamic shared memory.
ICHECK(store->buffer.scope() == "shared.dyn");
ICHECK(dst_elem_type.value() == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics according to their
// "value" dtype. Their "indices" are supposed to be applied after such pointer cast,
// for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value;
// To replace BufferStore/Load with cp.async, we need to multiply the store index by
// the byte size of the "value" dtype, to get the correct offset into the byte buffer.
index_factor = src_elem_type->bytes();
return InjectPTX(load, store);
} else if (auto* call = store->value.as<CallNode>()) {
// tir.if_then_else is a call to tir::builtin::if_then_else()
if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) {
if (auto* load = call->args[1].as<BufferLoadNode>()) {
// Only default value of 0 is supported since 0 is the default value used by cp.async
// ptx. @see section 9.7.8.22.3. of
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations
bool else_value_is_zero = false;
if (auto* b = call->args[2].as<BroadcastNode>()) {
if (auto* f = b->value.as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
}
}

if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
if (auto* f = call->args[2].as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
}

// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();

auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
if (else_value_is_zero) {
return InjectPTX(load, store, true, call->args[0]);
}
}
}
Expand Down
Loading