Skip to content

Commit

Permalink
Add column major input support in contractions_nt kernels with new ke…
Browse files Browse the repository at this point in the history
…rnel policy for it (#188)

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Thejaswi. N. S (https://github.com/teju85)
  - Divye Gala (https://github.com/divyegala)

URL: rapidsai/raft#188
  • Loading branch information
mdoijade authored Apr 5, 2021
1 parent 91844e2 commit f0cd81f
Showing 1 changed file with 184 additions and 31 deletions.
215 changes: 184 additions & 31 deletions cpp/include/raft/linalg/contractions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ struct KernelPolicy {
/** output tile size along cols */
Nblk = AccColsPerTh * AccThCols,
/** number of threads loading a single row */
LdgThK = Kblk / Veclen,
LdgThRow = Kblk / Veclen,
/** number of LDGs issued by a single thread for X */
LdgPerThX = Mblk * LdgThK / Nthreads,
LdgPerThX = Mblk * LdgThRow / Nthreads,
/** number of LDGs issued by a single thread for Y */
LdgPerThY = Nblk * LdgThK / Nthreads,
LdgPerThY = Nblk * LdgThRow / Nthreads,
/** number of rows of X covered per LDG */
LdgRowsX = Mblk / LdgPerThX,
/** number of rows of Y covered per LDG */
Expand All @@ -98,8 +98,54 @@ struct KernelPolicy {
/** size (in B) for smem needed */
SmemSize = 2 * SmemPage * sizeof(DataT),
}; // enum
}; // struct KernelPolicy

}; // struct KernelPolicy

template <typename DataT, int _veclen, int _kblk, int _rpt, int _cpt, int _tr,
int _tc>
struct ColKernelPolicy {
enum {
/** number of elements along K worked upon per main loop iteration */
Kblk = _kblk,
/** number of elements loaded per LDG */
Veclen = _veclen,
/** number of rows a thread works on for accumulation */
AccRowsPerTh = _rpt,
/** number of cols a thread works on for accumulation */
AccColsPerTh = _cpt,
/** number of threads working the same output col */
AccThRows = _tr,
/** number of threads working the same output row */
AccThCols = _tc,
/** total threads per block */
Nthreads = AccThRows * AccThCols,
/** output tile size along rows */
Mblk = AccRowsPerTh * AccThRows,
/** output tile size along cols */
Nblk = AccColsPerTh * AccThCols,
/** number of threads loading a single col */
LdgThRow = Mblk / Veclen,
/** number of LDGs issued by a single thread for X */
LdgPerThX = Kblk * LdgThRow / Nthreads,
/** number of LDGs issued by a single thread for Y */
LdgPerThY = Kblk * LdgThRow / Nthreads,
/** number of rows of X covered per LDG */
LdgRowsX = Kblk / LdgPerThX,
/** number of rows of Y covered per LDG */
LdgRowsY = Kblk / LdgPerThY,
/** stride for accessing X/Y data in shared mem */
SmemStride = Mblk + Veclen,
/** size of one page for storing X data */
SmemPageX = SmemStride * Kblk,
/** size of one page for storing Y data */
SmemPageY = SmemStride * Kblk,
/** size of one smem page */
SmemPage = SmemPageX + SmemPageY,
/** size (in B) for smem needed */
SmemSize = 2 * SmemPage * sizeof(DataT),
}; // colMajor enum
static_assert(Mblk == Nblk, "Mblk should be equal to Nblk");
};
/**
* @defgroup Policy4x4 16 elements per thread Policy with k-block = 32
* @{
Expand All @@ -110,11 +156,13 @@ struct Policy4x4 {};
template <int _veclen>
struct Policy4x4<float, _veclen> {
typedef KernelPolicy<float, _veclen, 32, 4, 4, 16, 16> Policy;
typedef ColKernelPolicy<float, _veclen, 32, 4, 4, 16, 16> ColPolicy;
};

template <int _veclen>
struct Policy4x4<double, _veclen> {
typedef KernelPolicy<double, _veclen, 16, 4, 4, 16, 16> Policy;
typedef ColKernelPolicy<double, _veclen, 16, 4, 4, 16, 16> ColPolicy;
};
/** @} */

Expand All @@ -132,7 +180,8 @@ struct Policy4x4<double, _veclen> {
* @tparam Policy policy used to customize memory access behavior.
* See documentation for `KernelPolicy` to know more.
*/
template <typename DataT, typename IdxT, typename Policy>
template <typename DataT, typename IdxT, typename Policy,
bool isRowMajor = true>
struct Contractions_NT {
protected:
typedef Policy P;
Expand All @@ -143,6 +192,13 @@ struct Contractions_NT {
IdxT n;
/** number of columns in X and Y */
IdxT k;
/** leading dimension in X */
IdxT lda;
/** leading dimension in Y */
IdxT ldb;
/** leading dimension in Output D */
IdxT ldd;

/** current thread's global mem row id for X data */
IdxT xrowid;
/** current thread's global mem row id for Y data */
Expand Down Expand Up @@ -196,19 +252,59 @@ struct Contractions_NT {
: m(_m),
n(_n),
k(_k),
xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThK),
yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThK),
x(_x + xrowid * k),
y(_y + yrowid * k),
srowid(threadIdx.x / P::LdgThK),
scolid((threadIdx.x % P::LdgThK) * P::Veclen),
lda(_k),
ldb(_k),
xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow),
yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow),
x(_x + xrowid * lda),
y(_y + yrowid * ldb),
srowid(threadIdx.x / P::LdgThRow),
scolid((threadIdx.x % P::LdgThRow) * P::Veclen),
accrowid(threadIdx.x / P::AccThCols),
acccolid(threadIdx.x % P::AccThCols),
sx((DataT*)_smem),
sy(&(sx[P::SmemPageX])),
pageWr(0),
pageRd(0) {}

/**
* @brief Ctor
* @param[in] _x X matrix. [on device] [dim = _m x _k] [row-major]
* @param[in] _y Y matrix. [on device] [dim = _n x _k] [row-major]
* @param[in] _m number of rows of X
* @param[in] _n number of rows of Y
* @param[in] _k number of cols of X and Y
* @param[in] _smem shared memory region used during computations
*/
DI Contractions_NT(const DataT* _x, const DataT* _y, IdxT _m, IdxT _n,
IdxT _k, IdxT _lda, IdxT _ldb, IdxT _ldd, char* _smem)
: m(_m),
n(_n),
k(_k),
lda(_lda),
ldb(_ldb),
ldd(_ldd),
srowid(threadIdx.x / P::LdgThRow),
scolid((threadIdx.x % P::LdgThRow) * P::Veclen),
accrowid(threadIdx.x / P::AccThCols),
acccolid(threadIdx.x % P::AccThCols),
sx((DataT*)_smem),
sy(&(sx[P::SmemPageX])),
pageWr(0),
pageRd(0) {
if (isRowMajor) {
xrowid = IdxT(blockIdx.x) * P::Mblk + srowid;
yrowid = IdxT(blockIdx.y) * P::Nblk + srowid;
x = _x + xrowid * lda;
y = _y + yrowid * ldb;
} else {
xrowid = IdxT(blockIdx.x) * P::Mblk;
yrowid = IdxT(blockIdx.y) * P::Nblk;
x = _x + xrowid + srowid * lda;
y = _y + yrowid + srowid * ldb;
}
}

protected:
/**
* @brief Load current block of X/Y from global memory to registers
Expand Down Expand Up @@ -239,28 +335,62 @@ struct Contractions_NT {

private:
DI void ldgX(IdxT kidx) {
auto koffset = kidx + scolid;
for (int i = 0; i < P::LdgPerThX; ++i) {
if (koffset < k && (xrowid + i * P::LdgRowsX) < m) {
ldg(ldgDataX[i], x + i * P::LdgRowsX * k + koffset);
} else {
if (isRowMajor) {
auto numRows = m;
auto koffset = kidx + scolid;
for (int i = 0; i < P::LdgPerThX; ++i) {
if (koffset < lda && (xrowid + i * P::LdgRowsX) < numRows) {
ldg(ldgDataX[i], x + i * P::LdgRowsX * lda + koffset);
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataX[i][j] = Zero;
for (int j = 0; j < P::Veclen; ++j) {
ldgDataX[i][j] = Zero;
}
}
}
} else {
const auto numRows = k;
auto koffset = scolid;
for (int i = 0; i < P::LdgPerThX; ++i) {
if ((koffset + xrowid) < lda &&
(srowid + kidx + i * P::LdgRowsX) < numRows) {
ldg(ldgDataX[i], x + (kidx + i * P::LdgRowsX) * lda + koffset);
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataX[i][j] = Zero;
}
}
}
}
}

DI void ldgY(IdxT kidx) {
auto koffset = kidx + scolid;
for (int i = 0; i < P::LdgPerThY; ++i) {
if (koffset < k && (yrowid + i * P::LdgRowsY) < n) {
ldg(ldgDataY[i], y + i * P::LdgRowsY * k + koffset);
} else {
if (isRowMajor) {
auto numRows = n;
auto koffset = kidx + scolid;
for (int i = 0; i < P::LdgPerThY; ++i) {
if (koffset < ldb && (yrowid + i * P::LdgRowsY) < numRows) {
ldg(ldgDataY[i], y + i * P::LdgRowsY * ldb + koffset);
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataY[i][j] = Zero;
for (int j = 0; j < P::Veclen; ++j) {
ldgDataY[i][j] = Zero;
}
}
}
} else {
auto numRows = k;
auto koffset = scolid;
for (int i = 0; i < P::LdgPerThY; ++i) {
if ((koffset + yrowid) < ldb &&
(srowid + kidx + i * P::LdgRowsY) < numRows) {
ldg(ldgDataY[i], y + (kidx + i * P::LdgRowsY) * ldb + koffset);
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataY[i][j] = Zero;
}
}
}
}
Expand All @@ -283,20 +413,43 @@ struct Contractions_NT {
}

DI void ldsX(int kidx, DataT* smem) {
auto* saddr = smem + accrowid * P::SmemStride + kidx;
if (isRowMajor) {
auto* saddr = smem + accrowid * P::SmemStride + kidx;
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
lds(regx[i], saddr + i * P::AccThRows * P::SmemStride);
}
} else {
auto* saddr = smem + accrowid + kidx * P::SmemStride;
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
lds(regx[i], saddr + i * P::AccThRows * P::SmemStride);
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int v = 0; v < P::Veclen; ++v) {
regx[i][v] = saddr[i * P::AccThRows + v * P::SmemStride];
}
}
}
}

DI void ldsY(int kidx, DataT* smem) {
auto* saddr = smem + acccolid * P::SmemStride + kidx;
if (isRowMajor) {
auto* saddr = smem + acccolid * P::SmemStride + kidx;
#pragma unroll
for (int i = 0; i < P::AccColsPerTh; ++i) {
lds(regy[i], saddr + i * P::AccThCols * P::SmemStride);
}
} else {
auto* saddr = smem + acccolid + kidx * P::SmemStride;
#pragma unroll
for (int i = 0; i < P::AccColsPerTh; ++i) {
lds(regy[i], saddr + i * P::AccThCols * P::SmemStride);
for (int i = 0; i < P::AccColsPerTh; ++i) {
#pragma unroll
for (int v = 0; v < P::Veclen; ++v) {
regy[i][v] = saddr[i * P::AccThCols + v * P::SmemStride];
}
}
}
}

}; // struct Contractions_NT

} // namespace linalg
Expand Down

0 comments on commit f0cd81f

Please sign in to comment.