diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 2df38d24..1a2aebd1 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -112,18 +112,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb( greater_than_u[j] = inclusive_cdf[j] + aggregate > u; } + bool greater_than_u_diff[VEC_SIZE]; #ifdef FLASHINFER_CUB_SUBTRACTLEFT_DEFINED BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u, greater_than_u, BoolDiffOp()); + .SubtractLeft(greater_than_u_diff, greater_than_u, BoolDiffOp()); #else BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u, greater_than_u, BoolDiffOp()); + .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp()); #endif __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - if (greater_than_u[j] && valid[j]) { + if (greater_than_u_diff[j] && valid[j]) { atomicMin(&(temp_storage->data.sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); } }