Skip to content

Commit

Permalink
Merge pull request #135 from NVIDIA/cutlass_2.3_final
Browse files Browse the repository at this point in the history
CUTLASS 2.3.0
  • Loading branch information
d-k-b authored Sep 25, 2020
2 parents c53f333 + 37a8f9e commit c2b80ad
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 15 deletions.
16 changes: 8 additions & 8 deletions include/cutlass/epilogue/threadblock/predicated_tile_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,11 @@ class PredicatedTileIterator {

bool guard = row_guard && mask_.predicates[column];

cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
column],
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess],
guard);
if (guard) {

memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
}
}

if (row + 1 < ThreadMap::Iterations::kRow) {
Expand Down Expand Up @@ -691,8 +690,9 @@ class InterleavedPredicatedTileIterator {

bool guard = col_guard && mask_.predicates[iteration_contiguous_];

cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
*frag_ptr, (void *)memory_pointer, guard);
if (guard) {
*memory_pointer = *frag_ptr;
}
}

/// Overrides the internal iteration index
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/kernel/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ struct Gemm {

// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;

//
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/kernel/gemm_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ struct GemmArray {

// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);

int lane_idx = threadIdx.x % 32;

Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/kernel/gemm_batched.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ struct GemmBatched {

// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);

int lane_idx = threadIdx.x % 32;

Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/kernel/gemm_planar_complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ struct GemmPlanarComplex {

// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);

int lane_idx = threadIdx.x % 32;

Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/kernel/gemm_planar_complex_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ struct GemmPlanarComplexArray {

// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;

//
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/kernel/gemm_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ struct GemmUniversal {

// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);

int lane_idx = threadIdx.x % 32;

Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/kernel/sparse_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ struct SparseGemm {

// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;

//
Expand Down

0 comments on commit c2b80ad

Please sign in to comment.