From e551ca41648db98bff1a3c0f75a64eaa0dec7ae0 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 21 Jul 2024 00:19:12 +0000 Subject: [PATCH] upd --- include/flashinfer/sampling.cuh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 2df38d24..1a2aebd1 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -112,18 +112,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb( greater_than_u[j] = inclusive_cdf[j] + aggregate > u; } + bool greater_than_u_diff[VEC_SIZE]; #ifdef FLASHINFER_CUB_SUBTRACTLEFT_DEFINED BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u, greater_than_u, BoolDiffOp()); + .SubtractLeft(greater_than_u_diff, greater_than_u, BoolDiffOp()); #else BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u, greater_than_u, BoolDiffOp()); + .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp()); #endif __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - if (greater_than_u[j] && valid[j]) { + if (greater_than_u_diff[j] && valid[j]) { atomicMin(&(temp_storage->data.sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); } }