Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FasterCodeGen/FasterGPTJ #3017

Merged
merged 28 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions examples/code_generation/codegen/codegen_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import random
import string
import time
import uvicorn
import paddle
from paddlenlp.utils.log import logger
from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM
from sse_starlette.sse import EventSourceResponse
from fastapi import FastAPI, Response, status
from pydantic import BaseModel


class DefaultConfig:
model_name = "Salesforce/codegen-2B-mono"
device = "gpu"
temperature = 0.5
top_k = 10
top_p = 1.0
repetition_penalty = 1.0
min_length = 0
max_length = 16
decode_strategy = "sampling"
use_fp16_decoding = True
decoding_lib = "/home/gongenlei/faster_gptj/paddlenlp/ops/build/lib/libdecoding_op"
use_faster = True
default_dtype = "float16" if use_fp16_decoding else "float32"
load_state_as_np = True


class Input(BaseModel):
prompt: str
stream: bool = False


class Output(BaseModel):
id: str
model: str = "codegen"
object: str = "text_completion"
created: int = int(time.time())
choices: list = None
usage = {
"completion_tokens": None,
"prompt_tokens": None,
"total_tokens": None,
}


generate_config = DefaultConfig()
paddle.set_device(generate_config.device)
paddle.set_default_dtype(generate_config.default_dtype)

tokenizer = CodeGenTokenizer.from_pretrained(generate_config.model_name)
model = CodeGenForCausalLM.from_pretrained(
generate_config.model_name,
load_state_as_np=generate_config.load_state_as_np)

app = FastAPI()


def random_completion_id():
return 'cmpl-' + ''.join(
random.choice(string.ascii_letters + string.digits) for _ in range(29))


@app.post("/v1/engines/codegen/completions", status_code=200)
async def gen(item: Input):
item = item.dict()
logger.info(f"Request: {item}")
temperature = item.get("temperature", generate_config.temperature)
top_k = item.get("top_k", generate_config.top_k)
if temperature == 0.0:
temperature = 1.0
top_k = 1
repetition_penalty = item.get('frequency_penalty',
generate_config.repetition_penalty)

start_time = time.time()
logger.info("Start generating code")
tokenized = tokenizer([item['prompt']], return_tensors='pd')
output, _ = model.generate(
tokenized["input_ids"],
max_length=16,
min_length=generate_config.min_length,
decode_strategy=generate_config.decode_strategy,
top_k=top_k,
repetition_penalty=repetition_penalty,
temperature=temperature,
use_faster=generate_config.use_faster,
use_fp16_decoding=generate_config.use_fp16_decoding,
decoding_lib=generate_config.decoding_lib,
)
logger.info("Finish generating code")
end_time = time.time()
logger.info(f"Time cost: {end_time - start_time}")
output = tokenizer.decode(output[0],
skip_special_tokens=True,
spaces_between_special_tokens=False)
logger.info(f"Generated code: {output}")
output_json = Output(
id=random_completion_id(),
choices=[{
"text": output,
"index": 0,
"finish_reason": "stop",
"logprobs": None,
}],
usage={
"completion_tokens": None,
"prompt_tokens": None,
"total_tokens": None,
},
).json()

def stream_response(response):
yield f"{response}\n\n"
yield "data: [DONE]\n\n"

if item.get("stream", False):
return EventSourceResponse(stream_response(output_json))
else:
return Response(
status_code=status.HTTP_200_OK,
content=output_json,
media_type="application/json",
)


if __name__ == "__main__":
uvicorn.run("codegen_server:app", host="0.0.0.0", port=8978)
5 changes: 5 additions & 0 deletions examples/code_generation/codegen/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fastapi==0.79.0
pydantic==1.9.1
python-dotenv==0.20.0
sse_starlette==0.10.3
uvicorn==0.17.6
29 changes: 27 additions & 2 deletions paddlenlp/ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ option(WITH_BART "Compile with BART"
option(WITH_MBART "Compile with MBART" ON)
option(WITH_PARALLEL "Compile with model parallel for GPT" OFF)
option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF)
option(WITH_GPTJ "Compile with GPTJ" ON)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
if(WITH_PARALLEL)
Expand Down Expand Up @@ -85,8 +86,12 @@ if(WITH_MBART)
list(APPEND decoding_op_files fusion_mbart_decoding_op.cc fusion_mbart_decoding_op.cu)
endif()

if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER AND NOT WITH_BART AND NOT WITH_MBART)
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON or/and -DWITH_BART=ON or/and -DWITH_MBART=ON must be set to use FasterTransformer. ")
if(WITH_GPTJ)
list(APPEND decoding_op_files fusion_gptj_op.cc fusion_gptj_op.cu)
endif()

if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER AND NOT WITH_BART AND NOT WITH_MBART AND NOT WITH_GPTJ)
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON or/and -DWITH_BART=ON or/and -DWITH_MBART=ON or/and -DWITH_GPTJ=ON must be set to use FasterTransformer. ")
endif()

set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
Expand Down Expand Up @@ -321,6 +326,21 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/opt.h opt_h_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/opt.h opt_h_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/gptj.h gptj_h_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/gptj.h gptj_h_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention_utils.h masked_multihead_attention_utils_h_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/masked_multihead_attention_utils.h masked_multihead_attention_utils_h_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.h masked_multihead_attention_h_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/masked_multihead_attention.h masked_multihead_attention_h_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cu attention_kernels_cu_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/attention_kernels.cu attention_kernels_cu_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh attention_kernels_cuh_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/attention_kernels.cuh attention_kernels_cuh_dst)

# Encoder patches.
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_dst)
Expand Down Expand Up @@ -360,6 +380,11 @@ set(FT_PATCH_COMMAND
&& cp ${fastertransformer_cmakelists_src} ${fastertransformer_cmakelists_dst}
&& cp ${gpt_h_src} ${gpt_h_dst}
&& cp ${opt_h_src} ${opt_h_dst}
&& cp ${gptj_h_src} ${gptj_h_dst}
&& cp ${masked_multihead_attention_h_src} ${masked_multihead_attention_h_dst}
&& cat ${masked_multihead_attention_utils_h_src} >> ${masked_multihead_attention_utils_h_dst}
&& cat ${attention_kernels_cu_src} >> ${attention_kernels_cu_dst}
&& cat ${attention_kernels_cuh_src} >> ${attention_kernels_cuh_dst}
&& cat blank_lines ${cuda_kernels_h_src} >> ${cuda_kernels_h_dst}
&& cat blank_lines ${lightseq_kernels_cu_src} >> ${topk_kernels_dst}
&& cat blank_lines ${cuda_kernels_cu_src} >> ${cuda_kernels_cu_dst}
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/ops/ext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _write_setup_file(name, file_path, build_dir, **kwargs):
f.write(content)


@file_lock(os.path.join(PPNLP_HOME, "load_ext.lock"))
# @file_lock(os.path.join(PPNLP_HOME, "load_ext.lock"))
def load(name, build_dir=None, force=False, verbose=False, **kwargs):
# TODO(guosheng): Need better way to resolve unsupported such as CPU. Currently,
# raise NotImplementedError and skip `_jit_compile`. Otherwise, `_jit_compile`
Expand Down
181 changes: 181 additions & 0 deletions paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#include <string>
#include <vector>

#include "fusion_gptj_op.h"
#include "pd_traits.h"


std::vector<paddle::Tensor> GPTJForward(
const paddle::Tensor& input,
const paddle::Tensor& attn_mask,
const paddle::Tensor& start_length,
const paddle::Tensor& word_embedding,
const std::vector<paddle::Tensor>& self_ln_weight,
const std::vector<paddle::Tensor>& self_ln_bias,
const std::vector<paddle::Tensor>& self_q_weight,
const std::vector<paddle::Tensor>& self_out_weight,
const std::vector<paddle::Tensor>& ffn_inter_weight,
const std::vector<paddle::Tensor>& ffn_inter_bias,
const std::vector<paddle::Tensor>& ffn_out_weight,
const std::vector<paddle::Tensor>& ffn_out_bias,
const paddle::Tensor& decoder_ln_weight,
const paddle::Tensor& decoder_ln_bias,
const paddle::Tensor& emb_weight,
const paddle::Tensor& emb_bias,
const int& topk,
const float& topp,
const int& max_len,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const float& temperature,
const int& rotary_embedding_dim,
const float& repetition_penalty,
const int& min_length,
const bool& use_fp16 = false,
const int& tensor_para_size = 1,
const int& layer_para_size = 1,
const int& layer_para_batch_size = 1) {
int batch_size = input.shape()[0];
int start_len = input.shape()[1];
int total_len = max_len + start_len;
std::vector<int64_t> output_dims({total_len, batch_size});
auto output_ids = paddle::Tensor(input.place(), output_dims);

if (word_embedding.place() == paddle::PlaceType::kGPU) {
return GPTJCUDAForward(input,
attn_mask,
start_length,
word_embedding,
self_ln_weight,
self_ln_bias,
self_q_weight,
self_out_weight,
ffn_inter_weight,
ffn_inter_bias,
ffn_out_weight,
ffn_out_bias,
decoder_ln_weight,
decoder_ln_bias,
emb_weight,
emb_bias,
output_ids,
topk,
topp,
total_len,
n_head,
size_per_head,
num_layer,
bos_id,
eos_id,
temperature,
rotary_embedding_dim,
repetition_penalty,
min_length,
use_fp16,
tensor_para_size,
layer_para_size,
layer_para_batch_size);
} else {
PD_THROW("Not implemented place. Only GPU is supported. ");
}
}

std::vector<std::vector<int64_t>> GPTJInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& attn_mask_shape,
const std::vector<int64_t>& start_length,
const std::vector<int64_t>& word_embedding_shape,
const std::vector<std::vector<int64_t>>& self_ln_weight_shapes,
const std::vector<std::vector<int64_t>>& self_ln_bias_shapes,
const std::vector<std::vector<int64_t>>& self_q_weight_shapes,
const std::vector<std::vector<int64_t>>& self_out_weight_shapes,
const std::vector<std::vector<int64_t>>& ffn_inter_weight_shapes,
const std::vector<std::vector<int64_t>>& ffn_inter_bias_shapes,
const std::vector<std::vector<int64_t>>& ffn_out_weight_shapes,
const std::vector<std::vector<int64_t>>& ffn_out_bias_shapes,
const std::vector<int64_t>& decoder_ln_weight_shape,
const std::vector<int64_t>& decoder_ln_bias_shape,
const std::vector<int64_t>& emb_weight_shape,
const std::vector<int64_t>& emb_bias_shape,
const int& topk,
const float& topp,
const int& max_len,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const float& temperature,
const int& rotary_embedding_dim,
const float& repetition_penalty,
const int& min_length,
const bool& use_fp16 = false,
const int& tensor_para_size = 1,
const int& layer_para_size = 1,
const int& layer_para_batch_size = 1) {
int64_t batch_size = input_shape[0];
int64_t start_len = input_shape[1];
std::vector<int64_t> output_dims({max_len + start_len, batch_size});
return {output_dims};
}

std::vector<paddle::DataType> GPTJInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& attn_mask_dtype,
const paddle::DataType& start_length_dtype,
const paddle::DataType& word_embedding_dtype,
const std::vector<paddle::DataType>& self_ln_weight_dtype,
const std::vector<paddle::DataType>& self_ln_bias_dtype,
const std::vector<paddle::DataType>& self_q_weight_dtype,
const std::vector<paddle::DataType>& self_out_weight_dtype,
const std::vector<paddle::DataType>& ffn_inter_weight_dtype,
const std::vector<paddle::DataType>& ffn_inter_bias_dtype,
const std::vector<paddle::DataType>& ffn_out_weight_dtype,
const std::vector<paddle::DataType>& ffn_out_bias_dtype,
const paddle::DataType& decoder_ln_weight_dtype,
const paddle::DataType& decoder_ln_bias_dtype,
const paddle::DataType& emb_weight_dtype,
const paddle::DataType& emb_bias_dtype) {
return {paddle::DataType::INT32};
}

PD_BUILD_OP(fusion_gptj)
.Inputs({"Input",
"AttentionMask",
"StartLength",
"WordEmbedding",
paddle::Vec("SelfLayernormWeight"),
paddle::Vec("SelfLayernormBias"),
paddle::Vec("SelfQueryWeight"),
paddle::Vec("SelfOutWeight"),
paddle::Vec("FFNInterWeight"),
paddle::Vec("FFNInterBias"),
paddle::Vec("FFNOutWeight"),
paddle::Vec("FFNOutBias"),
"DecoderLayernormWeight",
"DecoderLayernormBias",
"EmbWeight",
"EmbBias"})
.Outputs({"OutputIds"})
.Attrs({"topk: int",
"topp: float",
"max_len: int",
"n_head: int",
"size_per_head: int",
"num_layer: int",
"bos_id: int",
"eos_id: int",
"temperature: float",
"rotary_embedding_dim: int",
"repetition_penalty: float",
"min_length: int",
"use_fp16: bool",
"tensor_para_size: int",
"layer_para_size: int",
"layer_para_batch_size: int"})
.SetKernelFn(PD_KERNEL(GPTJForward))
.SetInferShapeFn(PD_INFER_SHAPE(GPTJInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GPTJInferDtype));
Loading