Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FT sampling performance issue #2910

Merged
merged 6 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions paddlenlp/ops/faster_transformer/src/fusion_unified_decoding_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ std::vector<paddle::Tensor> unified_decoding_kernel(
const std::string& hidden_act,
const bool early_stopping,
const int min_length,
cublasHandle_t cublas_handle_,
cublasLtHandle_t cublaslt_handle_,
cudaStream_t stream,
const int tensor_para_size = 1,
const int layer_para_size = 1,
Expand Down Expand Up @@ -147,8 +145,9 @@ std::vector<paddle::Tensor> unified_decoding_kernel(
typedef typename traits_::data_t data_t_;

DecodingInitParam<DataType_> decoding_params;
decoding_params.cublas_handle = cublas_handle_;
decoding_params.cublaslt_handle = cublaslt_handle_;
decoding_params.cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
decoding_params.cublaslt_handle =
CublasHandle::GetInstance()->cublaslt_handle_;

decoding_params.output_ids = output_ids.mutable_data<int>(input_ids.place());
decoding_params.parent_ids = parent_ids.mutable_data<int>(input_ids.place());
Expand Down Expand Up @@ -228,8 +227,8 @@ std::vector<paddle::Tensor> unified_decoding_kernel(
i
: i;
params[layer_idx].stream = stream;
params[layer_idx].cublas_handle = cublas_handle_;
params[layer_idx].cublaslt_handle = cublaslt_handle_;
params[layer_idx].cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
params[layer_idx].cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_;

if (decoding_strategy == "beam_search" ||
decoding_strategy == "beam_search_v2" ||
Expand Down Expand Up @@ -545,11 +544,8 @@ std::vector<paddle::Tensor> UnifiedDecodingCUDAForward(
const int layer_para_size = 1,
const int layer_para_batch_size = 1) {
auto stream = input_ids.stream();
cublasHandle_t cublas_handle_;
cublasCreate(&cublas_handle_);
cublasSetStream(cublas_handle_, stream);
cublasLtHandle_t cublaslt_handle_;
cublasLtCreate(&cublaslt_handle_);

cublasSetStream(CublasHandle::GetInstance()->cublas_handle_, stream);

std::vector<paddle::Tensor> ret;

Expand Down Expand Up @@ -618,8 +614,6 @@ std::vector<paddle::Tensor> UnifiedDecodingCUDAForward(
hidden_act,
early_stopping,
min_length,
cublas_handle_,
cublaslt_handle_,
stream,
tensor_para_size,
layer_para_size,
Expand Down Expand Up @@ -690,8 +684,6 @@ std::vector<paddle::Tensor> UnifiedDecodingCUDAForward(
hidden_act,
early_stopping,
min_length,
cublas_handle_,
cublaslt_handle_,
stream,
tensor_para_size,
layer_para_size,
Expand All @@ -706,7 +698,5 @@ std::vector<paddle::Tensor> UnifiedDecodingCUDAForward(
}
}

cublasDestroy(cublas_handle_);
cublasLtDestroy(cublaslt_handle_);
return ret;
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
// #include "fastertransformer/decoding_sampling.h"
// #include "fastertransformer/open_decoder.h"
// #include "fastertransformer/utils/common.h"
#include "cublas_handle.h"

#ifdef PADDLE_ON_INFERENCE
#include "paddle/include/experimental/ext_all.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void ker_curand_setupLauncher(curandState_t* state,
cudaStream_t stream) {
dim3 block(256);
dim3 grid((int)(ceil(args.batch_size_ * 1.0 / 256)));
int seed = args.seed_ != -1 ? args.seed_ : clock();
int seed = args.seed_ != -1 ? args.seed_ : clock() % INT_MAX;
ker_curand_setup<<<grid, block, 0, stream>>>(state, args.batch_size_, seed);
}

Expand Down