Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add 1D TMA UBLKCP #3739

Merged
merged 6 commits into from
Jan 23, 2025
Merged

add 1D TMA UBLKCP #3739

merged 6 commits into from
Jan 23, 2025

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Jan 21, 2025

Why add this?
There are two types of TMA ops:

 (1) 1D TMA with ptx `cp.async.bulk` and sass `UBLKCP`, added in this PR
 (2) n-D TMA with ptx `cp.async.bulk.tensor.{1-5}d` and sass `UTMALDG`, already exist.

The second version supports copying n-d data, with each dimension limited to a maximum size of 256. It requires the use of a tensor map.
The first version is designed for copying 1-d data, with lengths exceeding 256. It does not require a tensor map, making it better suited for non-matmul fusions, where 2D tiling is unnecessary. In these scenarios, each block typically loads n elements, with n ranging from 1K to 32K or more. This version requires only a single TMA instruction.

Code changes
(1) Adding a new loading type LoadStoreOpType::CpAsyncBulk and lowered to cp.async.bulk, the existing n-D TMA uses LoadStoreOpType::CpAsyncBulkTensorTile and lowered to cp.async.bulk.tensor.nd
(2) Added a unit test loading 512 elements in one TMA instruction.
(3) 1D TMA is lowered to:

struct CpAsyncBulkG2SIndex {
  const void* raw_gmem_addr;
  uint32_t bytes;
  uint32_t mbarrier;
};

__device__ inline void cpAsyncBulkG2S(
    const CpAsyncBulkG2SIndex& src,
    uint32_t smem_addr) {
  asm volatile(
      "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n"
      :
      : "r"(smem_addr),
        "l"(src.raw_gmem_addr),
        "r"(src.bytes),
        "r"(src.mbarrier)
      : "memory");
}

xxx
  float* T4 = reinterpret_cast<float*>(array + smem_offset + 0);
  uint64_t* T6 = reinterpret_cast<uint64_t*>(array + smem_offset + 2064);
  mbarrier::init(toSmem(T6), 1U);
  __syncthreads();
  if (b2) {
    uint64_t i3;
    i3 = mbarrier::arriveExpectTX(toSmem(T6), 2048U);
    Hopper::cpAsyncBulkG2S((Hopper::CpAsyncBulkG2SIndex{ (T1.data + i0), 2048U, toSmem(T6) }), toSmem(T4));
    mbarrier::wait(toSmem(T6), i3);
  }
  __syncthreads();
  mbarrier::inval(toSmem(T6));
xxx

Copy link

github-actions bot commented Jan 21, 2025

PR Reviewer Guide 🔍

(Review updated until commit 5365076)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Logic Change

The function genCpAsyncBulkMaybeTensorTile has been modified to handle both CpAsyncBulk and CpAsyncBulkTensorTile operations. Review the logic changes to ensure correctness.

void genCpAsyncBulkMaybeTensorTile(const LoadStoreOp* ldst) {
  auto in = ldst->in()->as<kir::TensorIndex>();
  auto out = ldst->out()->as<kir::TensorIndex>();

  auto in_tv = in->view();
  auto out_tv = out->view();

  kir::TensorIndex* gmem_ti = nullptr;
  kir::TensorIndex* smem_ti = nullptr;
  std::string func_name;

  bool is_tensor_tile =
      ldst->opType() == LoadStoreOpType::CpAsyncBulkTensorTile;

  if (out->view()->getMemoryType() == MemoryType::Shared) {
    func_name = is_tensor_tile ? "Hopper::cpAsyncBulkTensorTileG2S"
                               : "Hopper::cpAsyncBulkG2S";
    NVF_ERROR(
        in_tv->getMemoryType() == MemoryType::Global,
        "Expected input in global for G2S operation");
    smem_ti = out;
    gmem_ti = in;
  } else {
    NVF_ERROR(
        in_tv->getMemoryType() == MemoryType::Shared,
        "Expected input in shared for S2G operation");
    NVF_ERROR(
        out_tv->getMemoryType() == MemoryType::Global,
        "Expected input in shared for S2G operation");
    func_name = is_tensor_tile ? "Hopper::cpAsyncBulkTensorTileS2G"
                               : "Hopper::cpAsyncBulkS2G";
    smem_ti = in;
    gmem_ti = out;
  }
Potential Bug

The function getCpAsyncBulkGmemIndex has been modified to handle both CpAsyncBulk and CpAsyncBulkTensorTile operations. Verify that the indexing logic is correct for both cases.

NVF_ERROR(
    GpuLower::current()->consumerToTMAInfo().count(consumer_tv),
    "Unable to find TMA info for consumer_tv: ",
    consumer_tv->toString());
const TMAInfo& tma_info =
    GpuLower::current()->consumerToTMAInfo().at(consumer_tv);
int64_t dim = (int64_t)tma_info.dims().size();
Val* expected_bytes = SimplifyingIrBuilder::maybeCastExpr(
    DataType::UInt32, tma_info.tileSizeBytes());
expected_bytes =
    GpuLower::current()->commonScalarMap().hoistScalar(expected_bytes, loops);
Val* index = nullptr;

// 1D TMA without tensor map
if (ldst->opType() == LoadStoreOpType::CpAsyncBulk) {
  NVF_ERROR(dim == 1L, "1D TMA but got more than one indices.")
  if (is_load) {
    std::stringstream ss;
    ss << "Hopper::CpAsyncBulkG2SIndex";
    auto gmem_address = getProducerIndex(
        producer_tv, consumer_tv, loops, rotated_loops, {}, true);
    index = IrBuilder::structExpr(
        {{"raw_gmem_addr", gmem_address},
         {"bytes", expected_bytes},
         {"mbarrier", mbarrier}},
        ss.str());
  } else {
    NVF_THROW("S2G for CpAsyncBulk is not implemented yet.")
  }
} else {
  // ND TMA with tensor map
  ValGroups groups_to_index = tma_info.getTMADomain();
  // TensorIndexer needs IterDomain instead of ValGroup to work around
  // the resize indexing issue
  std::vector<IterDomain*> ids_to_index;
  ids_to_index.reserve(groups_to_index.size());
  const auto tma_all_ids = is_load ? consumer_tv->domain()->allIDs()
                                   : producer_tv->domain()->allIDs();
  for (const auto& group : groups_to_index) {
    auto it = std::find_if(
        tma_all_ids.begin(), tma_all_ids.end(), [&](IterDomain* gmem_id) {
          return group->has(gmem_id);
        });
    if (it != tma_all_ids.end()) {
      ids_to_index.push_back(*it);
    } else {
      ids_to_index.push_back(group->front()->as<IterDomain>());
    }
  }

  const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
  auto indices_inner_to_outer =
      indexer.getIndexFor(ldst, !is_load, ids_to_index, loops);

  auto coordinate = IrBuilder::arrayExpr(indices_inner_to_outer);
  auto descriptor = tma_info.tensorMap();
  if (is_load) {
    std::stringstream ss;
    ss << "Hopper::CpAsyncBulkTensorTileG2SIndex<" << dim << ">";
    index = IrBuilder::structExpr(
        {{"descriptor", IrBuilder::addressExpr(descriptor)},
         {"coordinate", coordinate},
         {"mbarrier", mbarrier}},
        ss.str());
  } else {
    std::stringstream ss;
    ss << "Hopper::CpAsyncBulkTensorTileS2GIndex<" << dim << ">";
    index = IrBuilder::structExpr(
        {{"descriptor", IrBuilder::addressExpr(descriptor)},
         {"coordinate", coordinate}},
        ss.str());
  }
}

index = GpuLower::current()->commonScalarMap().hoistScalar(index, loops);

auto is_multiple_of_16B = SimplifyingIrBuilder::eqExpr(
Logic Change

The function getConsumerToTMAInfoMap has been modified to handle both CpAsyncBulk and CpAsyncBulkTensorTile operations. Review the logic changes to ensure correctness.

std::unordered_map<TensorView*, const TMAInfo> getConsumerToTMAInfoMap(
    Fusion* fusion) {
  std::unordered_map<TensorView*, const TMAInfo> result;
  for (Expr* expr : fusion->exprs()) {
    if (auto ldst = dynamic_cast<LoadStoreOp*>(expr)) {
      if (ldst->opType() == LoadStoreOpType::CpAsyncBulkTensorTile ||
          ldst->opType() == LoadStoreOpType::CpAsyncBulk) {
        NVF_ERROR(
            result.emplace(ir_utils::getTvOutput(ldst), getTMAInfo(ldst))
                .second,
            "Ambiguous TMA information, likely something is wrong in the Fusion IR");
      }
    }
  }
  return result;
}

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

} else {
ids_to_index.push_back(group->front()->as<IterDomain>());
NVF_ERROR(true, "S2G not implemented yet.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: NVF_THROW

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl merged commit 8ea30c7 into main Jan 23, 2025
18 checks passed
@liqiangxl liqiangxl deleted the llu/1dtma branch January 23, 2025 20:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants