Skip to content

Commit

Permalink
reformat inject_ptx_async_copy.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Feb 16, 2023
1 parent 28b866f commit 1de7625
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions src/tir/transforms/inject_ptx_async_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class PTXAsyncCopyInjector : public StmtMutator {
return StmtMutator::VisitStmt_(attr);
}

Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store,
bool predicated = false, PrimExpr predicate_value = PrimExpr()) {
Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false,
PrimExpr predicate_value = PrimExpr()) {
if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes());
Expand Down Expand Up @@ -79,8 +79,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
Array<PrimExpr> args = {store->buffer->data,
tir::Mul(dst_offset, PrimExpr(index_factor)),
Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)};
// use arguments size to indicate whether or not to use predicated cp.async
if (predicated) {
Expand Down Expand Up @@ -115,10 +114,9 @@ class PTXAsyncCopyInjector : public StmtMutator {
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
}
}
Expand All @@ -134,9 +132,9 @@ class PTXAsyncCopyInjector : public StmtMutator {
// tir.if_then_else is a call to tir::builtin::if_then_else()
if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) {
if (auto* load = call->args[1].as<BufferLoadNode>()) {
// Only default value of 0 is supported since 0 is the default value used by cp.async ptx.
// @see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations
// section 9.7.8.22.3.
// Only default value of 0 is supported since 0 is the default value used by cp.async
// ptx. @see section 9.7.8.22.3. of
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations
bool else_value_is_zero = false;
if (auto* b = call->args[2].as<BroadcastNode>()) {
if (auto* f = b->value.as<FloatImmNode>()) {
Expand Down

0 comments on commit 1de7625

Please sign in to comment.