Skip to content

Commit

Permalink
fix_matmul_op_int8_plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Wangzheee committed Nov 24, 2021
1 parent d5c51e6 commit c3d454f
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,13 @@ void MatmulPlugin::configurePlugin(const nvinfer1::PluginTensorDesc* inputs,
matmulDesc_, CUBLASLT_MATMUL_DESC_POINTER_MODE, &matmul_model,
sizeof(matmul_model)));

float alpha_tem[n_];
std::vector<float> alpha_tem(n_, 0);
for (int i = 0; i < n_; i++) {
alpha_tem[i] = alpha_ * inscale_0 * inscale_1 / outscale;
}
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMalloc((void**)&alpha_scale_, n_ * sizeof(float)));
cudaMemcpyAsync(alpha_scale_, alpha_tem, n_ * sizeof(float),
cudaMemcpyAsync(alpha_scale_, &alpha_tem[0], n_ * sizeof(float),
cudaMemcpyHostToDevice);
float zero_tem = zero;
PADDLE_ENFORCE_CUDA_SUCCESS(
Expand Down Expand Up @@ -624,13 +624,13 @@ void MatmulPluginDynamic::configurePlugin(
sizeof(int8_t) * ((m_max + 32 - 1) / 32 * 32) / 32 * ldctransform));

if (type_ == nvinfer1::DataType::kINT8) {
float alpha_tem[n_max];
std::vector<float> alpha_tem(n_max, 0);
for (int i = 0; i < n_max; i++) {
alpha_tem[i] = alpha_ * inscale_0 * inscale_1 / outscale;
}
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMalloc((void**)&alpha_scale_, n_max * sizeof(float)));
cudaMemcpyAsync(alpha_scale_, alpha_tem, n_max * sizeof(float),
cudaMemcpyAsync(alpha_scale_, &alpha_tem[0], n_max * sizeof(float),
cudaMemcpyHostToDevice);
float zero_tem = zero;
PADDLE_ENFORCE_CUDA_SUCCESS(
Expand Down

1 comment on commit c3d454f

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.