Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move contractions tiling logic outside of Contractions_NT #837

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -284,7 +284,18 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/cluster/update_centroids_double.cu
src/distance/cluster/cluster_cost_float.cu
src/distance/cluster/cluster_cost_double.cu
src/distance/neighbors/refine.cu
src/distance/neighbors/refine_d_uint64_t_float.cu
src/distance/neighbors/refine_d_uint64_t_int8_t.cu
src/distance/neighbors/refine_d_uint64_t_uint8_t.cu
src/distance/neighbors/refine_h_uint64_t_float.cu
src/distance/neighbors/refine_h_uint64_t_int8_t.cu
src/distance/neighbors/refine_h_uint64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_d_uint64_t_float.cu
src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu
src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_h_uint64_t_float.cu
src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu
src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu
src/distance/neighbors/ivfpq_search.cu
src/distance/cluster/kmeans_fit_float.cu
src/distance/cluster/kmeans_fit_double.cu
Expand Down
10 changes: 5 additions & 5 deletions cpp/bench/distance/distance_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

namespace raft::bench::distance {

struct distance_inputs {
struct distance_params {
int m, n, k;
bool isRowMajor;
}; // struct distance_inputs
}; // struct distance_params

template <typename T, raft::distance::DistanceType DType>
struct distance : public fixture {
distance(const distance_inputs& p)
distance(const distance_params& p)
: params(p),
x(p.m * p.k, stream),
y(p.n * p.k, stream),
Expand Down Expand Up @@ -63,13 +63,13 @@ struct distance : public fixture {
}

private:
distance_inputs params;
distance_params params;
rmm::device_uvector<T> x, y, out;
rmm::device_uvector<char> workspace;
size_t worksize;
}; // struct Distance

const std::vector<distance_inputs> dist_input_vecs{
const std::vector<distance_params> dist_input_vecs{
{32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true},
{256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true},
{16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true},
Expand Down
3 changes: 3 additions & 0 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#include <raft/spatial/knn/specializations.cuh>
#if defined RAFT_DISTANCE_COMPILED
#include <raft/cluster/specializations.cuh>
#include <raft/neighbors/specializations.cuh>
#else
#pragma message("NN / Distance specializations are not enabled; expect very long building times.")
#endif
#endif

Expand Down
23 changes: 12 additions & 11 deletions cpp/bench/neighbors/refine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#include <raft/neighbors/specializations.cuh>
#endif

#if defined RAFT_NN_COMPILED
Expand All @@ -52,7 +53,7 @@ inline auto operator<<(std::ostream& os, const RefineInputs<IdxT>& p) -> std::os
return os;
}

RefineInputs<int64_t> p;
RefineInputs<uint64_t> p;

template <typename DataT, typename DistanceT, typename IdxT>
class RefineAnn : public fixture {
Expand Down Expand Up @@ -98,24 +99,24 @@ class RefineAnn : public fixture {
RefineHelper<DataT, DistanceT, IdxT> data;
};

std::vector<RefineInputs<int64_t>> getInputs()
std::vector<RefineInputs<uint64_t>> getInputs()
{
std::vector<RefineInputs<int64_t>> out;
std::vector<RefineInputs<uint64_t>> out;
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded;
for (bool host_data : {true, false}) {
for (int64_t n_queries : {1000, 10000}) {
for (int64_t dim : {128, 512}) {
out.push_back(RefineInputs<int64_t>{n_queries, 2000000, dim, 32, 128, metric, host_data});
out.push_back(RefineInputs<int64_t>{n_queries, 2000000, dim, 10, 40, metric, host_data});
for (uint64_t n_queries : {1000, 10000}) {
for (uint64_t dim : {128, 512}) {
out.push_back(RefineInputs<uint64_t>{n_queries, 2000000, dim, 32, 128, metric, host_data});
out.push_back(RefineInputs<uint64_t>{n_queries, 2000000, dim, 10, 40, metric, host_data});
}
}
}
return out;
}

using refine_float_int64 = RefineAnn<float, float, int64_t>;
RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs());
using refine_float_uint64 = RefineAnn<float, float, uint64_t>;
RAFT_BENCH_REGISTER(refine_float_uint64, "", getInputs());

using refine_uint8_int64 = RefineAnn<uint8_t, float, int64_t>;
RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs());
using refine_uint8_uint64 = RefineAnn<uint8_t, float, uint64_t>;
RAFT_BENCH_REGISTER(refine_uint8_uint64, "", getInputs());
} // namespace raft::bench::neighbors
207 changes: 96 additions & 111 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,6 +59,7 @@ namespace detail {
* @param core_op the core accumulation operation lambda
* @param epilog_op the epilog operation lambda
* @param fin_op the final gemm epilogue lambda
* @param rowEpilog_op epilog lambda that executes when a full row has been processed
*/

template <bool useNorms,
Expand Down Expand Up @@ -87,6 +88,11 @@ struct PairwiseDistances : public BaseClass {
FinalLambda fin_op;
rowEpilogueLambda rowEpilog_op;
ahendriksen marked this conversation as resolved.
Show resolved Hide resolved

const IdxT grid_stride_m;
const IdxT grid_stride_n;
const IdxT grid_offset_m;
const IdxT grid_offset_n;

AccT acc[P::AccRowsPerTh][P::AccColsPerTh];

public:
Expand Down Expand Up @@ -116,96 +122,83 @@ struct PairwiseDistances : public BaseClass {
core_op(_core_op),
epilog_op(_epilog_op),
fin_op(_fin_op),
rowEpilog_op(_rowEpilog_op)
rowEpilog_op(_rowEpilog_op),
grid_stride_m(P::Mblk * gridDim.y),
grid_stride_n(P::Nblk * gridDim.x),
grid_offset_m(P::Mblk * blockIdx.y),
grid_offset_n(P::Nblk * blockIdx.x)
{
}

DI void run()
{
for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m;
gridStrideY += P::Mblk * gridDim.y) {
for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n;
gridStrideX += P::Nblk * gridDim.x) {
prolog(gridStrideX, gridStrideY);
loop();
epilog(gridStrideX, gridStrideY);
for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) {
this->ldgXY(tile_idx_m, grid_offset_n, 0);
for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) {
// Prolog:
reset_accumulator();
ahendriksen marked this conversation as resolved.
Show resolved Hide resolved
this->stsXY();
__syncthreads();
this->switch_write_buffer();

// Main loop:
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(tile_idx_m, tile_idx_n, kidx);
// Process all data in shared memory (previous k-block) and
// accumulate in registers.
accumulate();
this->stsXY();
__syncthreads();
this->switch_write_buffer();
this->switch_read_buffer();
}
accumulate(); // last iteration
// The pre-condition for the loop over tile_idx_n is that write_buffer
// and read_buffer point to the same buffer. This flips read_buffer back
// so that it satisfies the pre-condition of this loop.
this->switch_read_buffer();

// Epilog:
if (useNorms) {
DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
load_norms(tile_idx_m, tile_idx_n, regxn, regyn);
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
}
if (writeOut) { store_output(tile_idx_m, tile_idx_n); }
}
rowEpilog_op(gridStrideY);
rowEpilog_op(tile_idx_m);
}
}

private:
DI void updateIndicesY()
{
const auto stride = P::Nblk * gridDim.x;
if (isRowMajor) {
this->y += stride * this->ldb;
} else {
this->y += stride;
}
this->yrowid += stride;
}

DI void updateIndicesXY()
{
const auto stride = P::Mblk * gridDim.y;
if (isRowMajor) {
this->x += stride * this->lda;
this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid;
this->y = yBase + this->yrowid * this->ldb;
} else {
this->x += stride;
this->yrowid = IdxT(blockIdx.x) * P::Nblk;
this->y = yBase + this->yrowid + this->srowid * this->ldb;
}
this->xrowid += stride;
}

DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY)
DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n)
{
// Fetch next grid stride ldg if within range
if ((gridStrideX + gridDim.x * P::Nblk) < this->n) {
updateIndicesY();
this->ldgXY(0);
} else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) {
updateIndicesXY();
this->ldgXY(0);
const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n;
const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m;
if ((next_tile_tile_idx_n) < this->n) {
this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0);
} else if ((next_tile_tile_idx_m) < this->m) {
this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0);
}
}

DI void prolog(IdxT gridStrideX, IdxT gridStrideY)
DI void reset_accumulator()
{
if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); }

// Reset accumulator registers to zero.
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
}
}

this->stsXY();
__syncthreads();
this->pageWr ^= 1;
}

DI void loop()
{
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(kidx);
accumulate(); // on the previous k-block
this->stsXY();
__syncthreads();
this->pageWr ^= 1;
this->pageRd ^= 1;
}
accumulate(); // last iteration
// This is needed for making sure next grid stride of
// non-norm based metrics uses previously accumulated buffer so
// it doesn't make shmem dirty until previous iteration
// is complete.
this->pageRd ^= 1;
}

DI void accumulate()
Expand All @@ -226,60 +219,52 @@ struct PairwiseDistances : public BaseClass {
}
}

DI void epilog(IdxT gridStrideX, IdxT gridStrideY)
DI void load_norms(IdxT tile_idx_m,
IdxT tile_idx_n,
DataT (&regxn)[P::AccRowsPerTh],
DataT (&regyn)[P::AccColsPerTh])
{
if (useNorms) {
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (gridStrideX == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = gridStrideY + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
}
}

for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = gridStrideX + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (tile_idx_n == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = tile_idx_m + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
}
}

__syncthreads();
for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = tile_idx_n + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
}
__syncthreads();

DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
#pragma unroll
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}

// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY);
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}
}

if (writeOut) {
IdxT starty = gridStrideY + this->accrowid;
IdxT startx = gridStrideX + this->acccolid;
DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n)
{
IdxT starty = tile_idx_m + this->accrowid;
IdxT startx = tile_idx_n + this->acccolid;

#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
}
}
Expand Down
Loading