Skip to content

Commit

Permalink
Move contractions tiling logic outside of Contractions_NT (rapidsai#837)
Browse files Browse the repository at this point in the history
The main functionality of Contractions_NT involves loading tiles of data into shared memory to enable fast GEMM-like kernels. In practice, this requires keeping track of tiles of data (2D submatrices of a bigger matrix) and distributing the data in the tiles over shared memory and registers of thread in a thread block.

Currently, Contractions_NT performs indexing logic for both: 
1. The distribution of data in a tile over registers and shared memory; 
2. Looping over tiles of data in a 2D matrix.

In this PR, we move functionality 2 out of Contractions_NT. Moving over the tiles of data and keeping track of the grid stride loop is now the responsibility of the calling code. 

Splitting these responsibilities is helpful when non-trivial tiling logic is required, as in the upcoming sparseL2NN functionality.

**Note**: This PR also cleans up one unfortunate wart in the current implementation. Depending on which of the two overloaded constructors was called, the tiling logic was transposed leading to extremely difficult to track down bugs.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#837
  • Loading branch information
Allard Hendriksen authored Jan 27, 2023
1 parent afece4f commit c58d00a
Show file tree
Hide file tree
Showing 23 changed files with 605 additions and 230 deletions.
15 changes: 13 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -284,7 +284,18 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/cluster/update_centroids_double.cu
src/distance/cluster/cluster_cost_float.cu
src/distance/cluster/cluster_cost_double.cu
src/distance/neighbors/refine.cu
src/distance/neighbors/refine_d_uint64_t_float.cu
src/distance/neighbors/refine_d_uint64_t_int8_t.cu
src/distance/neighbors/refine_d_uint64_t_uint8_t.cu
src/distance/neighbors/refine_h_uint64_t_float.cu
src/distance/neighbors/refine_h_uint64_t_int8_t.cu
src/distance/neighbors/refine_h_uint64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_d_uint64_t_float.cu
src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu
src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_h_uint64_t_float.cu
src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu
src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu
src/distance/neighbors/ivfpq_search.cu
src/distance/cluster/kmeans_fit_float.cu
src/distance/cluster/kmeans_fit_double.cu
Expand Down
10 changes: 5 additions & 5 deletions cpp/bench/distance/distance_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

namespace raft::bench::distance {

struct distance_inputs {
struct distance_params {
int m, n, k;
bool isRowMajor;
}; // struct distance_inputs
}; // struct distance_params

template <typename T, raft::distance::DistanceType DType>
struct distance : public fixture {
distance(const distance_inputs& p)
distance(const distance_params& p)
: params(p),
x(p.m * p.k, stream),
y(p.n * p.k, stream),
Expand Down Expand Up @@ -63,13 +63,13 @@ struct distance : public fixture {
}

private:
distance_inputs params;
distance_params params;
rmm::device_uvector<T> x, y, out;
rmm::device_uvector<char> workspace;
size_t worksize;
}; // struct Distance

const std::vector<distance_inputs> dist_input_vecs{
const std::vector<distance_params> dist_input_vecs{
{32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true},
{256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true},
{16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true},
Expand Down
3 changes: 3 additions & 0 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#include <raft/spatial/knn/specializations.cuh>
#if defined RAFT_DISTANCE_COMPILED
#include <raft/cluster/specializations.cuh>
#include <raft/neighbors/specializations.cuh>
#else
#pragma message("NN / Distance specializations are not enabled; expect very long building times.")
#endif
#endif

Expand Down
23 changes: 12 additions & 11 deletions cpp/bench/neighbors/refine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#include <raft/neighbors/specializations.cuh>
#endif

#if defined RAFT_NN_COMPILED
Expand All @@ -52,7 +53,7 @@ inline auto operator<<(std::ostream& os, const RefineInputs<IdxT>& p) -> std::os
return os;
}

RefineInputs<int64_t> p;
RefineInputs<uint64_t> p;

template <typename DataT, typename DistanceT, typename IdxT>
class RefineAnn : public fixture {
Expand Down Expand Up @@ -98,24 +99,24 @@ class RefineAnn : public fixture {
RefineHelper<DataT, DistanceT, IdxT> data;
};

std::vector<RefineInputs<int64_t>> getInputs()
std::vector<RefineInputs<uint64_t>> getInputs()
{
std::vector<RefineInputs<int64_t>> out;
std::vector<RefineInputs<uint64_t>> out;
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded;
for (bool host_data : {true, false}) {
for (int64_t n_queries : {1000, 10000}) {
for (int64_t dim : {128, 512}) {
out.push_back(RefineInputs<int64_t>{n_queries, 2000000, dim, 32, 128, metric, host_data});
out.push_back(RefineInputs<int64_t>{n_queries, 2000000, dim, 10, 40, metric, host_data});
for (uint64_t n_queries : {1000, 10000}) {
for (uint64_t dim : {128, 512}) {
out.push_back(RefineInputs<uint64_t>{n_queries, 2000000, dim, 32, 128, metric, host_data});
out.push_back(RefineInputs<uint64_t>{n_queries, 2000000, dim, 10, 40, metric, host_data});
}
}
}
return out;
}

using refine_float_int64 = RefineAnn<float, float, int64_t>;
RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs());
using refine_float_uint64 = RefineAnn<float, float, uint64_t>;
RAFT_BENCH_REGISTER(refine_float_uint64, "", getInputs());

using refine_uint8_int64 = RefineAnn<uint8_t, float, int64_t>;
RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs());
using refine_uint8_uint64 = RefineAnn<uint8_t, float, uint64_t>;
RAFT_BENCH_REGISTER(refine_uint8_uint64, "", getInputs());
} // namespace raft::bench::neighbors
207 changes: 96 additions & 111 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,6 +59,7 @@ namespace detail {
* @param core_op the core accumulation operation lambda
* @param epilog_op the epilog operation lambda
* @param fin_op the final gemm epilogue lambda
* @param rowEpilog_op epilog lambda that executes when a full row has been processed
*/

template <bool useNorms,
Expand Down Expand Up @@ -87,6 +88,11 @@ struct PairwiseDistances : public BaseClass {
FinalLambda fin_op;
rowEpilogueLambda rowEpilog_op;

const IdxT grid_stride_m;
const IdxT grid_stride_n;
const IdxT grid_offset_m;
const IdxT grid_offset_n;

AccT acc[P::AccRowsPerTh][P::AccColsPerTh];

public:
Expand Down Expand Up @@ -116,96 +122,83 @@ struct PairwiseDistances : public BaseClass {
core_op(_core_op),
epilog_op(_epilog_op),
fin_op(_fin_op),
rowEpilog_op(_rowEpilog_op)
rowEpilog_op(_rowEpilog_op),
grid_stride_m(P::Mblk * gridDim.y),
grid_stride_n(P::Nblk * gridDim.x),
grid_offset_m(P::Mblk * blockIdx.y),
grid_offset_n(P::Nblk * blockIdx.x)
{
}

DI void run()
{
for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m;
gridStrideY += P::Mblk * gridDim.y) {
for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n;
gridStrideX += P::Nblk * gridDim.x) {
prolog(gridStrideX, gridStrideY);
loop();
epilog(gridStrideX, gridStrideY);
for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) {
this->ldgXY(tile_idx_m, grid_offset_n, 0);
for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) {
// Prolog:
reset_accumulator();
this->stsXY();
__syncthreads();
this->switch_write_buffer();

// Main loop:
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(tile_idx_m, tile_idx_n, kidx);
// Process all data in shared memory (previous k-block) and
// accumulate in registers.
accumulate();
this->stsXY();
__syncthreads();
this->switch_write_buffer();
this->switch_read_buffer();
}
accumulate(); // last iteration
// The pre-condition for the loop over tile_idx_n is that write_buffer
// and read_buffer point to the same buffer. This flips read_buffer back
// so that it satisfies the pre-condition of this loop.
this->switch_read_buffer();

// Epilog:
if (useNorms) {
DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
load_norms(tile_idx_m, tile_idx_n, regxn, regyn);
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
}
if (writeOut) { store_output(tile_idx_m, tile_idx_n); }
}
rowEpilog_op(gridStrideY);
rowEpilog_op(tile_idx_m);
}
}

private:
DI void updateIndicesY()
{
const auto stride = P::Nblk * gridDim.x;
if (isRowMajor) {
this->y += stride * this->ldb;
} else {
this->y += stride;
}
this->yrowid += stride;
}

DI void updateIndicesXY()
{
const auto stride = P::Mblk * gridDim.y;
if (isRowMajor) {
this->x += stride * this->lda;
this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid;
this->y = yBase + this->yrowid * this->ldb;
} else {
this->x += stride;
this->yrowid = IdxT(blockIdx.x) * P::Nblk;
this->y = yBase + this->yrowid + this->srowid * this->ldb;
}
this->xrowid += stride;
}

DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY)
DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n)
{
// Fetch next grid stride ldg if within range
if ((gridStrideX + gridDim.x * P::Nblk) < this->n) {
updateIndicesY();
this->ldgXY(0);
} else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) {
updateIndicesXY();
this->ldgXY(0);
const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n;
const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m;
if ((next_tile_tile_idx_n) < this->n) {
this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0);
} else if ((next_tile_tile_idx_m) < this->m) {
this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0);
}
}

DI void prolog(IdxT gridStrideX, IdxT gridStrideY)
DI void reset_accumulator()
{
if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); }

// Reset accumulator registers to zero.
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
}
}

this->stsXY();
__syncthreads();
this->pageWr ^= 1;
}

DI void loop()
{
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(kidx);
accumulate(); // on the previous k-block
this->stsXY();
__syncthreads();
this->pageWr ^= 1;
this->pageRd ^= 1;
}
accumulate(); // last iteration
// This is needed for making sure next grid stride of
// non-norm based metrics uses previously accumulated buffer so
// it doesn't make shmem dirty until previous iteration
// is complete.
this->pageRd ^= 1;
}

DI void accumulate()
Expand All @@ -226,60 +219,52 @@ struct PairwiseDistances : public BaseClass {
}
}

DI void epilog(IdxT gridStrideX, IdxT gridStrideY)
DI void load_norms(IdxT tile_idx_m,
IdxT tile_idx_n,
DataT (&regxn)[P::AccRowsPerTh],
DataT (&regyn)[P::AccColsPerTh])
{
if (useNorms) {
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (gridStrideX == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = gridStrideY + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
}
}

for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = gridStrideX + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (tile_idx_n == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = tile_idx_m + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
}
}

__syncthreads();
for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = tile_idx_n + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
}
__syncthreads();

DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
#pragma unroll
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}

// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY);
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}
}

if (writeOut) {
IdxT starty = gridStrideY + this->accrowid;
IdxT startx = gridStrideX + this->acccolid;
DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n)
{
IdxT starty = tile_idx_m + this->accrowid;
IdxT startx = tile_idx_n + this->acccolid;

#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
}
}
Expand Down
Loading

0 comments on commit c58d00a

Please sign in to comment.