diff --git a/cub/block/specializations/block_reduce_warp_reductions.cuh b/cub/block/specializations/block_reduce_warp_reductions.cuh index 4a49b1277a..10ba303b4c 100644 --- a/cub/block/specializations/block_reduce_warp_reductions.cuh +++ b/cub/block/specializations/block_reduce_warp_reductions.cuh @@ -94,9 +94,9 @@ struct BlockReduceWarpReductions // Thread fields _TempStorage &temp_storage; - unsigned int linear_tid; - unsigned int warp_id; - unsigned int lane_id; + int linear_tid; + int warp_id; + int lane_id; /// Constructor @@ -169,13 +169,11 @@ struct BlockReduceWarpReductions T input, ///< [in] Calling thread's input partial reductions int num_valid) ///< [in] Number of valid elements (may be less than BLOCK_THREADS) { - cub::Sum reduction_op; - unsigned int warp_offset = warp_id * LOGICAL_WARP_SIZE; - unsigned int warp_num_valid = (FULL_TILE && EVEN_WARP_MULTIPLE) ? + cub::Sum reduction_op; + int warp_offset = (warp_id * LOGICAL_WARP_SIZE); + int warp_num_valid = ((FULL_TILE && EVEN_WARP_MULTIPLE) || (warp_offset + LOGICAL_WARP_SIZE <= num_valid)) ? LOGICAL_WARP_SIZE : - (warp_offset < num_valid) ? - num_valid - warp_offset : - 0; + num_valid - warp_offset; // Warp reduction in every warp T warp_aggregate = WarpReduce(temp_storage.warp_reduce[warp_id]).template Reduce<(FULL_TILE && EVEN_WARP_MULTIPLE)>( @@ -197,12 +195,10 @@ struct BlockReduceWarpReductions int num_valid, ///< [in] Number of valid elements (may be less than BLOCK_THREADS) ReductionOp reduction_op) ///< [in] Binary reduction operator { - unsigned int warp_offset = warp_id * LOGICAL_WARP_SIZE; - unsigned int warp_num_valid = (FULL_TILE && EVEN_WARP_MULTIPLE) ? + int warp_offset = warp_id * LOGICAL_WARP_SIZE; + int warp_num_valid = ((FULL_TILE && EVEN_WARP_MULTIPLE) || (warp_offset + LOGICAL_WARP_SIZE <= num_valid)) ? LOGICAL_WARP_SIZE : - (warp_offset < static_cast(num_valid)) ? - num_valid - warp_offset : - 0; + num_valid - warp_offset; // Warp reduction in every warp T warp_aggregate = WarpReduce(temp_storage.warp_reduce[warp_id]).template Reduce<(FULL_TILE && EVEN_WARP_MULTIPLE)>(