Skip to content

Commit

Permalink
contractions: Concentrate tile index calculations
Browse files Browse the repository at this point in the history
The calculation of the tile indices are now performed in ldgXY(). This
will make it possible to remove all state related to the tile index out
of the class in the next commit.

Note that the calculation of the tile index can depend on which
overloaded constructor is called(!)
  • Loading branch information
ahendriksen authored and Allard Hendriksen committed Jan 24, 2023
1 parent 0076101 commit a3b8587
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 49 deletions.
27 changes: 7 additions & 20 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,27 +138,14 @@ struct PairwiseDistances : public BaseClass {
DI void updateIndicesY()
{
const auto stride = P::Nblk * gridDim.x;
if (isRowMajor) {
this->y += stride * this->ldb;
} else {
this->y += stride;
}
this->yrowid += stride;
this->increment_grid_idx_n(stride);
}

DI void updateIndicesXY()
{
const auto stride = P::Mblk * gridDim.y;
if (isRowMajor) {
this->x += stride * this->lda;
this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid;
this->y = yBase + this->yrowid * this->ldb;
} else {
this->x += stride;
this->yrowid = IdxT(blockIdx.x) * P::Nblk;
this->y = yBase + this->yrowid + this->srowid * this->ldb;
}
this->xrowid += stride;
this->increment_grid_idx_m(stride);
this->reset_grid_idx_n();
}

DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY)
Expand Down Expand Up @@ -187,7 +174,7 @@ struct PairwiseDistances : public BaseClass {

this->stsXY();
__syncthreads();
this->pageWr ^= 1;
this->switch_write_buffer();
}

DI void loop()
Expand All @@ -197,15 +184,15 @@ struct PairwiseDistances : public BaseClass {
accumulate(); // on the previous k-block
this->stsXY();
__syncthreads();
this->pageWr ^= 1;
this->pageRd ^= 1;
this->switch_write_buffer();
this->switch_read_buffer();
}
accumulate(); // last iteration
// This is needed for making sure next grid stride of
// non-norm based metrics uses previously accumulated buffer so
// it doesn't make shmem dirty until previous iteration
// is complete.
this->pageRd ^= 1;
this->switch_read_buffer();
}

DI void accumulate()
Expand Down
84 changes: 60 additions & 24 deletions cpp/include/raft/linalg/detail/contractions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ struct Contractions_NT {
/** leading dimension in Output D */
IdxT ldd;

/** current thread's global mem row id for X data */
IdxT xrowid;
/** current thread's global mem row id for Y data */
IdxT yrowid;
/** global memory pointer to X matrix */
const DataT* x;
const DataT* x_base;
/** global memory pointer to Y matrix */
const DataT* y;
const DataT* y_base;

/** Support variables to provide backward compatibility **/
IdxT grid_idx_m = 0;
IdxT grid_idx_n = 0;
bool first_constructor_called;

/** current thread's smem row id */
int srowid;
Expand Down Expand Up @@ -94,18 +95,17 @@ struct Contractions_NT {
k(_k),
lda(_k),
ldb(_k),
xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow),
yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow),
x(_x + xrowid * lda),
y(_y + yrowid * ldb),
x_base(_x),
y_base(_y),
srowid(threadIdx.x / P::LdgThRow),
scolid((threadIdx.x % P::LdgThRow) * P::Veclen),
accrowid(threadIdx.x / P::AccThCols),
acccolid(threadIdx.x % P::AccThCols),
sx((DataT*)_smem),
sy(&(sx[P::SmemPageX])),
pageWr(0),
pageRd(0)
pageRd(0),
first_constructor_called(true)
{
}

Expand Down Expand Up @@ -133,26 +133,18 @@ struct Contractions_NT {
lda(_lda),
ldb(_ldb),
ldd(_ldd),
x_base(_x),
y_base(_y),
srowid(threadIdx.x / P::LdgThRow),
scolid((threadIdx.x % P::LdgThRow) * P::Veclen),
accrowid(threadIdx.x / P::AccThCols),
acccolid(threadIdx.x % P::AccThCols),
sx((DataT*)_smem),
sy(&(sx[P::SmemPageX])),
pageWr(0),
pageRd(0)
pageRd(0),
first_constructor_called(false)
{
if (isRowMajor) {
xrowid = IdxT(blockIdx.y) * P::Mblk + srowid;
yrowid = IdxT(blockIdx.x) * P::Nblk + srowid;
x = _x + xrowid * lda;
y = _y + yrowid * ldb;
} else {
xrowid = IdxT(blockIdx.y) * P::Mblk;
yrowid = IdxT(blockIdx.x) * P::Nblk;
x = _x + xrowid + srowid * lda;
y = _y + yrowid + srowid * ldb;
}
}

protected:
Expand All @@ -166,6 +158,12 @@ struct Contractions_NT {
ldgY(kidx);
}

DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx)
{
ldgX(tile_idx_m, kidx);
ldgY(tile_idx_n, kidx);
}

/**
* @brief Store current block of X/Y from registers to smem
* @param[in] kidx current start index of k to be loaded
Expand All @@ -186,9 +184,35 @@ struct Contractions_NT {
ldsY(kidx, sy + pageRd * P::SmemPage);
}

DI void increment_grid_idx_m(IdxT by) { grid_idx_m += by; }

DI void increment_grid_idx_n(IdxT by) { grid_idx_n += by; }

DI void reset_grid_idx_n() { grid_idx_n = 0; }

DI void switch_read_buffer() { this->pageRd ^= 1; }

DI void switch_write_buffer() { this->pageWr ^= 1; }

private:
DI void ldgX(IdxT kidx)
{
// Backward compatible way to determine the tile index. This depends on
// whether the first or the second constructor was called. The first
// constructor is called in epsilon_neighborhood.cuh and the second
// constructor is called in pairwise_distance_base.cuh.
if (first_constructor_called) {
ldgX(IdxT(blockIdx.x) * P::Mblk, kidx);
} else {
ldgX(grid_idx_m + IdxT(blockIdx.y) * P::Mblk, kidx);
}
}

DI void ldgX(IdxT tile_idx_m, IdxT kidx)
{
IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m;
auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda;

if (isRowMajor) {
auto numRows = m;
auto koffset = kidx + scolid;
Expand Down Expand Up @@ -222,6 +246,18 @@ struct Contractions_NT {

DI void ldgY(IdxT kidx)
{
if (first_constructor_called) {
ldgY(IdxT(blockIdx.y) * P::Nblk, kidx);
} else {
ldgY(grid_idx_n + IdxT(blockIdx.x) * P::Nblk, kidx);
}
}

DI void ldgY(IdxT tile_idx_n, IdxT kidx)
{
IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n;
auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb;

if (isRowMajor) {
auto numRows = n;
auto koffset = kidx + scolid;
Expand Down Expand Up @@ -315,4 +351,4 @@ struct Contractions_NT {

} // namespace detail
} // namespace linalg
} // namespace raft
} // namespace raft
10 changes: 5 additions & 5 deletions cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
private:
DI void prolog()
{
this->ldgXY(0);
this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0);
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
Expand All @@ -74,18 +74,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
}
this->stsXY();
__syncthreads();
this->pageWr ^= 1;
this->switch_write_buffer();
}

DI void loop()
{
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(kidx);
this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx);
accumulate(); // on the previous k-block
this->stsXY();
__syncthreads();
this->pageWr ^= 1;
this->pageRd ^= 1;
this->switch_write_buffer();
this->switch_read_buffer();
}
accumulate(); // last iteration
}
Expand Down

0 comments on commit a3b8587

Please sign in to comment.