Skip to content

Commit

Permalink
[Paddle-TRT] Support PromptTuning in the transformer model for VarSeq…
Browse files Browse the repository at this point in the history
…len (#57034)

Support PromptTuning in the transformer model for VarSeqlen
  • Loading branch information
Wangzheee authored Sep 11, 2023
1 parent dbfa292 commit 6f26480
Show file tree
Hide file tree
Showing 13 changed files with 1,882 additions and 25 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ if(WITH_TENSORRT)
pass_library(preln_elementwise_groupnorm_act_pass inference)
pass_library(groupnorm_act_pass inference)
pass_library(trans_layernorm_fuse_pass inference)
pass_library(trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass
inference)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(split_layernorm_to_math_ops_pass inference)
Expand Down
15 changes: 9 additions & 6 deletions paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ namespace framework {
namespace ir {
namespace patterns {
void EmbEltwiseLayernorm::operator()() {
// Create nodes for fused_embedding_eltwise_layernorm.
auto* emb_elt_layernorm_op =
pattern->NewNode(emb_elt_layernorm_op_repr())
->assert_is_op("fused_embedding_eltwise_layernorm");
// Create nodes for fused_embedding_eltwise_layernorm or
// prompt_tuning_emb_eltwise_layernorm.
std::unordered_set<std::string> embedding_ops{
"fused_embedding_eltwise_layernorm",
"prompt_tuning_emb_eltwise_layernorm"};
auto* emb_elt_layernorm_op = pattern->NewNode(emb_elt_layernorm_op_repr())
->assert_is_ops(embedding_ops);
auto* emb_elt_layernorm_out =
pattern->NewNode(emb_elt_layernorm_out_repr())
->assert_is_op_output("fused_embedding_eltwise_layernorm", "Out");
->assert_is_ops_output(embedding_ops, "Out");

// Add links for fused_embedding_eltwise_layernorm op.
// Add links for embedding_ops.
emb_elt_layernorm_op->LinksTo({emb_elt_layernorm_out});
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <utility>

#include "paddle/fluid/framework/ir/fuse_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

struct TrtPromptTuningEmbedding2Eltwise1Pattern : public PatternBase {
TrtPromptTuningEmbedding2Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding2_eltwise1") {}

void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(feed2);
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table2);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};

struct TrtPromptTuningEmbedding1Eltwise1Pattern : public PatternBase {
TrtPromptTuningEmbedding1Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding1_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(eltwise_add_in);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};

struct TrtPromptTuningSkipLayerNorm : public PatternBase {
TrtPromptTuningSkipLayerNorm(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "skip_layernorm") {}
void operator()();

PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
PATTERN_DECL_NODE(mul0_x);
PATTERN_DECL_NODE(mul0_y);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(eltadd0_b);
PATTERN_DECL_NODE(eltadd0);
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(relu);
PATTERN_DECL_NODE(relu_out);
PATTERN_DECL_NODE(mul1_y);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(eltadd1_b);
PATTERN_DECL_NODE(eltadd1);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(concat);
PATTERN_DECL_NODE(concat_out);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
};
} // namespace patterns

class TrtPromptTuningEmbeddingEltwiseLayerNormFusePass : public FusePassBase {
public:
TrtPromptTuningEmbeddingEltwiseLayerNormFusePass();
virtual ~TrtPromptTuningEmbeddingEltwiseLayerNormFusePass() {}

protected:
void ApplyImpl(Graph* graph) const;
int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const;
const std::string name_scope_{
"trt_prompt_tuning_embedding_eltwise_layernorm_fuse"};
};

} // namespace ir
} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2893,6 +2893,7 @@ USE_TRT_CONVERTER(sign);
#endif
USE_TRT_CONVERTER(rsqrt);
USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER(prompt_tuning_emb_eltwise_layernorm);
USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
USE_TRT_CONVERTER(preln_skip_layernorm)
USE_TRT_CONVERTER(fused_bias_dropout_residual_layer_norm)
Expand Down
35 changes: 18 additions & 17 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,24 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); }

const std::vector<std::string> kTRTSubgraphPasses({
"trt_support_nhwc_pass",
"adaptive_pool2d_convert_global_pass", //
"trt_map_ops_to_matrix_multiply_pass", //
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", //
"trt_delete_weight_dequant_linear_op_pass", //
"delete_quant_dequant_linear_op_pass", //
"identity_op_clean_pass", //
"add_support_int8_pass", //
"simplify_with_basic_ops_pass", //
"trt_embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", //
"trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", //
"multihead_matmul_roformer_fuse_pass", //
"constant_folding_pass", //
"adaptive_pool2d_convert_global_pass", //
"trt_map_ops_to_matrix_multiply_pass", //
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", //
"trt_delete_weight_dequant_linear_op_pass", //
"delete_quant_dequant_linear_op_pass", //
"identity_op_clean_pass", //
"add_support_int8_pass", //
"simplify_with_basic_ops_pass", //
"trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass", //
"trt_embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", //
"trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", //
"multihead_matmul_roformer_fuse_pass", //
"constant_folding_pass", //
#ifdef PADDLE_WITH_TENSORRT
#if !IS_TRT_VERSION_GE(8610)
"trt_flash_multihead_matmul_fuse_pass", //
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ list(

if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc
preln_emb_eltwise_layernorm.cc)
preln_emb_eltwise_layernorm.cc prompt_tuning_emb_eltwise_layernorm.cc)
endif()

if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/utils.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/phi/core/ddim.h"

namespace paddle {
namespace inference {
namespace tensorrt {

class PromptTuningEmbEltwiseLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert fused_prompt_tuning_embedding_eltwise_layernorm op to "
"tensorrt layer";
// get the presistable var's data
auto GetWeight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<phi::DenseTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetTrtWeight(var_name, *temp_tensor);
return weight;
};

framework::OpDesc op_desc(op, nullptr);
auto* dense_vector = engine_->GetITensor(op_desc.Input("DenseVector")[0]);

auto pos_id_name = engine_->tensorrt_transformer_posid();
auto mask_id_name = engine_->tensorrt_transformer_maskid();

// bool with_fp16 = engine_->WithFp16() &&
// !engine_->disable_trt_plugin_fp16(); int hidden = 0; Declare inputs
std::vector<nvinfer1::ITensor*> input_ids;

// Declare inputs_weight
std::vector<nvinfer1::Weights> input_embs;
std::vector<int> emb_sizes;
TensorRTEngine::Weight weight;
framework::DDim emb_dims;
framework::DDim bias_dims, scale_dims;
TensorRTEngine::Weight bias_weight, scale_weight;

int64_t bias_size = phi::product(bias_dims);
int64_t scale_size = phi::product(scale_dims);
bool enable_int8 = op_desc.HasAttr("enable_int8");

std::vector<std::string> id_names = op_desc.Input("Ids");
std::vector<std::string> emb_names = op_desc.Input("Embs");
int input_num = id_names.size();

engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name));
engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name));
for (int i = 0; i < input_num; i++) {
auto input_tensor = engine_->GetITensor(id_names[i]);
weight = GetWeight(emb_names[i], &emb_dims);
if (id_names[i] == pos_id_name) {
input_ids.insert(input_ids.begin(), input_tensor);
input_embs.insert(input_embs.begin(), weight.get());
emb_sizes.insert(emb_sizes.begin(), weight.get().count);
} else {
input_ids.push_back(input_tensor);
input_embs.push_back(weight.get());
emb_sizes.push_back(weight.get().count);
}
}
bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims);
bias_size = phi::product(bias_dims);
scale_size = phi::product(scale_dims);
// other_id(except pos_id)
engine_->SetITensor("word_id", input_ids[1]);

int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0);
if (enable_int8) {
output_fp16 = 1;
}
PADDLE_ENFORCE_EQ(
output_fp16,
1,
platform::errors::InvalidArgument(
"Only Precision::KHalf(fp16) is supported when infering "
"ernie(bert) model with config.EnableVarseqlen(). "
"But Precision::KFloat32 is setted."));

std::vector<nvinfer1::PluginField> fields;
std::vector<std::string> temp_fields_keys;
fields.emplace_back("bert_embeddings_layernorm_beta",
bias_weight.get().values,
GetPluginFieldType(bias_weight.get().type),
static_cast<int32_t>(bias_size));
fields.emplace_back("bert_embeddings_layernorm_gamma",
scale_weight.get().values,
GetPluginFieldType(scale_weight.get().type),
static_cast<int32_t>(scale_size));
fields.emplace_back(
"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1);
for (int i = 0; i < input_num; ++i) {
temp_fields_keys.push_back("bert_embeddings_word_embeddings_" +
std::to_string(i));
fields.emplace_back(temp_fields_keys.rbegin()->c_str(),
input_embs[i].values,
GetPluginFieldType(input_embs[i].type),
static_cast<int32_t>(emb_sizes[i]));
}

nvinfer1::PluginFieldCollection* plugin_ptr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_ptr) +
fields.size() * sizeof(nvinfer1::PluginField)));
plugin_ptr->nbFields = static_cast<int>(fields.size());
plugin_ptr->fields = fields.data();

std::vector<nvinfer1::ITensor*> plugin_inputs = input_ids;
plugin_inputs.emplace_back(
engine_->GetITensor("mask_id")); // input mask_id

plugin_inputs.emplace_back(dense_vector); // prompt_tuning'dense_vector

auto creator = GetPluginRegistry()->getPluginCreator(
"PromptTuningEmbLayerNormVarlenPluginDynamic", "1");
auto plugin_obj = creator->createPlugin(
"PromptTuningEmbLayerNormVarlenPluginDynamic", plugin_ptr);

auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);

plugin_layer->setName(
("PromptTuningEmbLayerNormVarlenPluginDynamicV1(Output: " +
op_desc.Output("Out")[0] + ")")
.c_str());
free(plugin_ptr);
if (enable_int8) {
float out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(plugin_layer->getOutput(0),
out_scale); // output
engine_->SetTensorDynamicRange(plugin_layer->getOutput(1),
out_scale); // mask
engine_->SetTensorDynamicRange(plugin_layer->getOutput(2),
out_scale); // max seqlen
}

engine_->DeleteITensor("mask_id", engine_->GetITensor("mask_id"));
engine_->DeleteITensor("pos_id", engine_->GetITensor("pos_id"));

auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(plugin_layer,
"PromptTuningEmbLayerNormVarlenPluginDynamicV1",
{output_name,
std::string("qkv_plugin_mask"),
std::string("max_seqlen_tensor"),
std::string("mask_id"),
std::string("pos_id")},
test_mode);
}
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(prompt_tuning_emb_eltwise_layernorm,
PromptTuningEmbEltwiseLayerNormOpConverter);
Loading

0 comments on commit 6f26480

Please sign in to comment.