diff --git a/paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.cu b/paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.cu index 9f439a553e6e..8a95865c9b05 100644 --- a/paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.cu +++ b/paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.cu @@ -110,8 +110,6 @@ std::vector unified_decoding_kernel( const std::string& hidden_act, const bool early_stopping, const int min_length, - cublasHandle_t cublas_handle_, - cublasLtHandle_t cublaslt_handle_, cudaStream_t stream, const int tensor_para_size = 1, const int layer_para_size = 1, @@ -147,8 +145,9 @@ std::vector unified_decoding_kernel( typedef typename traits_::data_t data_t_; DecodingInitParam decoding_params; - decoding_params.cublas_handle = cublas_handle_; - decoding_params.cublaslt_handle = cublaslt_handle_; + decoding_params.cublas_handle = CublasHandle::GetInstance()->cublas_handle_; + decoding_params.cublaslt_handle = + CublasHandle::GetInstance()->cublaslt_handle_; decoding_params.output_ids = output_ids.mutable_data(input_ids.place()); decoding_params.parent_ids = parent_ids.mutable_data(input_ids.place()); @@ -228,8 +227,8 @@ std::vector unified_decoding_kernel( i : i; params[layer_idx].stream = stream; - params[layer_idx].cublas_handle = cublas_handle_; - params[layer_idx].cublaslt_handle = cublaslt_handle_; + params[layer_idx].cublas_handle = CublasHandle::GetInstance()->cublas_handle_; + params[layer_idx].cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_; if (decoding_strategy == "beam_search" || decoding_strategy == "beam_search_v2" || @@ -545,11 +544,8 @@ std::vector UnifiedDecodingCUDAForward( const int layer_para_size = 1, const int layer_para_batch_size = 1) { auto stream = input_ids.stream(); - cublasHandle_t cublas_handle_; - cublasCreate(&cublas_handle_); - cublasSetStream(cublas_handle_, stream); - cublasLtHandle_t cublaslt_handle_; - cublasLtCreate(&cublaslt_handle_); + + cublasSetStream(CublasHandle::GetInstance()->cublas_handle_, stream); std::vector ret; @@ -618,8 +614,6 @@ std::vector UnifiedDecodingCUDAForward( hidden_act, early_stopping, min_length, - cublas_handle_, - cublaslt_handle_, stream, tensor_para_size, layer_para_size, @@ -690,8 +684,6 @@ std::vector UnifiedDecodingCUDAForward( hidden_act, early_stopping, min_length, - cublas_handle_, - cublaslt_handle_, stream, tensor_para_size, layer_para_size, @@ -706,7 +698,5 @@ std::vector UnifiedDecodingCUDAForward( } } - cublasDestroy(cublas_handle_); - cublasLtDestroy(cublaslt_handle_); return ret; } diff --git a/paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.h b/paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.h index 6eed1645dd51..2173563118d4 100644 --- a/paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.h +++ b/paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.h @@ -20,6 +20,7 @@ limitations under the License. */ // #include "fastertransformer/decoding_sampling.h" // #include "fastertransformer/open_decoder.h" // #include "fastertransformer/utils/common.h" +#include "cublas_handle.h" #ifdef PADDLE_ON_INFERENCE #include "paddle/include/experimental/ext_all.h" diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu index 2173af73cbaf..bfd680a6b24f 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu @@ -40,7 +40,7 @@ void ker_curand_setupLauncher(curandState_t* state, cudaStream_t stream) { dim3 block(256); dim3 grid((int)(ceil(args.batch_size_ * 1.0 / 256))); - int seed = args.seed_ != -1 ? args.seed_ : clock(); + int seed = args.seed_ != -1 ? args.seed_ : clock() % INT_MAX; ker_curand_setup<<>>(state, args.batch_size_, seed); }