diff --git a/examples/code_generation/codegen/README.md b/examples/code_generation/codegen/README.md
index 0d83dc8589bc..72430d775ed9 100644
--- a/examples/code_generation/codegen/README.md
+++ b/examples/code_generation/codegen/README.md
@@ -1,74 +1,219 @@
-# CodeGen: A Conversational Paradigm for Program Synthesis
+# 代码生成:写代码的AI助理
-## 模型简介
+**目录**
+- [代码生成](#代码生成)
+ - [简介](#简介)
+ - [特色](#特色)
+ - [效果展示](#效果展示)
+ - [开箱即用](#开箱即用)
+ - [支持单条、批量预测](#支持单条批量预测)
+ - [可配置参数说明](#可配置参数说明)
+ - [训练定制](#训练定制)
+ - [环境依赖](#环境依赖)
+ - [代码结构说明](#代码结构说明)
+ - [数据准备](#数据准备)
+ - [从本地文件创建数据集](#从本地文件创建数据集)
+ - [Github Copilot插件配置](#GithubCopilot插件配置)
+ - [插件环境依赖](#插件环境依赖)
+ - [启动服务](#启动服务)
+ - [配置参数](#配置参数说明)
+ - [测试服务](#测试服务)
+ - [配置插件](#配置插件)
+ - [注意事项](#注意事项)
+ - [TaskFlow调用](#TaskFlow调用)
+ - [使用案例](#使用案例)
+ - [模型列表](#模型列表)
+ - [References](#references)
-[CodeGen](https://arxiv.org/pdf/2203.13474.pdf) (A Conversational Paradigm for Program Synthesis)提出了一种通过大型语言模型进行对话式程序生成的方法,将编写规范和程序的过程转换为用户和系统之间的多回合对话。它把程序生成看作一个序列预测问题,用自然语言表达规范,并有条件地对所期望的程序进行抽样。同时,CodeGen(16B)在HumanEval benchmark上已经超过[OpenAI's Codex](https://arxiv.org/pdf/2107.03374.pdf)。
-本项目展示如何调用CodeGen来进行代码生成。
+## 简介
+代码生成是根据编程人员的输入,生成出编程人员想要的代码,能够帮助编程人员甚至独立生成代码,提高编程效率。
-## 快速开始
+
+### 特色
+
+本项目是基于预训练语言模型CodeGen的代码生成,具有以下优势:
+- **效果领先**。CodeGen(16B)在HumanEval benchmark上评估指标已经超过[OpenAI's Codex](https://arxiv.org/pdf/2107.03374.pdf)。
+- **免费的Github Copilot**。支持通过Github Copilot调用该模型,让你免费体验代码AI助理。
+- **高性能**。基于[FasterGeneration](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/faster_generation)打造高性能推理,毫秒级响应。具体加速指标可参考[perf](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/faster_generation/README.md)。
+- **支持自定义数据集训练**。可增加自己的代码数据加以微调,让其更智能。
+- **开箱即用**。本项目提供TaskFlow接口,无需训练,仅需几行代码便可预测。
+
+
+## 效果展示
+
+## 训练定制
### 环境依赖
+- PaddleNLP >= 2.4.0
+- PaddlePaddle >= 2.3.1
+
+### 代码结构说明
+
+以下是本项目主要代码结构及说明:
+
+```text
+codegen/
+├── requirements.txt # 环境依赖
+├── codegen_server.py # server启动脚本
+├── run_clm.py # 训练评估脚本
+├── run_clm.sh # 启动脚本
+└── README.md # 说明文档
+```
+
+### 数据准备
+
+#### 从本地文件创建数据集
+
+在许多情况,我们需要使用本地数据集来训练我们的代码生成模型,本项目支持使用固定格式本地数据集文件进行训练。
+
+本地数据集文件格式如下:
+- train.json/test.json 文件格式:
+每行为一个jsonline
+```text
+{
+ "code": "from paddlenlp.transformers import CodeGenForCausalLM\n\n\nmodel = CodeGenForCausalLM.from_pretrained('Salesforce/codegen-2B-mono')\n"
+}
+```
+
+更多数据集读取格式详见[数据集加载](https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_load.html#)和[自定义数据集](https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_self_defined.html)。
+
+
+### 模型训练
+运行如下命令即可在样例训练集上进行finetune,并在样例验证集上进行验证。
+
+```shell
+# GPU启动,参数`--gpus`指定训练所用的GPU卡号,可以是单卡,也可以多卡
+unset CUDA_VISIBLE_DEVICES
+
+python -m paddle.distributed.launch --gpus 0,1 run_clm.py \
+ --model_name_or_path Salesforce/codegen-350M-mono \
+ --block_size 1024 \
+ --output_dir output \
+ --train_file train.json \
+ --validation_file test.json \
+ --num_train_epochs 5 \
+ --logging_steps 1 \
+ --save_steps 10 \
+ --train_batch_size 2 \
+ --eval_batch_size 2 \
+ --learning_rate 1e-4 \
+ --warmup_proportion 0.1 \
+ --device gpu
+```
+使用多卡训练可以指定多个GPU卡号,例如 --gpus "0,1"
+
+关键参数释义如下:
+- `gpus` 指示了训练所用的GPU卡号。
+- `model_name_or_path` 指示了finetune使用的具体预训练模型,可以是PaddleNLP提供的预训练模型(详见[模型列表](#模型列表)),或者是本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。
+- `block_size` 表示训练时候数据被拆分的块数。
+- `output_dir` 表示模型的保存路径。
+- `train_file` 本地训练数据地址,数据格式必须与`dataset_name`所指数据集格式相同。
+- `validation_file` 本地测试数据地址,数据格式必须与`dataset_name`所指数据集格式相同。
+- `num_train_epochs` 表示训练轮数。
+- `logging_steps` 表示日志打印间隔。
+- `save_steps` 表示模型保存及评估间隔。
+- `train_batch_size` 表示训练时**每张卡**上的样本数目。
+- `eval_batch_size` 表示测试时**每张卡**上的样本数目。
+- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
+- `warmup_propotion` 表示学习率逐渐升高到基础学习率(即上面配置的learning_rate)所需要的迭代数占总步数的比例,最早的使用可以参考[这篇论文](https://arxiv.org/pdf/1706.02677.pdf)。
+- `device` 表示使用的设备,从gpu和cpu中选择。
- - python >= 3.6
- - paddlepaddle >= 2.3.0
- - paddlenlp >= 2.3.4
+可通过`bash run_clm.sh`启动训练,更多参数详情和参数的默认值请参考`run_clm.py`。
-### 代码调用
+程序运行时将会自动进行训练和验证,训练过程中会自动保存模型在指定的`save_dir`中。
+如:
+```text
+./output/
+│── model_config.json
+│── model_state.pdparams
+│── tokenizer_config.json
+│── special_tokens_map.json
+│── added_tokens.json
+│── vocab.json
+│── merges.txt
+└── ...
+```
+
+**NOTE:** 如需恢复模型训练,`model_name_or_path`配置本地模型的目录地址即可。
+
+## GithubCopilot插件配置
+以下以VS Code的插件为例
+### 插件环境依赖
+- PaddleNLP >= 2.4.0
+- PaddlePaddle >= 2.3.1
+
+其他依赖:`pip install -r requirements.txt`
+
+
+### 启动服务
```python
-import re
-import paddle
-from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM
+python codegen_server.py
+```
+
+##### 配置参数说明
+在codegen_server.py中配置如下参数:
+- `model_name_or_path`:模型名,默认为 "Salesforce/codegen-2B-mono"
+- `device`:运行设备,默认为"gpu"
+- `temperature`:解码参数temperature,默认为0.5
+- `top_k`:解码参数top_k,默认为10
+- `top_p`:解码参数top_p,默认为1.0
+- `repetition_penalty`:解码重复惩罚项,默认为1.0
+- `min_length`:生成的最小长度,默认为0
+- `max_length`:生成的最大长度,默认为16
+- `decode_strategy`:解码策略,默认为"sampling"
+- `load_state_as_np`:以numpy格式加载模型参数,可节省显存,默认为True
+- `use_faster`:是否使用Fastergeneration,可加速推理,默认为True
+- `use_fp16_decoding`:是否使用fp16推理,可节省显存和加速推理,默认为True
+
+### 测试服务
+`pip install --upgrade openai`
+
+```python
+import openai
+openai.api_key = 'dummy'
+openai.api_base = 'http://127.0.0.1:8000/v1'
+result = openai.Completion.create(
+ engine='codegen', prompt='def hello', max_tokens=16, temperature=0.1)
+print(result)
+'''
+ JSON: {
+ "id": "cmpl-dmhoeHmcw9DJ4NeqOJDQVKv3iivJ0",
+ "choices": [
+ {
+ "text": "_world():\n print(\"Hello World!\")\n\n\n#",
+ "index": 0,
+ "finish_reason": "stop",
+ "logprobs": null,
+ }
+ ],
+ "usage": {
+ "completion_tokens": null,
+ "prompt_tokens": null,
+ "total_tokens": null
+ }
+}
+'''
-# The supported models are shown in the following table
-model_name = 'Salesforce/codegen-350M-mono'
-# Init tokenizer
-tokenizer = CodeGenTokenizer.from_pretrained(model_name)
-# Init model
-model = CodeGenForCausalLM.from_pretrained(model_name)
-inputs = tokenizer(["def hello_world():"])
-inputs = {k: paddle.to_tensor(v) for (k, v) in inputs.items()}
-# Generate
-output, score = model.generate(inputs['input_ids'],
- max_length=128,
- decode_strategy='sampling',
- top_k=5,
- repetition_penalty=1.1,
- temperature=0.6)
-# Decode the result
-print(
- re.split(
- "\nclass|\ndef|\n#|\n@|\nprint|\nif",
- tokenizer.decode(output[0],
- skip_special_tokens=True,
- spaces_between_special_tokens=False))[0].rstrip())
```
-其中参数释义如下:
-- `max_length` 解码的最大长度,默认128。
-- `decode_strategy` 解码的策略,默认sampling。
-- `top_k` 解码参数top_k,默认5。
-- `repetition_penalty` 解码重复惩罚系数,默认1.1。
-- `temperature` 解码参数temperature,默认0.6。
+### 配置插件
+打开用户设置([settings.json](https://code.visualstudio.com/docs/getstarted/settings#_settings-file-locations)),增加一行配置
+```json
+ "github.copilot.advanced": {
+ "debug.overrideEngine": "codegen",
+ "debug.testOverrideProxyUrl": "http://127.0.0.1:8978",
+ "debug.overrideProxyUrl": "http://127.0.0.1:8978"
+ },
+```
-模型列表
-| 模型名称 | 说明 |
-| :--------------------------------- | -------------------------------- |
-| Salesforce/codegen-350M-mono | 基于Python数据集BIGPYTHON训练 |
-| Salesforce/codegen-2B-mono | 基于Python数据集BIGPYTHON训练 |
-| Salesforce/codegen-6B-mono | 基于Python数据集BIGPYTHON训练 |
-| Salesforce/codegen-16B-mono | 基于Python数据集BIGPYTHON训练 |
-| Salesforce/codegen-350M-nl | 基于自然语言数据集THEPILE训练 |
-| Salesforce/codegen-2B-nl | 基于自然语言数据集THEPILE训练 |
-| Salesforce/codegen-6B-nl | 基于自然语言数据集THEPILE训练 |
-| Salesforce/codegen-16B-nl | 基于自然语言数据集THEPILE训练 |
-| Salesforce/codegen-350M-multi | 基于多编程语言数据集BIGQUERY训练 |
-| Salesforce/codegen-2B-multi | 基于多编程语言数据集BIGQUERY训练 |
-| Salesforce/codegen-6B-multi | 基于多编程语言数据集BIGQUERY训练 |
-| Salesforce/codegen-16B-multi | 基于多编程语言数据集BIGQUERY训练 |
+接下来就可以愉快地使用了😊。
+#### 注意事项
+- 如果使用FasterGeneration,需要设置[codegen_server.py](#配置参数说明)中`use_faster=True`,第一次推理会涉及到编译,会耗费一些时间。FasterGeneration的环境依赖参考[这里](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/ops/README.md#%E4%BD%BF%E7%94%A8%E7%8E%AF%E5%A2%83%E8%AF%B4%E6%98%8E)。
+- 如果要使用自己训练好的模型,可以设置[codegen_server.py](#配置参数说明)中`model_name_or_path`为本地模型路径。
-### TaskFlow调用
+## TaskFlow调用
参考[TaskFlow文档](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/model_zoo/taskflow.md)
## 使用案例
@@ -157,4 +302,24 @@ def hello_world():
hello_world()
```
-其它更多趣味性的生成欢迎大家体验,同时也欢迎大家来开发代码补全的插件。
+
+## 模型列表
+模型列表
+| 模型名称 | 说明 |
+| :--------------------------------- | -------------------------------- |
+| Salesforce/codegen-350M-mono | 基于Python数据集BIGPYTHON训练 |
+| Salesforce/codegen-2B-mono | 基于Python数据集BIGPYTHON训练 |
+| Salesforce/codegen-6B-mono | 基于Python数据集BIGPYTHON训练 |
+| Salesforce/codegen-16B-mono | 基于Python数据集BIGPYTHON训练 |
+| Salesforce/codegen-350M-nl | 基于自然语言数据集THEPILE训练 |
+| Salesforce/codegen-2B-nl | 基于自然语言数据集THEPILE训练 |
+| Salesforce/codegen-6B-nl | 基于自然语言数据集THEPILE训练 |
+| Salesforce/codegen-16B-nl | 基于自然语言数据集THEPILE训练 |
+| Salesforce/codegen-350M-multi | 基于多编程语言数据集BIGQUERY训练 |
+| Salesforce/codegen-2B-multi | 基于多编程语言数据集BIGQUERY训练 |
+| Salesforce/codegen-6B-multi | 基于多编程语言数据集BIGQUERY训练 |
+| Salesforce/codegen-16B-multi | 基于多编程语言数据集BIGQUERY训练 |
+
+## References
+- Nijkamp, Erik, et al. "A conversational paradigm for program synthesis." arXiv preprint arXiv:2203.13474 (2022).
+- [https://github.com/features/copilot/](https://github.com/features/copilot/)
diff --git a/examples/code_generation/codegen/codegen_server.py b/examples/code_generation/codegen/codegen_server.py
new file mode 100644
index 000000000000..fd448749c7ac
--- /dev/null
+++ b/examples/code_generation/codegen/codegen_server.py
@@ -0,0 +1,141 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import random
+import string
+import time
+import uvicorn
+import paddle
+from paddlenlp.utils.log import logger
+from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM
+from sse_starlette.sse import EventSourceResponse
+from fastapi import FastAPI, Response, status
+from pydantic import BaseModel
+
+
+class DefaultConfig:
+ model_name_or_path = "Salesforce/codegen-2B-mono"
+ device = "gpu"
+ temperature = 0.5
+ top_k = 10
+ top_p = 1.0
+ repetition_penalty = 1.0
+ min_length = 0
+ max_length = 16
+ decode_strategy = "sampling"
+ load_state_as_np = True
+ use_faster = True
+ use_fp16_decoding = True
+ default_dtype = "float16" if use_faster and use_fp16_decoding else "float32"
+
+
+class Input(BaseModel):
+ prompt: str
+ stream: bool = False
+
+
+class Output(BaseModel):
+ id: str
+ model: str = "codegen"
+ object: str = "text_completion"
+ created: int = int(time.time())
+ choices: list = None
+ usage = {
+ "completion_tokens": None,
+ "prompt_tokens": None,
+ "total_tokens": None,
+ }
+
+
+generate_config = DefaultConfig()
+paddle.set_device(generate_config.device)
+paddle.set_default_dtype(generate_config.default_dtype)
+
+tokenizer = CodeGenTokenizer.from_pretrained(generate_config.model_name_or_path)
+model = CodeGenForCausalLM.from_pretrained(
+ generate_config.model_name_or_path,
+ load_state_as_np=generate_config.load_state_as_np)
+
+app = FastAPI()
+
+
+def random_completion_id():
+ return 'cmpl-' + ''.join(
+ random.choice(string.ascii_letters + string.digits) for _ in range(29))
+
+
+@app.post("/v1/engines/codegen/completions", status_code=200)
+async def gen(item: Input):
+ item = item.dict()
+ logger.info(f"Request: {item}")
+ temperature = item.get("temperature", generate_config.temperature)
+ top_k = item.get("top_k", generate_config.top_k)
+ if temperature == 0.0:
+ temperature = 1.0
+ top_k = 1
+ repetition_penalty = item.get('frequency_penalty',
+ generate_config.repetition_penalty)
+
+ start_time = time.time()
+ logger.info("Start generating code")
+ tokenized = tokenizer([item['prompt']],
+ truncation=True,
+ return_tensors='pd')
+ output, _ = model.generate(
+ tokenized["input_ids"],
+ max_length=16,
+ min_length=generate_config.min_length,
+ decode_strategy=generate_config.decode_strategy,
+ top_k=top_k,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ use_faster=generate_config.use_faster,
+ use_fp16_decoding=generate_config.use_fp16_decoding)
+ logger.info("Finish generating code")
+ end_time = time.time()
+ logger.info(f"Time cost: {end_time - start_time}")
+ output = tokenizer.decode(output[0],
+ skip_special_tokens=True,
+ spaces_between_special_tokens=False)
+ logger.info(f"Generated code: {output}")
+ output_json = Output(
+ id=random_completion_id(),
+ choices=[{
+ "text": output,
+ "index": 0,
+ "finish_reason": "stop",
+ "logprobs": None,
+ }],
+ usage={
+ "completion_tokens": None,
+ "prompt_tokens": None,
+ "total_tokens": None,
+ },
+ ).json()
+
+ def stream_response(response):
+ yield f"{response}\n\n"
+ yield "data: [DONE]\n\n"
+
+ if item.get("stream", False):
+ return EventSourceResponse(stream_response(output_json))
+ else:
+ return Response(
+ status_code=status.HTTP_200_OK,
+ content=output_json,
+ media_type="application/json",
+ )
+
+
+if __name__ == "__main__":
+ uvicorn.run("codegen_server:app", host="0.0.0.0", port=8978)
diff --git a/examples/code_generation/codegen/requirements.txt b/examples/code_generation/codegen/requirements.txt
new file mode 100644
index 000000000000..ded81d0ca44e
--- /dev/null
+++ b/examples/code_generation/codegen/requirements.txt
@@ -0,0 +1,5 @@
+fastapi==0.79.0
+pydantic==1.9.1
+python-dotenv==0.20.0
+sse_starlette==0.10.3
+uvicorn==0.17.6
\ No newline at end of file
diff --git a/examples/code_generation/codegen/run_clm.py b/examples/code_generation/codegen/run_clm.py
new file mode 100644
index 000000000000..1166d22af69f
--- /dev/null
+++ b/examples/code_generation/codegen/run_clm.py
@@ -0,0 +1,375 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import argparse
+import random
+import time
+import distutils.util
+from pprint import pprint
+from functools import partial
+import numpy as np
+from itertools import chain
+from datasets import load_dataset
+import math
+import paddle
+import paddle.nn as nn
+from paddle.io import BatchSampler, DistributedBatchSampler, DataLoader
+from paddlenlp.transformers import CodeGenForCausalLM, CodeGenTokenizer
+from paddlenlp.transformers import LinearDecayWithWarmup
+from paddlenlp.utils.log import logger
+from paddlenlp.data import DataCollatorWithPadding
+from paddle.metric import Accuracy
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument("--model_name_or_path",
+ default="Salesforce/codegen-350M-mono",
+ type=str,
+ required=True,
+ help="Path to pre-trained model. ")
+ parser.add_argument(
+ "--output_dir",
+ default="output",
+ type=str,
+ required=True,
+ help=
+ "The output directory where the model predictions and checkpoints will be written."
+ )
+ parser.add_argument("--train_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The input training data file.")
+ parser.add_argument("--validation_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The input validation data file.")
+ parser.add_argument(
+ "--block_size",
+ default=None,
+ type=int,
+ help=
+ "The training dataset will be truncated in block of this size for training. "
+ )
+ parser.add_argument("--learning_rate",
+ default=1e-4,
+ type=float,
+ help="The initial learning rate for Adam.")
+ parser.add_argument(
+ "--num_train_epochs",
+ default=3,
+ type=int,
+ help="Total number of training epochs to perform.",
+ )
+ parser.add_argument("--logging_steps",
+ type=int,
+ default=100,
+ help="Log every X updates steps.")
+ parser.add_argument("--save_steps",
+ type=int,
+ default=100,
+ help="Save checkpoint every X updates steps.")
+ parser.add_argument(
+ "--train_batch_size",
+ default=20,
+ type=int,
+ help="Batch size per GPU/CPU for training.",
+ )
+ parser.add_argument(
+ "--eval_batch_size",
+ default=12,
+ type=int,
+ help="Batch size per GPU/CPU for evaluation.",
+ )
+ parser.add_argument("--weight_decay",
+ default=0.0,
+ type=float,
+ help="Weight decay if we apply some.")
+ parser.add_argument(
+ "--warmup_steps",
+ default=0,
+ type=int,
+ help=
+ "Linear warmup over warmup_steps. If > 0: Override warmup_proportion")
+ parser.add_argument("--warmup_proportion",
+ default=0.1,
+ type=float,
+ help="Linear warmup proportion over total steps.")
+ parser.add_argument("--adam_epsilon",
+ default=1e-6,
+ type=float,
+ help="Epsilon for Adam optimizer.")
+ parser.add_argument(
+ "--max_steps",
+ default=-1,
+ type=int,
+ help=
+ "If > 0: set total number of training steps to perform. Override num_train_epochs.",
+ )
+ parser.add_argument("--seed",
+ default=42,
+ type=int,
+ help="random seed for initialization")
+ parser.add_argument(
+ "--device",
+ default="gpu",
+ type=str,
+ choices=["cpu", "gpu", "xpu"],
+ help="The device to select to train the model, is must be cpu/gpu/xpu.")
+ parser.add_argument("--overwrite_cache",
+ action="store_true",
+ help="Whether to overwrite cache for dataset.")
+ parser.add_argument("--use_amp",
+ default=False,
+ type=distutils.util.strtobool,
+ help="Enable mixed precision training.")
+ parser.add_argument("--scale_loss",
+ default=2**15,
+ type=float,
+ help="The value of scale_loss for fp16.")
+ args = parser.parse_args()
+ return args
+
+
+def set_seed(args):
+ # Use the same data seed(for data shuffle) for all procs to guarantee data
+ # consistency after sharding.
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ # Maybe different op seeds(for dropout) for different procs is better. By:
+ # `paddle.seed(args.seed + paddle.distributed.get_rank())`
+ paddle.seed(args.seed)
+
+
+@paddle.no_grad()
+def evaluate(model, data_loader, loss_fct):
+ model.eval()
+ metric = Accuracy()
+ metric.reset()
+ losses = []
+ model = model._layers if isinstance(model, paddle.DataParallel) else model
+ for batch in data_loader:
+ labels = batch.pop("labels")
+ logits, _ = model(**batch)
+ loss = loss_fct(logits[:, :-1, :], labels[:, 1:])
+ correct = metric.compute(paddle.max(logits[:, :-1, :], axis=-1),
+ labels[:, 1:])
+ losses.append(loss)
+ losses = paddle.concat(losses)
+ eval_loss = paddle.mean(losses)
+ perplexity = math.exp(eval_loss)
+ accuracy = metric.accumulate()
+ logger.info("[validation] accuracy: %f, loss: %f, ppl: %f" %
+ (accuracy, eval_loss, perplexity))
+ model.train()
+ return perplexity
+
+
+def convert_example(examples, tokenizer):
+ """convert examples into necessary features"""
+ # Convert raw text to feature
+ tokenized_examples = tokenizer(examples["code"],
+ return_attention_mask=True,
+ return_position_ids=False,
+ return_token_type_ids=False)
+ return tokenized_examples
+
+
+def group_texts(examples, block_size):
+ concatenated_examples = {
+ k: list(chain(*examples[k]))
+ for k in examples.keys()
+ }
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
+ if total_length >= block_size:
+ total_length = (total_length // block_size) * block_size
+ result = {
+ k: [t[i:i + block_size] for i in range(0, total_length, block_size)]
+ for k, t in concatenated_examples.items()
+ }
+ result["labels"] = result["input_ids"].copy()
+ return result
+
+
+def process_ds(dataset, tokenizer, overwrite_cache, block_size):
+ trans_func = partial(convert_example, tokenizer=tokenizer)
+ dataset = dataset.map(trans_func,
+ batched=True,
+ remove_columns=dataset.column_names,
+ load_from_cache_file=overwrite_cache)
+ trans_func = partial(group_texts, block_size=block_size)
+ dataset = dataset.map(trans_func,
+ batched=True,
+ load_from_cache_file=overwrite_cache)
+ return dataset
+
+
+def do_train(args):
+ paddle.set_device(args.device)
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ set_seed(args)
+
+ tokenizer = CodeGenTokenizer.from_pretrained(args.model_name_or_path)
+
+ train_set = load_dataset("json", data_files=args.train_file, split="train")
+ dev_set = load_dataset("json",
+ data_files=args.validation_file,
+ split="train")
+
+ if args.block_size is None:
+ block_size = tokenizer.model_max_length
+ if block_size > 1024:
+ logger.warning(
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
+ )
+ block_size = 1024
+ else:
+ if args.block_size > tokenizer.model_max_length:
+ logger.warning(
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
+ )
+ block_size = min(args.block_size, tokenizer.model_max_length)
+
+ train_set = process_ds(train_set, tokenizer, args.overwrite_cache,
+ block_size)
+ dev_set = process_ds(dev_set, tokenizer, args.overwrite_cache, block_size)
+
+ batchify_fn = DataCollatorWithPadding(tokenizer)
+
+ train_batch_sampler = DistributedBatchSampler(
+ train_set, batch_size=args.train_batch_size, shuffle=True)
+
+ train_data_loader = DataLoader(dataset=train_set,
+ batch_sampler=train_batch_sampler,
+ collate_fn=batchify_fn,
+ num_workers=0,
+ return_list=True)
+
+ dev_batch_sampler = BatchSampler(dev_set,
+ batch_size=args.eval_batch_size,
+ shuffle=False)
+ dev_data_loader = DataLoader(dataset=dev_set,
+ batch_sampler=dev_batch_sampler,
+ collate_fn=batchify_fn,
+ num_workers=0,
+ return_list=True)
+
+ model = CodeGenForCausalLM.from_pretrained(args.model_name_or_path)
+
+ if paddle.distributed.get_world_size() > 1:
+ model = paddle.DataParallel(model)
+
+ if args.max_steps > 0:
+ num_training_steps = args.max_steps
+ num_train_epochs = math.ceil(num_training_steps /
+ len(train_data_loader))
+ else:
+ num_training_steps = len(train_data_loader) * args.num_train_epochs
+ num_train_epochs = args.num_train_epochs
+
+ warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
+
+ lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
+ warmup)
+
+ # Generate parameter names needed to perform weight decay.
+ # All bias and LayerNorm parameters are excluded.
+ decay_params = [
+ p.name for n, p in model.named_parameters()
+ if not any(nd in n for nd in ["bias", "norm"])
+ ]
+ optimizer = paddle.optimizer.AdamW(
+ learning_rate=lr_scheduler,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=args.adam_epsilon,
+ parameters=model.parameters(),
+ weight_decay=args.weight_decay,
+ apply_decay_param_fun=lambda x: x in decay_params)
+
+ loss_fct = nn.CrossEntropyLoss()
+ if args.use_amp:
+ scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
+ global_step = 0
+ best_eval_ppl = float("inf")
+ tic_train = time.time()
+ for epoch in range(num_train_epochs):
+ for step, batch in enumerate(train_data_loader):
+ global_step += 1
+ labels = batch.pop("labels")
+ with paddle.amp.auto_cast(
+ args.use_amp,
+ custom_white_list=["layer_norm", "softmax", "gelu"]):
+ logits, _ = model(**batch)
+ loss = loss_fct(logits[:, :-1, :], labels[:, 1:])
+ if args.use_amp:
+ scaled_loss = scaler.scale(loss)
+ scaled_loss.backward()
+ scaler.minimize(optimizer, scaled_loss)
+ else:
+ loss.backward()
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.clear_grad()
+ if global_step % args.logging_steps == 0:
+ logger.info(
+ "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, ppl: %f, lr: %.10f, speed: %.4f step/s"
+ % (global_step, num_training_steps, epoch, step,
+ paddle.distributed.get_rank(), loss, math.exp(loss),
+ optimizer.get_lr(), args.logging_steps /
+ (time.time() - tic_train)))
+ tic_train = time.time()
+ if global_step % args.save_steps == 0 or global_step == num_training_steps:
+ tic_eval = time.time()
+ ppl = evaluate(model, dev_data_loader, loss_fct)
+ logger.info("eval done total : %s s" % (time.time() - tic_eval))
+ if best_eval_ppl > ppl and paddle.distributed.get_rank() == 0:
+ best_eval_ppl = ppl
+ output_dir = args.output_dir
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ # Need better way to get inner model of DataParallel
+ model_to_save = model._layers if isinstance(
+ model, paddle.DataParallel) else model
+ model_to_save.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+
+ if global_step >= num_training_steps:
+ break
+ if global_step >= num_training_steps:
+ break
+
+ if paddle.distributed.get_rank() == 0:
+ output_dir = os.path.join(args.output_dir,
+ "codegen_model_final_%d" % global_step)
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ # Need better way to get inner model of DataParallel
+ model_to_save = model._layers if isinstance(
+ model, paddle.DataParallel) else model
+ model_to_save.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ pprint(args)
+ do_train(args)
diff --git a/examples/code_generation/codegen/run_clm.sh b/examples/code_generation/codegen/run_clm.sh
new file mode 100644
index 000000000000..083abcf4cf47
--- /dev/null
+++ b/examples/code_generation/codegen/run_clm.sh
@@ -0,0 +1,11 @@
+python -m paddle.distributed.launch --gpus 0,1 run_clm.py \
+ --model_name_or_path Salesforce/codegen-350M-mono \
+ --output_dir output \
+ --train_file train.json \
+ --validation_file test.json \
+ --num_train_epochs 5 \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --train_batch_size 2 \
+ --eval_batch_size 2 \
+ --device gpu
diff --git a/faster_generation/README.md b/faster_generation/README.md
index 6adb349cd1ca..58cb2a6f4f99 100644
--- a/faster_generation/README.md
+++ b/faster_generation/README.md
@@ -11,7 +11,7 @@ FasterGeneration是PaddleNLP v2.2版本加入的文本生成高性能加速功
## Featrues
-- 全面支持生成式预训练模型。包括GPT、OPT、BART、mBART、UnifiedTransformer和UNIMO-text。
+- 全面支持生成式预训练模型。包括GPT、OPT、CodeGen、GPTJ、BART、mBART、UnifiedTransformer和UNIMO-text。
- 支持大多数主流解码策略。包括Beam Search、Sampling、Greedy Search。以及Diverse Sibling Search、Length Penalty等子策略。
- 解码速度快。最高可达非加速版generate函数的**18倍**。**并支持FP16混合精度计算**。
- 易用性强。功能的入口为`model.generate`,与非加速版生成api的使用方法相同,当满足加速条件时使用jit即时编译高性能算子并用于生成,不满足则自动切换回非加速版生成api。
@@ -20,17 +20,17 @@ FasterGeneration是PaddleNLP v2.2版本加入的文本生成高性能加速功
### Inference Model Support
下表为PaddleNLP FasterGeneration对预训练模型和解码策略的支持情况(GPU)。
-| Model Name | GPT2 | OPT | BART | mBART | UnifiedTransformer |
-|------------------------|---------|---------|-----------------|-----------------|--------------------|
-| Model Structure | Decoder | Decoder | Encoder-Decoder | Encoder-Decoder | Prefix-LM |
-| Beam Search | ❌ | ❌ | ✅ | ✅ | ✅ |
-| Top-K Sampling | ✅ | ✅ | ✅ | ✅ | ✅ |
-| Top-P Sampling | ✅ | ✅ | ✅ | ✅ | ✅ |
-| Diverse Sibling Search | ❌ | ❌ | ✅ | ✅ | ✅ |
-| Forced Decoding | ❌ | ❌ | ❌ | ✅ | ❌ |
-| Length Penalty | ❌ | ❌ | ✅ | ✅ | ✅ |
-| Temperature | ✅ | ✅ | ✅ | ✅ | ✅ |
-| Repetition Penalty | ✅ | ✅ | ❌ | ❌ | ❌ |
+| Model Name | GPT2 | OPT | CodeGen| GPTJ| BART | mBART | UnifiedTransformer |
+|------------------------|---------|---------| ---------| ---------|-----------------|-----------------|--------------------|
+| Model Structure | Decoder | Decoder |Decoder|Decoder| Encoder-Decoder | Encoder-Decoder | Prefix-LM |
+| Beam Search | ❌ | ❌ |❌|❌| ✅ | ✅ | ✅ |
+| Top-K Sampling | ✅ | ✅ |✅|✅| ✅ | ✅ | ✅ |
+| Top-P Sampling | ✅ | ✅ |✅|✅| ✅ | ✅ | ✅ |
+| Diverse Sibling Search | ❌ | ❌ |❌|❌| ✅ | ✅ | ✅ |
+| Forced Decoding | ❌ | ❌ |❌|❌| ❌ | ✅ | ❌ |
+| Length Penalty | ❌ | ❌ |❌|❌| ✅ | ✅ | ✅ |
+| Temperature | ✅ | ✅ |✅|✅| ✅ | ✅ | ✅ |
+| Repetition Penalty | ✅ | ✅ |✅|✅| ❌ | ❌ | ❌ |
## Performence
@@ -61,6 +61,26 @@ FasterGeneration的高性能解码相比原版generate方法加速明显,并
+**CodeGen:**
+* 环境和超参
+- Platform: Tesla V100-SXM2-32GB
+- CUDA 10.1
+- CUDNN 7.6.5
+- PaddlePaddle-gpu 2.3.1.post101
+- transformers==4.21.1
+- torch==1.11.0
+- Batch Size: 1
+- Input Length: 60
+- Output Length: 20
+
+
+
+
+- Platform: A100-40G
+
+
+
+
更详细的性能数据请参见[这里](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/faster_generation/perf)
## Quick Start
diff --git a/faster_generation/perf/README.md b/faster_generation/perf/README.md
index c1c1e1d53090..2f89147651ba 100644
--- a/faster_generation/perf/README.md
+++ b/faster_generation/perf/README.md
@@ -154,6 +154,49 @@ transformers: 4.20.1
| | top_k=16 | 110.80 | 78.19 | 488.34 | 4.41 | 6.25 |
| | top_p=0.4 | 128.33 | 92.57 | 544.18 | 4.24 | 5.88 |
+**CodeGen:**
+* 环境和超参
+
+- Platform: Tesla V100-SXM2-32GB
+- CUDA 10.1
+- CUDNN 7.6.5
+- PaddlePaddle-gpu 2.3.1.post101
+- transformers==4.21.1
+- torch==1.11.0
+- Batch Size: 1
+- Input Length: 60
+- Output Length: 20
+
+* 模型参数
+
+| Model Name | num_layers | num_attention_heads | hidden_size |
+|------------|------------|---------------------|-------------|
+| Salesforce/codegen-350M-mono | 20 | 16 | 1024 |
+| Salesforce/codegen-2B-mono | 32 | 32 | 2560 |
+| Salesforce/codegen-6B-mono | 33 | 16 | 4096 |
+| Salesforce/codegen-16B-mono | 34 | 24 | 6144 |
+
+
+
+* 性能结果报表
+
+| Model | Decoding Strategy | Faster Generation(FP32)(ms) | Faster Generation(FP16)(ms) | HF Generation(ms) | Speed Up Rate(Faster32/HF) | Speed Up Rate(Faster16/HF) |
+|:--------:|:-------------------:|:-----------------------------:|:-----------------------------:|:-------------------:|:----------------------------:|:----------------------------:|
+| Salesforce/codegen-350M-mono | top_k=1 | 57.76 | 51.35 | 709.62 | 12.29 | 13.82 |
+| | top_k=4 | 57.42 | 50.88 | 639.58 | 11.14 | 12.57 |
+| | top_k=8 | 57.24 | 51.67 | 685.82 | 11.98 | 13.27 |
+| | top_k=16 | 57.57 | 51.61 | 686.62 | 11.93 | 13.30 |
+| | top_p=0.4 | 67.26 | 57.35 | 656.12 | 9.75 | 11.44 |
+| Salesforce/codegen-2B-mono| top_k=1 | 319.03 | 207.41 | 1040.71 | 3.26 | 5.02 |
+| | top_k=4 | 318.98 | 207.37 | 1014.32 | 3.18 | 4.89 |
+| | top_k=8 | 319.66 | 207.26 | 1084.09 | 3.39 | 5.23 |
+| | top_k=16 | 320.04 | 207.74 | 1040.28 | 3.25 | 5.01 |
+| | top_p=0.4 | 329.07 | 213.97 | 1055.55 | 3.21 | 4.93 |
+| Salesforce/codegen-6B-mono| top_k=1 | 762.91 | 411.94 | 1384.90 | 1.82 | 3.36 |
+| | top_k=4 | 762.58 | 412.79 | 1378.32 | 1.81 | 3.34 |
+| | top_k=8 | 763.43 | 413.32 | 1366.45 | 1.79 | 3.31 |
+| | top_k=16 | 762.79 | 413.83 | 1376.69 | 1.80 | 3.33 |
+| | top_p=0.4 | 771.77 | 419.16 | 1366.49 | 1.77 | 3.26 |
## 测试方法
运行如下命令即可bart性能测试:
diff --git a/faster_generation/perf/codegen_perf.py b/faster_generation/perf/codegen_perf.py
new file mode 100644
index 000000000000..dc391956d0a1
--- /dev/null
+++ b/faster_generation/perf/codegen_perf.py
@@ -0,0 +1,190 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import time
+from pprint import pprint
+import numpy as np
+import paddle
+from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM
+
+import pynvml
+
+pynvml.nvmlInit()
+
+
+def query_by_id(gpu_id=2):
+ handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
+ meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
+ return meminfo.used // 1024 // 1024
+
+
+def perf_pd(args):
+ start_mem = query_by_id()
+ place = "gpu"
+ place = paddle.set_device(place)
+ tokenizer = CodeGenTokenizer.from_pretrained(args.model_name_or_path)
+ model = CodeGenForCausalLM.from_pretrained(args.model_name_or_path,
+ load_state_as_np=True)
+ model.eval()
+ load_mem = query_by_id()
+
+ input_ids_np = [
+ np.random.choice(list(tokenizer.decoder.keys())[:-1], args.input_len)
+ for _ in range(args.batch_size)
+ ]
+ input_ids = paddle.to_tensor(input_ids_np)
+
+ num_loop = 100
+ with paddle.no_grad():
+ for i in range(num_loop):
+ # For warmup.
+ if num_loop // 2 == i:
+ # PaddlePaddle >= 2.2
+ paddle.device.cuda.synchronize(place)
+ start = time.perf_counter()
+ output, _ = model.generate(input_ids=input_ids,
+ max_length=args.generate_len,
+ min_length=args.generate_len,
+ decode_strategy="sampling",
+ top_k=args.top_k,
+ top_p=args.top_p,
+ use_faster=args.use_faster,
+ use_fp16_decoding=args.use_fp16_decoding)
+ generate_mem = query_by_id()
+ paddle.device.cuda.synchronize(place)
+ pd_cost = (time.perf_counter() - start) / (num_loop -
+ num_loop // 2) * 1000
+ return pd_cost, load_mem - start_mem, generate_mem - start_mem
+
+
+def perf_hf(args):
+ import torch
+ from transformers import CodeGenTokenizer as hf_tokenizer, CodeGenForCausalLM as hf_codegen
+ start_mem = query_by_id()
+ device = torch.device("cuda")
+ tokenizer = hf_tokenizer.from_pretrained(args.model_name_or_path)
+ model = hf_codegen.from_pretrained(args.model_name_or_path)
+ model.to(device)
+ model.eval()
+ load_mem = query_by_id()
+
+ input_ids_np = [
+ np.random.choice(list(tokenizer.decoder.keys()), args.input_len)
+ for _ in range(args.batch_size)
+ ]
+ input_ids = torch.tensor(input_ids_np)
+ input_ids = input_ids.to(device)
+ num_loop = 100
+ with torch.no_grad():
+ for i in range(num_loop):
+ # For warmup.
+ if num_loop // 2 == i:
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+ output = model.generate(
+ input_ids,
+ do_sample=True,
+ max_length=args.generate_len + input_ids.shape[-1],
+ min_length=args.generate_len + input_ids.shape[-1],
+ top_k=args.top_k,
+ top_p=args.top_p)
+ generate_mem = query_by_id()
+ torch.cuda.synchronize()
+ hf_cost = (time.perf_counter() - start) / (num_loop -
+ num_loop // 2) * 1000
+ return hf_cost, load_mem - start_mem, generate_mem - start_mem
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--perf_type",
+ default="pd",
+ type=str,
+ choices=['pd', 'pd_faster_fp32', 'pd_faster_fp16', 'hf'],
+ help="The type of perf. ")
+ parser.add_argument("--model_name_or_path",
+ default="Salesforce/codegen-350M-mono",
+ type=str,
+ choices=[
+ 'Salesforce/codegen-350M-mono',
+ 'Salesforce/codegen-2B-mono',
+ 'Salesforce/codegen-6B-mono',
+ 'Salesforce/codegen-16B-mono'
+ ],
+ help="The model name to specify the bart to use. ")
+ parser.add_argument(
+ "--top_k",
+ default=4,
+ type=int,
+ help="The number of candidate to procedure topk sampling. ")
+ parser.add_argument(
+ "--top_p",
+ default=1.0,
+ type=float,
+ help="The probability threshold to procedure topp sampling. ")
+ parser.add_argument("--batch_size",
+ default=1,
+ type=int,
+ help="The size of input batch. ")
+ parser.add_argument("--input_len",
+ default=60,
+ type=int,
+ help="The size of model input. ")
+ parser.add_argument("--generate_len",
+ default=20,
+ type=int,
+ help="Length of output . ")
+ parser.add_argument(
+ '--use_faster',
+ action='store_true',
+ help='Whether to process inference using faster codegen. ')
+
+ parser.add_argument("--use_fp16_decoding",
+ action="store_true",
+ help="Whether to use fp16 decoding to predict. ")
+ args = parser.parse_args()
+ return args
+
+
+def do_predict(args):
+ try:
+ if args.perf_type == 'pd':
+ args.use_faster = False
+ cost, load_mem, generate_mem = perf_pd(args)
+ elif args.perf_type == 'pd_faster_fp32':
+ args.use_faster = True
+ args.use_fp16_decoding = False
+ cost, load_mem, generate_mem = perf_pd(args)
+ elif args.perf_type == 'pd_faster_fp16':
+ args.use_faster = True
+ args.use_fp16_decoding = True
+ paddle.set_default_dtype('float16')
+ cost, load_mem, generate_mem = perf_pd(args)
+ else:
+ cost, load_mem, generate_mem = perf_hf(args)
+ pprint(args)
+ print(
+ f'CodeGenPerfResult: cost_time: {cost} ms, load_mem: {load_mem} MB, generate_mem:{generate_mem} MB, args:{args}\n'
+ )
+ except Exception as e:
+ pprint(args)
+ print(f'CodeGenPerfResult: ERROR: {e}, args:{args}\n')
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ do_predict(args)
diff --git a/faster_generation/perf/run_perf_codegen.sh b/faster_generation/perf/run_perf_codegen.sh
new file mode 100644
index 000000000000..0d02575af92d
--- /dev/null
+++ b/faster_generation/perf/run_perf_codegen.sh
@@ -0,0 +1,47 @@
+export CUDA_VISIBLE_DEVICES=2
+
+for model_name in Salesforce/codegen-350M-mono Salesforce/codegen-2B-mono Salesforce/codegen-6B-mono;
+ do
+ for top_k in 1 4 8 16;
+ do
+ for input_len in 60;
+ do
+ for generate_len in 20;
+ do
+ for perf_type in pd pd_faster_fp32 pd_faster_fp16 hf;
+ do
+ echo model_name: $model_name, perf_type: $perf_type, top_k: $top_k, top_p: 1.0, input_len: $input_len, generate_len: $generate_len
+ python codegen_perf.py \
+ --model_name_or_path=$model_name \
+ --perf_type=$perf_type \
+ --top_k=$top_k \
+ --top_p=1.0 \
+ --input_len=$input_len \
+ --generate_len=$generate_len
+ sleep 3s
+ done
+ done
+ done
+ done
+ for top_p in 0.4;
+ do
+ for input_len in 60;
+ do
+ for generate_len in 20;
+ do
+ for perf_type in pd pd_faster_fp32 pd_faster_fp16 hf;
+ do
+ echo model_name: $model_name, perf_type: $perf_type, top_k: 0, top_p: $top_p, input_len: $input_len, generate_len: $generate_len
+ python codegen_perf.py \
+ --model_name_or_path=$model_name \
+ --perf_type=$perf_type \
+ --top_k=0 \
+ --top_p=$top_p \
+ --input_len=$input_len \
+ --generate_len=$generate_len
+ sleep 3s
+ done
+ done
+ done
+ done
+ done
\ No newline at end of file
diff --git a/faster_generation/samples/codegen_sample.py b/faster_generation/samples/codegen_sample.py
new file mode 100644
index 000000000000..050073d8f95c
--- /dev/null
+++ b/faster_generation/samples/codegen_sample.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM
+
+model_name = 'Salesforce/codegen-350M-mono'
+
+tokenizer = CodeGenTokenizer.from_pretrained(model_name)
+model = CodeGenForCausalLM.from_pretrained(model_name)
+model.eval()
+
+inputs = 'def hello'
+input_ids = tokenizer([inputs], return_tensors='pd')['input_ids']
+
+outputs, _ = model.generate(input_ids=input_ids,
+ max_length=128,
+ decode_strategy='greedy_search',
+ use_fp16_decoding=True,
+ use_faster=True)
+
+result = tokenizer.decode(outputs[0],
+ truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
+
+print("Model input:", inputs)
+print("Result:", result)
+# Result: _world():
+# print("Hello World")
+
+# hello_world()
diff --git a/faster_generation/samples/gptj_sample.py b/faster_generation/samples/gptj_sample.py
new file mode 100644
index 000000000000..90f041e0585d
--- /dev/null
+++ b/faster_generation/samples/gptj_sample.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from paddlenlp.transformers import GPTJTokenizer, GPTJForCausalLM
+
+paddle.set_default_dtype('float16')
+model_name = 'EleutherAI/gpt-j-6B'
+
+tokenizer = GPTJTokenizer.from_pretrained(model_name)
+model = GPTJForCausalLM.from_pretrained(model_name, load_state_as_np=True)
+model.eval()
+
+inputs = "What is PaddleNLP?"
+input_ids = tokenizer([inputs], return_tensors='pd')['input_ids']
+
+outputs, _ = model.generate(input_ids=input_ids,
+ max_length=100,
+ decode_strategy='sampling',
+ temperature=0.8,
+ top_p=0.9,
+ use_fp16_decoding=True,
+ use_faster=True)
+
+result = tokenizer.decode(outputs[0])
+
+print("Model input:", inputs)
+print("Result:", result)
\ No newline at end of file
diff --git a/paddlenlp/ops/CMakeLists.txt b/paddlenlp/ops/CMakeLists.txt
index c0c1e2d21e8b..194f32167ebe 100644
--- a/paddlenlp/ops/CMakeLists.txt
+++ b/paddlenlp/ops/CMakeLists.txt
@@ -36,6 +36,7 @@ option(WITH_BART "Compile with BART"
option(WITH_MBART "Compile with MBART" ON)
option(WITH_PARALLEL "Compile with model parallel for GPT" OFF)
option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF)
+option(WITH_GPTJ "Compile with GPTJ" ON)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
if(WITH_PARALLEL)
@@ -85,8 +86,12 @@ if(WITH_MBART)
list(APPEND decoding_op_files fusion_mbart_decoding_op.cc fusion_mbart_decoding_op.cu)
endif()
-if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER AND NOT WITH_BART AND NOT WITH_MBART)
- message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON or/and -DWITH_BART=ON or/and -DWITH_MBART=ON must be set to use FasterTransformer. ")
+if(WITH_GPTJ)
+ list(APPEND decoding_op_files fusion_gptj_op.cc fusion_gptj_op.cu)
+endif()
+
+if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER AND NOT WITH_BART AND NOT WITH_MBART AND NOT WITH_GPTJ)
+ message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON or/and -DWITH_BART=ON or/and -DWITH_MBART=ON or/and -DWITH_GPTJ=ON must be set to use FasterTransformer. ")
endif()
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
@@ -321,6 +326,21 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/opt.h opt_h_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/opt.h opt_h_dst)
+file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/gptj.h gptj_h_src)
+file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/gptj.h gptj_h_dst)
+
+file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention_utils.h masked_multihead_attention_utils_h_src)
+file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/masked_multihead_attention_utils.h masked_multihead_attention_utils_h_dst)
+
+file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.h masked_multihead_attention_h_src)
+file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/masked_multihead_attention.h masked_multihead_attention_h_dst)
+
+file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cu attention_kernels_cu_src)
+file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/attention_kernels.cu attention_kernels_cu_dst)
+
+file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh attention_kernels_cuh_src)
+file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/attention_kernels.cuh attention_kernels_cuh_dst)
+
# Encoder patches.
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_dst)
@@ -360,6 +380,11 @@ set(FT_PATCH_COMMAND
&& cp ${fastertransformer_cmakelists_src} ${fastertransformer_cmakelists_dst}
&& cp ${gpt_h_src} ${gpt_h_dst}
&& cp ${opt_h_src} ${opt_h_dst}
+ && cp ${gptj_h_src} ${gptj_h_dst}
+ && cp ${masked_multihead_attention_h_src} ${masked_multihead_attention_h_dst}
+ && cat blank_lines ${masked_multihead_attention_utils_h_src} >> ${masked_multihead_attention_utils_h_dst}
+ && cat blank_lines ${attention_kernels_cu_src} >> ${attention_kernels_cu_dst}
+ && cat blank_lines ${attention_kernels_cuh_src} >> ${attention_kernels_cuh_dst}
&& cat blank_lines ${cuda_kernels_h_src} >> ${cuda_kernels_h_dst}
&& cat blank_lines ${lightseq_kernels_cu_src} >> ${topk_kernels_dst}
&& cat blank_lines ${cuda_kernels_cu_src} >> ${cuda_kernels_cu_dst}
diff --git a/paddlenlp/ops/faster_transformer/src/CMakeLists.txt b/paddlenlp/ops/faster_transformer/src/CMakeLists.txt
index 6de779f002de..46588a22151a 100644
--- a/paddlenlp/ops/faster_transformer/src/CMakeLists.txt
+++ b/paddlenlp/ops/faster_transformer/src/CMakeLists.txt
@@ -221,6 +221,16 @@ else(ON_INFER)
set(PYTHON_PATH ${PY_CMD} CACHE STRING "Python path")
endif()
+ execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import paddle; print(paddle.__version__)"
+ RESULT_VARIABLE _INC_PYTHON_SUCCESS
+ OUTPUT_VARIABLE _INC_PYTHON_VALUES)
+ message(STATUS "PADDLE_VERSION: ${_INC_PYTHON_VALUES}")
+
+ # TODO(gongenlei): support PADDLE_NEW_ALLOCATOR for ON_INFER
+ if(_INC_PYTHON_VALUES VERSION_GREATER_EQUAL "2.3.0")
+ add_definitions(-DPADDLE_NEW_ALLOCATOR)
+ endif()
+
execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import paddle; print(paddle.sysconfig.get_include())"
RESULT_VARIABLE _INC_PYTHON_SUCCESS
OUTPUT_VARIABLE _INC_PYTHON_VALUES)
diff --git a/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cc b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cc
new file mode 100644
index 000000000000..81d5411eb4be
--- /dev/null
+++ b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cc
@@ -0,0 +1,203 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+
+#include "fusion_gptj_op.h"
+#include "pd_traits.h"
+
+
+std::vector GPTJForward(
+ const paddle::Tensor& input,
+ const paddle::Tensor& attn_mask,
+ const paddle::Tensor& start_length,
+ const paddle::Tensor& word_embedding,
+ const std::vector& self_ln_weight,
+ const std::vector& self_ln_bias,
+ const std::vector& self_q_weight,
+ const std::vector& self_out_weight,
+ const std::vector& ffn_inter_weight,
+ const std::vector& ffn_inter_bias,
+ const std::vector& ffn_out_weight,
+ const std::vector& ffn_out_bias,
+ const paddle::Tensor& decoder_ln_weight,
+ const paddle::Tensor& decoder_ln_bias,
+ const paddle::Tensor& emb_weight,
+ const paddle::Tensor& emb_bias,
+ const int topk,
+ const float topp,
+ const int max_len,
+ const int n_head,
+ const int size_per_head,
+ const int num_layer,
+ const int bos_id,
+ const int eos_id,
+ const float temperature,
+ const int rotary_embedding_dim,
+ const float repetition_penalty,
+ const int min_length,
+ const bool use_fp16 = false,
+ const int tensor_para_size = 1,
+ const int layer_para_size = 1,
+ const int layer_para_batch_size = 1) {
+ int batch_size = input.shape()[0];
+ int start_len = input.shape()[1];
+ int total_len = max_len + start_len;
+ std::vector output_dims({total_len, batch_size});
+
+#ifdef PADDLE_NEW_ALLOCATOR
+ // For PaddlePaddle>=2.3.0
+ auto output_ids = paddle::empty(output_dims, paddle::DataType::INT32, input.place());
+ auto gpu_place = paddle::GPUPlace();
+#else
+ auto output_ids = paddle::Tensor(input.place(), output_dims);
+ auto gpu_place = paddle::PlaceType::kGPU;
+#endif
+
+ if (word_embedding.place() == gpu_place) {
+ return GPTJCUDAForward(input,
+ attn_mask,
+ start_length,
+ word_embedding,
+ self_ln_weight,
+ self_ln_bias,
+ self_q_weight,
+ self_out_weight,
+ ffn_inter_weight,
+ ffn_inter_bias,
+ ffn_out_weight,
+ ffn_out_bias,
+ decoder_ln_weight,
+ decoder_ln_bias,
+ emb_weight,
+ emb_bias,
+ output_ids,
+ topk,
+ topp,
+ total_len,
+ n_head,
+ size_per_head,
+ num_layer,
+ bos_id,
+ eos_id,
+ temperature,
+ rotary_embedding_dim,
+ repetition_penalty,
+ min_length,
+ use_fp16,
+ tensor_para_size,
+ layer_para_size,
+ layer_para_batch_size);
+ } else {
+ PD_THROW("Not implemented place. Only GPU is supported. ");
+ }
+}
+
+std::vector> GPTJInferShape(
+ const std::vector& input_shape,
+ const std::vector& attn_mask_shape,
+ const std::vector& start_length,
+ const std::vector& word_embedding_shape,
+ const std::vector>& self_ln_weight_shapes,
+ const std::vector>& self_ln_bias_shapes,
+ const std::vector>& self_q_weight_shapes,
+ const std::vector>& self_out_weight_shapes,
+ const std::vector>& ffn_inter_weight_shapes,
+ const std::vector>& ffn_inter_bias_shapes,
+ const std::vector>& ffn_out_weight_shapes,
+ const std::vector>& ffn_out_bias_shapes,
+ const std::vector& decoder_ln_weight_shape,
+ const std::vector& decoder_ln_bias_shape,
+ const std::vector& emb_weight_shape,
+ const std::vector& emb_bias_shape,
+ const int topk,
+ const float topp,
+ const int max_len,
+ const int n_head,
+ const int size_per_head,
+ const int num_layer,
+ const int bos_id,
+ const int eos_id,
+ const float temperature,
+ const int rotary_embedding_dim,
+ const float repetition_penalty,
+ const int min_length,
+ const bool use_fp16 = false,
+ const int tensor_para_size = 1,
+ const int layer_para_size = 1,
+ const int layer_para_batch_size = 1) {
+ int64_t batch_size = input_shape[0];
+ int64_t start_len = input_shape[1];
+ std::vector output_dims({max_len + start_len, batch_size});
+ return {output_dims};
+}
+
+std::vector GPTJInferDtype(
+ const paddle::DataType& input_dtype,
+ const paddle::DataType& attn_mask_dtype,
+ const paddle::DataType& start_length_dtype,
+ const paddle::DataType& word_embedding_dtype,
+ const std::vector& self_ln_weight_dtype,
+ const std::vector& self_ln_bias_dtype,
+ const std::vector& self_q_weight_dtype,
+ const std::vector& self_out_weight_dtype,
+ const std::vector& ffn_inter_weight_dtype,
+ const std::vector& ffn_inter_bias_dtype,
+ const std::vector& ffn_out_weight_dtype,
+ const std::vector& ffn_out_bias_dtype,
+ const paddle::DataType& decoder_ln_weight_dtype,
+ const paddle::DataType& decoder_ln_bias_dtype,
+ const paddle::DataType& emb_weight_dtype,
+ const paddle::DataType& emb_bias_dtype) {
+ return {paddle::DataType::INT32};
+}
+
+PD_BUILD_OP(fusion_gptj)
+ .Inputs({"Input",
+ "AttentionMask",
+ "StartLength",
+ "WordEmbedding",
+ paddle::Vec("SelfLayernormWeight"),
+ paddle::Vec("SelfLayernormBias"),
+ paddle::Vec("SelfQueryWeight"),
+ paddle::Vec("SelfOutWeight"),
+ paddle::Vec("FFNInterWeight"),
+ paddle::Vec("FFNInterBias"),
+ paddle::Vec("FFNOutWeight"),
+ paddle::Vec("FFNOutBias"),
+ "DecoderLayernormWeight",
+ "DecoderLayernormBias",
+ "EmbWeight",
+ "EmbBias"})
+ .Outputs({"OutputIds"})
+ .Attrs({"topk: int",
+ "topp: float",
+ "max_len: int",
+ "n_head: int",
+ "size_per_head: int",
+ "num_layer: int",
+ "bos_id: int",
+ "eos_id: int",
+ "temperature: float",
+ "rotary_embedding_dim: int",
+ "repetition_penalty: float",
+ "min_length: int",
+ "use_fp16: bool",
+ "tensor_para_size: int",
+ "layer_para_size: int",
+ "layer_para_batch_size: int"})
+ .SetKernelFn(PD_KERNEL(GPTJForward))
+ .SetInferShapeFn(PD_INFER_SHAPE(GPTJInferShape))
+ .SetInferDtypeFn(PD_INFER_DTYPE(GPTJInferDtype));
diff --git a/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cu b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cu
new file mode 100644
index 000000000000..c2217a84a0c1
--- /dev/null
+++ b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cu
@@ -0,0 +1,334 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+// Use the global cublas handle
+#include "cublas_handle.h"
+
+// TODO(guosheng): `HOST` conflict exists in float.h of paddle and mpi.h of mpi
+#include "fusion_gptj_op.h"
+#include "pd_traits.h"
+#ifdef HOST
+#undef HOST
+#endif
+#include "fastertransformer/cuda/cub/cub.cuh"
+#include "fastertransformer/utils/common.h"
+
+#ifdef BUILD_GPT // consistent with FasterTransformer
+#include "parallel_utils.h"
+#endif
+
+template
+std::vector gptj_kernel(
+ const paddle::Tensor& input,
+ const paddle::Tensor& attn_mask,
+ const paddle::Tensor& start_length,
+ const paddle::Tensor& word_emb,
+ const std::vector& self_ln_weight,
+ const std::vector& self_ln_bias,
+ const std::vector& self_q_weight,
+ const std::vector& self_out_weight,
+ const std::vector& ffn_inter_weight,
+ const std::vector& ffn_inter_bias,
+ const std::vector& ffn_out_weight,
+ const std::vector& ffn_out_bias,
+ const paddle::Tensor& decoder_ln_weight,
+ const paddle::Tensor& decoder_ln_bias,
+ const paddle::Tensor& emb_weight,
+ const paddle::Tensor& emb_bias,
+ paddle::Tensor& output_ids,
+ const int topk,
+ const float topp,
+ const int max_len,
+ const int n_head,
+ const int size_per_head,
+ const int num_layer,
+ const int bos_id,
+ const int eos_id,
+ const float temperature,
+ const int rotary_embedding_dim,
+ const float repetition_penalty,
+ const int min_length,
+ cudaStream_t stream,
+ const int tensor_para_size = 1,
+ const int layer_para_size = 1,
+ const int layer_para_batch_size = 1) {
+ auto input_dims = input.shape();
+ int batch_size_ = input_dims[0];
+ int start_len = input_dims[1];
+ const int vocab_size = emb_bias.shape()[0];
+
+ typedef PDTraits traits_;
+ typedef typename traits_::DataType DataType_;
+ typedef typename traits_::data_t data_t_;
+
+ DecodingInitParam decoding_params;
+ decoding_params.cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
+ decoding_params.cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_;
+
+#ifdef PADDLE_NEW_ALLOCATOR
+ // For PaddlePaddle>=2.3.0
+ decoding_params.output_ids = output_ids.data();
+#else
+ decoding_params.output_ids = output_ids.mutable_data(word_emb.place());
+#endif
+
+ typedef DecoderTransformerTraits DecodingTraits_;
+ decoding_params.stream = stream;
+ fastertransformer::Allocator allocator_(stream);
+
+ const int hidden_unit = size_per_head * n_head;
+
+#ifdef BUILD_GPT
+ auto* model_para_desc = ModelParaDescFactory::CreateModelParaDesc(
+ n_head,
+ size_per_head,
+ num_layer,
+ tensor_para_size,
+ layer_para_size,
+ layer_para_batch_size,
+ const_cast(word_emb.data()));
+ auto& tensor_parallel_param = model_para_desc->tensor_parallel_param;
+ auto& layer_parallel_param = model_para_desc->layer_parallel_param;
+ auto seed = model_para_desc->dist(model_para_desc->gen);
+#else
+ TensorParallelParam tensor_parallel_param;
+ LayerParallelParam layer_parallel_param;
+ tensor_parallel_param.rank = 0;
+ tensor_parallel_param.world_size = 1;
+ tensor_parallel_param.local_head_num_ = n_head;
+ tensor_parallel_param.local_hidden_units_ = hidden_unit;
+
+ layer_parallel_param.rank = 0;
+ layer_parallel_param.world_size = 1;
+ layer_parallel_param.layers_per_group = num_layer;
+ layer_parallel_param.local_batch_size = batch_size_;
+ int seed = -1;
+#endif
+
+ DecodingGptJ* gptj_decoding;
+
+ decoding_params.request_batch_size = batch_size_;
+ decoding_params.max_input_len = start_len;
+ decoding_params.request_input_len = start_len;
+ decoding_params.request_output_len = max_len - start_len;
+
+ decoding_params.d_start_ids = const_cast(input.data());
+ decoding_params.d_attn_mask =
+ reinterpret_cast(const_cast(attn_mask.data()));
+ decoding_params.d_start_lengths = start_length.data();
+
+ gptj_decoding =
+ new DecodingGptJ(allocator_,
+ batch_size_,
+ max_len,
+ n_head,
+ size_per_head,
+ vocab_size,
+ num_layer,
+ bos_id,
+ eos_id,
+ topk,
+ topp,
+ temperature,
+ tensor_para_size,
+ layer_para_size,
+ true, /*is_fuse_QKV*/
+ repetition_penalty, /*repetition_penalty*/
+ seed,
+ rotary_embedding_dim,
+ min_length);
+
+ gptj_decoding->set_tensor_parallel_param(tensor_parallel_param);
+ gptj_decoding->set_layer_parallel_param(layer_parallel_param);
+
+ DecoderInitParam* params =
+ new DecoderInitParam[num_layer];
+
+ for (int i = 0; i < self_ln_weight.size(); ++i) {
+ // Allow python passing weights of all layers or only passing the
+ // corresponding layers to save memory.
+ int layer_idx = self_ln_weight.size() != num_layer
+ ? layer_parallel_param.rank *
+ layer_parallel_param.layers_per_group +
+ i
+ : i;
+
+ params[layer_idx].stream = stream;
+ params[layer_idx].cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
+ params[layer_idx].cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_;
+
+ params[layer_idx].request_batch_size = batch_size_;
+ params[layer_idx].request_max_mem_seq_len = start_len;
+
+ params[layer_idx].self_layernorm.gamma =
+ reinterpret_cast(self_ln_weight[i].data());
+ params[layer_idx].self_layernorm.beta =
+ reinterpret_cast(self_ln_bias[i].data());
+
+ params[layer_idx].self_attention.query_weight.kernel =
+ reinterpret_cast(self_q_weight[i].data());
+ params[layer_idx].self_attention.query_weight.bias = nullptr;
+
+ params[layer_idx].self_attention.attention_output_weight.kernel =
+ reinterpret_cast(self_out_weight[i].data());
+ params[layer_idx].self_attention.attention_output_weight.bias = nullptr;
+
+ params[layer_idx].ffn.intermediate_weight.kernel =
+ reinterpret_cast(ffn_inter_weight[i].data());
+ params[layer_idx].ffn.intermediate_weight.bias =
+ reinterpret_cast(ffn_inter_bias[i].data());
+ params[layer_idx].ffn.output_weight.kernel =
+ reinterpret_cast(ffn_out_weight[i].data());
+ params[layer_idx].ffn.output_weight.bias =
+ reinterpret_cast(ffn_out_bias[i].data());
+ }
+
+ decoding_params.layernorm.gamma =
+ reinterpret_cast(decoder_ln_weight.data());
+ decoding_params.layernorm.beta =
+ reinterpret_cast(decoder_ln_bias.data());
+ decoding_params.embedding_table =
+ reinterpret_cast(word_emb.data());
+ decoding_params.embedding_kernel =
+ reinterpret_cast(emb_weight.data());
+ decoding_params.embedding_bias =
+ reinterpret_cast(emb_bias.data());
+
+ gptj_decoding->forward_context(params, decoding_params);
+ gptj_decoding->forward(params, decoding_params);
+
+ delete gptj_decoding;
+ delete[] params;
+
+ return {output_ids};
+}
+
+std::vector GPTJCUDAForward(
+ const paddle::Tensor& input,
+ const paddle::Tensor& attn_mask,
+ const paddle::Tensor& start_length,
+ const paddle::Tensor& word_embedding,
+ const std::vector& self_ln_weight,
+ const std::vector& self_ln_bias,
+ const std::vector& self_q_weight,
+ const std::vector& self_out_weight,
+ const std::vector& ffn_inter_weight,
+ const std::vector& ffn_inter_bias,
+ const std::vector& ffn_out_weight,
+ const std::vector& ffn_out_bias,
+ const paddle::Tensor& decoder_ln_weight,
+ const paddle::Tensor& decoder_ln_bias,
+ const paddle::Tensor& emb_weight,
+ const paddle::Tensor& emb_bias,
+ paddle::Tensor& output_ids,
+ const int topk,
+ const float topp,
+ const int max_len,
+ const int n_head,
+ const int size_per_head,
+ const int num_layer,
+ const int bos_id,
+ const int eos_id,
+ const float temperature,
+ const int rotary_embedding_dim,
+ const float repetition_penalty,
+ const int min_length,
+ const bool use_fp16 = false,
+ const int tensor_para_size = 1,
+ const int layer_para_size = 1,
+ const int layer_para_batch_size = 1) {
+
+ auto stream = word_embedding.stream();
+ cublasSetStream(CublasHandle::GetInstance()->cublas_handle_, stream);
+
+ if (use_fp16) {
+ return gptj_kernel(input,
+ attn_mask,
+ start_length,
+ word_embedding,
+ self_ln_weight,
+ self_ln_bias,
+ self_q_weight,
+ self_out_weight,
+ ffn_inter_weight,
+ ffn_inter_bias,
+ ffn_out_weight,
+ ffn_out_bias,
+ decoder_ln_weight,
+ decoder_ln_bias,
+ emb_weight,
+ emb_bias,
+ output_ids,
+ topk,
+ topp,
+ max_len,
+ n_head,
+ size_per_head,
+ num_layer,
+ bos_id,
+ eos_id,
+ temperature,
+ rotary_embedding_dim,
+ repetition_penalty,
+ min_length,
+ stream,
+ tensor_para_size,
+ layer_para_size,
+ layer_para_batch_size);
+ } else {
+ return gptj_kernel(input,
+ attn_mask,
+ start_length,
+ word_embedding,
+ self_ln_weight,
+ self_ln_bias,
+ self_q_weight,
+ self_out_weight,
+ ffn_inter_weight,
+ ffn_inter_bias,
+ ffn_out_weight,
+ ffn_out_bias,
+ decoder_ln_weight,
+ decoder_ln_bias,
+ emb_weight,
+ emb_bias,
+ output_ids,
+ topk,
+ topp,
+ max_len,
+ n_head,
+ size_per_head,
+ num_layer,
+ bos_id,
+ eos_id,
+ temperature,
+ rotary_embedding_dim,
+ repetition_penalty,
+ min_length,
+ stream,
+ tensor_para_size,
+ layer_para_size,
+ layer_para_batch_size);
+ }
+}
diff --git a/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.h b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.h
new file mode 100644
index 000000000000..0470de164433
--- /dev/null
+++ b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.h
@@ -0,0 +1,64 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include
+#include
+
+#include "fastertransformer/gptj.h"
+#include "fastertransformer/open_decoder.h"
+#include "fastertransformer/utils/common.h"
+
+#ifdef PADDLE_ON_INFERENCE
+#include "paddle/include/experimental/ext_all.h"
+#else
+#include "paddle/extension.h"
+#endif
+
+
+std::vector GPTJCUDAForward(
+ const paddle::Tensor& input,
+ const paddle::Tensor& attn_mask,
+ const paddle::Tensor& start_length,
+ const paddle::Tensor& word_embedding,
+ const std::vector& self_ln_weight,
+ const std::vector& self_ln_bias,
+ const std::vector& self_q_weight,
+ const std::vector& self_out_weight,
+ const std::vector& ffn_inter_weight,
+ const std::vector& ffn_inter_bias,
+ const std::vector& ffn_out_weight,
+ const std::vector& ffn_out_bias,
+ const paddle::Tensor& decoder_ln_weight,
+ const paddle::Tensor& decoder_ln_bias,
+ const paddle::Tensor& emb_weight,
+ const paddle::Tensor& emb_bias,
+ paddle::Tensor& output_ids,
+ const int topk,
+ const float topp,
+ const int max_len,
+ const int n_head,
+ const int size_per_head,
+ const int num_layer,
+ const int bos_id,
+ const int eos_id,
+ const float temperature,
+ const int rotary_embedding_dim,
+ const float repetition_penalty,
+ const int min_length,
+ const bool use_fp16,
+ const int tensor_para_size,
+ const int layer_para_size,
+ const int layer_para_batch_size);
diff --git a/paddlenlp/ops/faster_transformer/transformer/decoding.py b/paddlenlp/ops/faster_transformer/transformer/decoding.py
index a6cf4091f731..3e53ea89680d 100644
--- a/paddlenlp/ops/faster_transformer/transformer/decoding.py
+++ b/paddlenlp/ops/faster_transformer/transformer/decoding.py
@@ -622,6 +622,68 @@ def infer_mbart_decoding(
attrs_names, attrs_val, outputs_names, outputs_dtype)
+def infer_gptj_decoding(input, attn_mask, mem_seq_len, word_emb, slf_ln_weight,
+ slf_ln_bias, slf_q_weight, slf_out_weight,
+ ffn_inter_weight, ffn_inter_bias, ffn_out_weight,
+ ffn_out_bias, decoder_ln_weight, decoder_ln_bias,
+ linear_weight, linear_bias, topk, topp, max_out_len,
+ head_num, size_per_head, num_layer, bos_id, eos_id,
+ temperature, rotary_embedding_dim, repetition_penalty,
+ min_length, use_fp16_decoding):
+ tensor_para_size = get_ft_para_conf().tensor_para_size
+ layer_para_size = get_ft_para_conf().layer_para_size
+ layer_para_batch_size = get_ft_para_conf().layer_para_batch_size
+
+ inputs = {
+ "Input": input,
+ "AttentionMask": attn_mask,
+ "StartLength": mem_seq_len,
+ "WordEmbedding": word_emb,
+ "SelfLayernormWeight@VECTOR": slf_ln_weight,
+ "SelfLayernormBias@VECTOR": slf_ln_bias,
+ "SelfQueryWeight@VECTOR": slf_q_weight,
+ "SelfOutWeight@VECTOR": slf_out_weight,
+ "FFNInterWeight@VECTOR": ffn_inter_weight,
+ "FFNInterBias@VECTOR": ffn_inter_bias,
+ "FFNOutWeight@VECTOR": ffn_out_weight,
+ "FFNOutBias@VECTOR": ffn_out_bias,
+ "DecoderLayernormWeight": decoder_ln_weight,
+ "DecoderLayernormBias": decoder_ln_bias,
+ "EmbWeight": linear_weight,
+ "EmbBias": linear_bias
+ }
+
+ attrs = {
+ "topk": topk,
+ "topp": topp,
+ "max_len": max_out_len,
+ "n_head": head_num,
+ "size_per_head": size_per_head,
+ "num_layer": num_layer,
+ "bos_id": bos_id,
+ "eos_id": eos_id,
+ "temperature": temperature,
+ "rotary_embedding_dim": rotary_embedding_dim,
+ "repetition_penalty": repetition_penalty,
+ "min_length": min_length,
+ "use_fp16": use_fp16_decoding,
+ "tensor_para_size": tensor_para_size,
+ "layer_para_size": layer_para_size,
+ "layer_para_batch_size": layer_para_batch_size
+ }
+
+ outputs_names = ["OutputIds"]
+ outputs_dtype = ["int32"]
+
+ return run_custom(op_name="fusion_gptj",
+ inputs_names=inputs.keys(),
+ inputs_var=inputs.values(),
+ attrs_names=attrs.keys(),
+ attrs_val=attrs.values(),
+ outputs_names=outputs_names,
+ outputs_dtype=outputs_dtype)
+
+
def finalize(beam_size,
output_ids,
parent_ids,
@@ -2516,3 +2578,265 @@ def forward(self,
sequence_length,
decoding_strategy=decoding_strategy)
return ids
+
+
+def convert_gptj_params(faster_model,
+ model,
+ fuse_qkv=1,
+ use_fp16=False,
+ restore_data=False,
+ permutation=None):
+ r"""
+ Convert parameters included in Transformer layer from original models
+ to the format of faster models.
+
+ Args:
+ faster_model (Layer): The faster model object.
+ model (Layer): The Transformer layer.
+ fuse_qkv (int): 0 for nofuse, 1 for fuse, 2 for fuse and delete the
+ unfused parameters. If environment variable `PPFG_QKV_MEM_OPT` is
+ set and the weights of q/k/v is fused, it will try to delete the
+ original unfused weights. Note the rollback to original model would
+ not be guarantee anymore when the faster model failed if the original
+ weights are deleted. Default to 1.
+ use_fp16 (bool): Whether to use float16. Maybe we should use the default
+ dtype as the highest priority later. Default to `False`.
+ restore_data (bool): If `False`, need to reload the weight values. It
+ should be `True` for weight loaded models. Default to `False`.
+
+ Returns:
+ defaultdict: Each value is a list including converted parameters in all
+ layers. For other parameters not included in Transformer module to
+ be converted, such as embeddings, you can achieve it by using the
+ returned dict `params` though `params['word_emb'].append()` directly
+ which would do CPU/GPU and fp32/fp16 transfer automatically.
+ """
+ if fuse_qkv == 1:
+ fuse_qkv = 2 if os.getenv("PPFG_QKV_MEM_OPT", "0") == "1" else 1
+ ft_para_conf = get_ft_para_conf()
+
+ class _list(list):
+
+ def append(self, item):
+ if isinstance(item[0], nn.Layer):
+ # Axis is used for tensor slice in tensor parallel.
+ # Use None to make no slice on the tensor.
+ if len(item) == 2:
+ layer, attr = item
+ axis = None
+ else:
+ layer, attr, axis = item
+ param = getattr(layer, attr)
+ if axis is not None and isinstance(layer, nn.Linear):
+ param = ft_para_conf.slice_weight(param, axis)
+ param = transfer_param(
+ param,
+ is_bias=attr.endswith("bias"),
+ dtype="float16" if use_fp16 else "float32",
+ restore_data=restore_data)
+ # NOTE: Assignment to parameter 'weight' should be of type
+ # Parameter or None, thus delete first in case of param is
+ # a tensor.
+ # TODO(guosheng): Make slice_weight use `output_param=True`
+ # and remove delattr. Currently, if `param` is Tensor rather
+ # than Parameter, it would not be in state_dict.
+ delattr(layer, attr)
+ setattr(layer, attr, param)
+ else:
+ # NOTE: Compared with if branch, there is no layer attribute
+ # refered to the transfered param, thus we should set it as
+ # the layer attribute to be able to convert to static graph.
+ # Additionally, we suppose no need to process tensor parallel
+ # here since the param passed in might have been processed.
+ if len(item) == 2:
+ param, is_bias = item
+ attr_handle = lambda x: x
+ else:
+ param, is_bias, attr_handle = item
+ param = transfer_param(
+ param,
+ is_bias=is_bias,
+ dtype="float16" if use_fp16 else "float32",
+ restore_data=restore_data)
+ attr_handle(param)
+ return super().append(param)
+
+ params = defaultdict(_list)
+
+ def _convert(module):
+ num_layer = len(module)
+ for i, layer in enumerate(module):
+ if not ft_para_conf.is_load(i, num_layer):
+ continue
+ # TODO(guosheng): Tensor with size 0 might be failed in
+ # paddle develop, thus use tensor with size 1 instead
+ # temporarily. Besides, we use 2D tensor since jit log
+ # requires that on linear weight. While size 0 seems all
+ # right in jit.to_static/jit.save.
+ dummy_tensor = paddle.zeros([1, 1])
+ if permutation is not None:
+ qkv = layer.attn.qkv_proj.weight.numpy()
+ qkv = qkv[:, permutation]
+ if fuse_qkv == 2:
+ del layer.attn.qkv_proj.weight
+ setattr(layer.attn.qkv_proj, "weight", dummy_tensor)
+ w = paddle.to_tensor(qkv)
+ else:
+ w = _convert_qkv(layer.attn.q_proj,
+ layer.attn.k_proj,
+ layer.attn.v_proj,
+ attr="weight",
+ use_numpy=fuse_qkv == 2,
+ del_param=fuse_qkv == 2,
+ dummy_tensor=dummy_tensor)
+ params["slf_q_weight"].append((w, False))
+ # NOTE: Use `params["slf_q_weight"][-1]` rather than `w`,
+ # since the appended tensor might be a new transfered tensor.
+ # Besides, to allow convert_params be called more than once,
+ # we find a attr name not existing to avoid overwriting the
+ # existing attr.
+ attr = "slf_q_weight_" + str(i)
+ while hasattr(faster_model, attr):
+ attr += "_"
+ setattr(faster_model, attr, params["slf_q_weight"][-1])
+
+ params["slf_out_weight"].append((layer.attn.out_proj, "weight", 0))
+ params["slf_ln_weight"].append((layer.ln_1, "weight"))
+ params["slf_ln_bias"].append((layer.ln_1, "bias"))
+ # Slice tensor when append according to axis(1 or 0) if parallel
+ # is enable.
+ params["ffn_inter_weight"].append((layer.mlp.fc_in, "weight", 1))
+ params["ffn_inter_bias"].append((layer.mlp.fc_in, "bias", 1))
+ params["ffn_out_weight"].append((layer.mlp.fc_out, "weight", 0))
+ params["ffn_out_bias"].append((layer.mlp.fc_out, "bias"))
+
+ _convert(model)
+ return params
+
+
+class InferGptJDecoding(nn.Layer):
+
+ def __init__(self,
+ model,
+ decoding_lib=None,
+ use_fp16_decoding=False,
+ transpose_qkv=False):
+ if decoding_lib is not None and os.path.isfile(decoding_lib):
+ if "FasterTransformer" not in LOADED_EXT.keys():
+ ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
+ decoding_lib)
+ LOADED_EXT["FasterTransformer"] = ops
+ else:
+ if decoding_lib is not None:
+ logger.warning(
+ "The specified decoding_lib does not exist, and it will be built automatically."
+ )
+ load("FasterTransformer"
+ if get_ft_para_conf().no_para else "FasterTransformerParallel",
+ verbose=True,
+ need_parallel=not get_ft_para_conf().no_para)
+
+ super(InferGptJDecoding, self).__init__()
+
+ self.use_fp16_decoding = use_fp16_decoding
+ self.model = model
+ self.head_num = self.model.transformer.config['n_head']
+ self.size_per_head = int(self.model.transformer.config['n_embd'] /
+ self.head_num)
+ self.num_layer = self.model.transformer.config['n_layer']
+ self.rotary_embedding_dim = self.model.transformer.config['rotary_dim']
+ logger.info("Converting model weights, it will cost a few seconds.....")
+ permutation = None
+ if transpose_qkv:
+ # GPTJ is different with CodeGen in attention project layer.
+ local_dim = self.model.transformer.config['n_embd'] // 4
+ base_permutation = [0, 3, 6, 9, 2, 5, 8, 11, 1, 4, 7, 10]
+ permutation = np.concatenate([
+ np.arange(i * local_dim, (i + 1) * local_dim)
+ for i in base_permutation
+ ])
+ params = convert_gptj_params(self,
+ model.transformer.h,
+ fuse_qkv=2,
+ use_fp16=use_fp16_decoding,
+ restore_data=True,
+ permutation=permutation)
+
+ params["word_emb"].append((self.model.transformer.wte, "weight"))
+ params["decoder_ln_weight"].append(
+ (self.model.transformer.ln_f, "weight"))
+ params["decoder_ln_bias"].append((self.model.transformer.ln_f, "bias"))
+ params["linear_weight"].append((self.model.lm_head.weight.t(),
+ partial(setattr, self,
+ "linear_weight_out")))
+ params["linear_bias"].append((self.model.lm_head, "bias"))
+
+ for k, v in params.items():
+ setattr(self, k, v)
+ logger.info("Already converted model weights.")
+
+ def forward(self,
+ input_ids,
+ mem_seq_len,
+ attention_mask=None,
+ topk=4,
+ topp=0.0,
+ bos_token_id=None,
+ eos_token_id=None,
+ pad_token_id=None,
+ forced_eos_token_id=None,
+ max_out_len=256,
+ temperature=1,
+ repetition_penalty=1.0,
+ min_length=0):
+ if attention_mask is None:
+ batch_size, input_length = paddle.shape(input_ids)
+ attention_mask = paddle.unsqueeze(
+ (input_ids != pad_token_id).astype("float32"), axis=[1])
+ causal_mask = paddle.tril(
+ paddle.ones([batch_size, input_length, input_length],
+ dtype="float32"))
+ attention_mask = paddle.logical_and(attention_mask, causal_mask)
+ if not self.use_fp16_decoding:
+ attention_mask = paddle.cast(attention_mask, dtype="float32")
+ else:
+ attention_mask = paddle.cast(attention_mask, dtype="float16")
+
+ if self.use_fp16_decoding and attention_mask.dtype == paddle.float32:
+ attention_mask = paddle.cast(attention_mask, dtype="float16")
+
+ output_ids, = infer_gptj_decoding(
+ input=[input_ids],
+ attn_mask=[attention_mask],
+ mem_seq_len=[mem_seq_len],
+ word_emb=self.word_emb,
+ slf_ln_weight=self.slf_ln_weight,
+ slf_ln_bias=self.slf_ln_bias,
+ slf_q_weight=self.slf_q_weight,
+ slf_out_weight=self.slf_out_weight,
+ ffn_inter_weight=self.ffn_inter_weight,
+ ffn_inter_bias=self.ffn_inter_bias,
+ ffn_out_weight=self.ffn_out_weight,
+ ffn_out_bias=self.ffn_out_bias,
+ decoder_ln_weight=self.decoder_ln_weight,
+ decoder_ln_bias=self.decoder_ln_bias,
+ linear_weight=self.linear_weight,
+ linear_bias=self.linear_bias,
+ topk=topk,
+ topp=topp,
+ max_out_len=max_out_len,
+ head_num=self.head_num,
+ size_per_head=self.size_per_head,
+ num_layer=self.num_layer,
+ bos_id=bos_token_id,
+ eos_id=eos_token_id,
+ temperature=temperature,
+ rotary_embedding_dim=self.rotary_embedding_dim,
+ repetition_penalty=repetition_penalty,
+ min_length=min_length,
+ use_fp16_decoding=self.use_fp16_decoding)
+
+ output_ids = output_ids[paddle.shape(input_ids)[-1]:, :]
+ if forced_eos_token_id is not None:
+ output_ids[:, -1] = forced_eos_token_id
+ return output_ids
diff --git a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
index 66dce6a0f195..5b9d7fae1f1d 100644
--- a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
+++ b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
@@ -24,7 +24,8 @@
InferTransformerModel, GPTModel)
from paddlenlp.ops import (InferTransformerDecoding, InferGptDecoding,
InferUnifiedDecoding, InferBartDecoding,
- InferMBartDecoding, InferOptDecoding)
+ InferMBartDecoding, InferOptDecoding,
+ InferGptJDecoding)
from .encoder import enable_faster_encoder, disable_faster_encoder
from paddlenlp.ops.ext_utils import load
@@ -33,7 +34,8 @@
UnifiedTransformerPretrainedModel,
UNIMOPretrainedModel, BartPretrainedModel,
GPTPretrainedModel, MBartPretrainedModel,
- OPTPretrainedModel)
+ OPTPretrainedModel, GPTJPretrainedModel,
+ CodeGenPreTrainedModel)
class FasterTransformer(TransformerModel):
@@ -1481,3 +1483,140 @@ def forward(self,
early_stopping=early_stopping)
generate = forward
+
+
+class FasterGPTJ(GPTJPretrainedModel):
+
+ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
+ super(FasterGPTJ, self).__init__()
+ self._model = model
+ self.use_fp16_decoding = use_fp16_decoding
+ self.decoding = InferGptJDecoding(model=model,
+ decoding_lib=decoding_lib,
+ use_fp16_decoding=use_fp16_decoding)
+
+ def forward(self,
+ input_ids,
+ seq_len=None,
+ attention_mask=None,
+ top_k=4,
+ top_p=0.0,
+ min_length=0,
+ max_length=256,
+ bos_token_id=None,
+ eos_token_id=None,
+ pad_token_id=None,
+ forced_eos_token_id=None,
+ temperature=0,
+ repetition_penalty=1.0,
+ decode_strategy="sampling",
+ num_return_sequences=1,
+ **model_kwargs):
+ if input_ids.dtype == paddle.int64:
+ input_ids = paddle.cast(input_ids, "int32")
+
+ # change top_p to zero if not using top_p sampling for FT
+ if decode_strategy == "greedy_search":
+ top_p = 0.0
+ top_k = 1
+ if top_p == 1.0:
+ top_p = 0.0
+ if seq_len is None:
+ seq_len = paddle.sum(paddle.cast(input_ids != pad_token_id,
+ dtype="int32"),
+ axis=-1,
+ dtype="int32")
+
+ if num_return_sequences > 1:
+ input_ids, model_kwargs = self.expand_inputs_for_generation(
+ input_ids,
+ expand_size=num_return_sequences,
+ seq_len=seq_len,
+ attention_mask=attention_mask)
+ seq_len = model_kwargs["seq_len"]
+ attention_mask = model_kwargs.get("attention_mask", None)
+
+ return self.decoding(input_ids,
+ mem_seq_len=seq_len,
+ attention_mask=attention_mask,
+ topk=top_k,
+ topp=top_p,
+ max_out_len=max_length,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ forced_eos_token_id=forced_eos_token_id,
+ temperature=temperature,
+ repetition_penalty=repetition_penalty,
+ min_length=min_length)
+
+ generate = forward
+
+
+class FasterCodeGen(CodeGenPreTrainedModel):
+
+ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
+ super(FasterCodeGen, self).__init__()
+ self._model = model
+ self.use_fp16_decoding = use_fp16_decoding
+ self.decoding = InferGptJDecoding(model=model,
+ decoding_lib=decoding_lib,
+ use_fp16_decoding=use_fp16_decoding,
+ transpose_qkv=True)
+
+ def forward(self,
+ input_ids,
+ seq_len=None,
+ attention_mask=None,
+ top_k=4,
+ top_p=0.0,
+ min_length=0,
+ max_length=256,
+ bos_token_id=None,
+ eos_token_id=None,
+ pad_token_id=None,
+ forced_eos_token_id=None,
+ temperature=0,
+ repetition_penalty=1.0,
+ decode_strategy="sampling",
+ num_return_sequences=1,
+ **model_kwargs):
+ if input_ids.dtype == paddle.int64:
+ input_ids = paddle.cast(input_ids, "int32")
+
+ # change top_p to zero if not using top_p sampling for FT
+ if decode_strategy == "greedy_search":
+ top_p = 0.0
+ top_k = 1
+ if top_p == 1.0:
+ top_p = 0.0
+ if seq_len is None:
+ seq_len = paddle.sum(paddle.cast(input_ids != pad_token_id,
+ dtype="int32"),
+ axis=-1,
+ dtype="int32")
+
+ if num_return_sequences > 1:
+ input_ids, model_kwargs = self.expand_inputs_for_generation(
+ input_ids,
+ expand_size=num_return_sequences,
+ seq_len=seq_len,
+ attention_mask=attention_mask)
+ seq_len = model_kwargs["seq_len"]
+ attention_mask = model_kwargs.get("attention_mask", None)
+
+ return self.decoding(input_ids,
+ mem_seq_len=seq_len,
+ attention_mask=attention_mask,
+ topk=top_k,
+ topp=top_p,
+ max_out_len=max_length,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ forced_eos_token_id=forced_eos_token_id,
+ temperature=temperature,
+ repetition_penalty=repetition_penalty,
+ min_length=min_length)
+
+ generate = forward
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cu
new file mode 100644
index 000000000000..32c824ef01b0
--- /dev/null
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cu
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fastertransformer/cuda/masked_multihead_attention_utils.h"
+namespace fastertransformer
+{
+
+template
+struct Vec_t {};
+template<>
+struct Vec_t {
+ using Type = float2;
+};
+template<>
+struct Vec_t {
+ using Type = uint32_t;
+};
+
+#ifdef ENABLE_BF16
+template<>
+struct Vec_t<__nv_bfloat16> {
+ using Type = __nv_bfloat162;
+};
+#endif
+
+
+template
+__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
+ T* k_buf,
+ T* v_buf,
+ const T* __restrict QKV,
+ const T* __restrict qkv_bias,
+ const int batch_size,
+ const int seq_len,
+ const int head_num,
+ const int size_per_head,
+ const int rotary_embedding_dim)
+{
+ using Vec_t = typename Vec_t::Type;
+ const int batch_idx = blockIdx.z;
+ const int head_idx = blockIdx.y;
+ const int seq_idx = blockIdx.x;
+ const int tidx = threadIdx.x;
+ if (tidx * 2 >= size_per_head) {
+ return;
+ }
+
+ const int batch_time_idx = seq_len * batch_idx + seq_idx;
+ const int hidden_idx = head_idx * size_per_head + tidx * 2;
+ const int n = head_num * size_per_head;
+
+ // src QKV: [batch, time, 3, head, hidden]
+ const int q_idx = batch_time_idx * 3 * n + hidden_idx;
+ const int k_idx = batch_time_idx * 3 * n + hidden_idx + n;
+ const int v_idx = batch_time_idx * 3 * n + hidden_idx + 2 * n;
+
+ Vec_t q = *reinterpret_cast(&QKV[q_idx]);
+ Vec_t k = *reinterpret_cast(&QKV[k_idx]);
+ Vec_t v = *reinterpret_cast(&QKV[v_idx]);
+
+ if(qkv_bias != nullptr){
+ // qkv_bias: [3, head, hidden]
+ Vec_t q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]);
+ Vec_t k_bias = *reinterpret_cast(&qkv_bias[hidden_idx + n]);
+ Vec_t v_bias = *reinterpret_cast(&qkv_bias[hidden_idx + 2 * n]);
+
+ q = mmha::add(q, q_bias);
+ k = mmha::add(k, k_bias);
+ v = mmha::add(v, v_bias);
+ }
+
+ mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, seq_idx);
+
+ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head]
+ const int dest_idx = size_per_head * seq_len * head_num * batch_idx + size_per_head * seq_len * head_idx
+ + size_per_head * seq_idx + tidx * 2;
+
+ *reinterpret_cast(&q_buf[dest_idx]) = q;
+ *reinterpret_cast(&k_buf[dest_idx]) = k;
+ *reinterpret_cast(&v_buf[dest_idx]) = v;
+}
+
+template
+void add_fusedQKV_bias_transpose_kernelLauncher(
+ T* q_buf,
+ T* k_buf,
+ T* v_buf,
+ T* QKV,
+ const T* qkv_bias,
+ const int batch_size,
+ const int seq_len,
+ const int head_num,
+ const int size_per_head,
+ const int rotary_embedding_dim,
+ cudaStream_t stream)
+{
+ if (rotary_embedding_dim == 0) {
+ const int m = batch_size * seq_len;
+ const int n = head_num * size_per_head;
+ dim3 block(384);
+ dim3 grid((int)(ceil(1.0 * m * n / 384)));
+ add_fusedQKV_bias_transpose_kernel<<>>(
+ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head);
+ }
+ else {
+ // To implement rotary embeddings, each thread processes two QKV elems:
+ dim3 block((size_per_head / 2 + 31) / 32 * 32);
+ dim3 grid(seq_len, head_num, batch_size);
+ add_fusedQKV_bias_transpose_kernel<<>>(
+ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, rotary_embedding_dim);
+ }
+}
+
+template void add_fusedQKV_bias_transpose_kernelLauncher(
+ float* q_buf,
+ float* k_buf,
+ float* v_buf,
+ float* QKV,
+ const float* qkv_bias,
+ const int batch_size,
+ const int seq_len,
+ const int head_num,
+ const int size_per_head,
+ const int rotary_embedding_dim,
+ cudaStream_t stream);
+
+template void add_fusedQKV_bias_transpose_kernelLauncher(
+ half* q_buf,
+ half* k_buf,
+ half* v_buf,
+ half* QKV,
+ const half* qkv_bias,
+ const int batch_size,
+ const int seq_len,
+ const int head_num,
+ const int size_per_head,
+ const int rotary_embedding_dim,
+ cudaStream_t stream);
+
+} // namespace fastertransformer
\ No newline at end of file
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh
new file mode 100644
index 000000000000..0c1fd232994a
--- /dev/null
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+namespace fastertransformer
+{
+
+template
+void add_fusedQKV_bias_transpose_kernelLauncher(
+ T* q_buf,
+ T* k_buf,
+ T* v_buf,
+ T* QKV,
+ const T* qkv_bias,
+ const int batch_size,
+ const int seq_len,
+ const int head_num,
+ const int size_per_head,
+ const int rotary_embedding_dim,
+ cudaStream_t stream);
+
+
+} // namespace fastertransformer
\ No newline at end of file
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.h
index 409919775c3f..8149895a9e7f 100644
--- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.h
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.h
@@ -150,6 +150,31 @@ void apply_logits_mask_kernelLauncher(T* log_probs,
cudaStream_t stream,
const T* logits_mask = nullptr,
const bool min_penalty = false,
- const int end_id = -1);
+ const int end_id = -1,
+ const T* bias = nullptr);
+
+template
+void gptj_start_id_embedding_lookups_kernel_launcher(T* from_tensor,
+ int* output_ids,
+ const T* embedding_table,
+ const int* word_ids,
+ const int length,
+ const int max_length,
+ const int batch_size,
+ const int hidden_units,
+ cudaStream_t stream);
+
+template
+void gpj_embedding_lookups_kernel_launcher(T* from_tensor,
+ const T* embedding_table,
+ const int* word_ids,
+ const int local_batch_size,
+ const int batch_size,
+ const int hidden_units,
+ int step,
+ int ite,
+ int max_input_len,
+ const int* start_lengths,
+ cudaStream_t stream);
} // namespace fastertransformer
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/decoding_kernels.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/decoding_kernels.cu
index 48d32a44a8ae..23c3aaf3a160 100644
--- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/decoding_kernels.cu
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/decoding_kernels.cu
@@ -340,7 +340,8 @@ __global__ void apply_logits_mask_kernel(int vocab_size_padded,
const bool* finished,
const T* logits_mask = nullptr,
const bool min_penalty = false,
- const int end_id = -1) {
+ const int end_id = -1,
+ const T* bias = nullptr) {
int tid = threadIdx.x;
int bid = blockIdx.x;
int bbid = blockIdx.y; // batch_size * beam_size: index
@@ -355,6 +356,8 @@ __global__ void apply_logits_mask_kernel(int vocab_size_padded,
log_probs[i + bbid * vocab_size_padded] += -MAX_T_VAL;
} else if (logits_mask) {
log_probs[i + bbid * vocab_size_padded] += logits_mask[i];
+ } else if (bias) {
+ log_probs[i + bbid * vocab_size_padded] += bias[i];
} else {
continue;
}
@@ -372,8 +375,9 @@ void apply_logits_mask_kernelLauncher(T* log_probs,
cudaStream_t stream,
const T* logits_mask,
const bool min_penalty,
- const int end_id) {
- if (logits_mask == nullptr && !min_penalty) return;
+ const int end_id,
+ const T* bias) {
+ if (logits_mask == nullptr && !min_penalty && bias == nullptr) return;
dim3 block(256);
dim3 grid((vocab_size_padded + block.x - 1) / block.x,
@@ -386,7 +390,122 @@ void apply_logits_mask_kernelLauncher(T* log_probs,
finished,
logits_mask,
min_penalty,
- end_id);
+ end_id,
+ bias);
+}
+
+
+ template __launch_bounds__(1024, 1)
+ __global__ void gptj_start_id_embedding_lookups_kernel(T* from_tensor,
+ int* output_ids,
+ const T* embedding_table,
+ const int* word_ids,
+ const int length,
+ const int max_length,
+ const int batch_size,
+ const int hidden_units)
+ {
+ for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * length * hidden_units; index += blockDim.x * gridDim.x)
+ {
+ // transpose the word_ids [batch, length] (part of [batch, max_length]) to output_ids [length, batch]
+ if(index < batch_size * max_length)
+ {
+ const int seq_id = index % max_length;
+ const int batch_id = index / max_length;
+ if(seq_id < length)
+ output_ids[seq_id * batch_size + batch_id] = word_ids[index];
+ // output_ids[index] = word_ids[index];
+ }
+
+ // embedding lookup from word ids [batch, length] (part of [batch, max_length]) and [vocab, hidden] to generate embedding [batch, length, hidden]
+ const int word_index = index / hidden_units;
+ const int word_index_row = word_index / length;
+ const int word_index_col = word_index % length;
+ const int real_word_index = word_index_row * max_length + word_index_col;
+ const int col_index = index % hidden_units;
+ from_tensor[index] = embedding_table[word_ids[real_word_index] * hidden_units + col_index];
+ }
+ }
+
+
+ template
+ void gptj_start_id_embedding_lookups_kernel_launcher(T* from_tensor,
+ int *output_ids,
+ const T* embedding_table,
+ const int* word_ids,
+ const int length,
+ const int max_length,
+ const int batch_size,
+ const int hidden_units,
+ cudaStream_t stream)
+ {
+ dim3 grid(min(batch_size * length, 65536));
+ dim3 block(min(hidden_units, 1024));
+ gptj_start_id_embedding_lookups_kernel<<>>(from_tensor,
+ output_ids,
+ embedding_table,
+ word_ids,
+ length,
+ max_length,
+ batch_size,
+ hidden_units);
+ }
+
+
+ // TODO Add half2 implementation
+template
+__global__ void gptj_embedding_lookups_kernel(
+ T* from_tensor,
+ const T* embedding_table,
+ const int* word_ids,
+ const int local_batch_size,
+ const int batch_size,
+ const int hidden_units,
+ int step,
+ int ite,
+ int max_input_len,
+ const int* start_lengths) {
+ int timestep = step - 1;
+ // if the input is padded in the batch, indices of the word_id
+ // should be shifted forward by the length of the padding.
+ int len_padding =
+ max_input_len - start_lengths[local_batch_size * ite + blockIdx.x];
+ int idx_word_id = (step == max_input_len) ? timestep - len_padding : timestep;
+
+ int* word_ids_buf =
+ (int*)word_ids + idx_word_id * batch_size + local_batch_size * ite;
+ T* from_tensor_buf = from_tensor + blockIdx.x * hidden_units;
+ for (int index = threadIdx.x; index < hidden_units; index += blockDim.x) {
+ from_tensor_buf[index] =
+ embedding_table[word_ids_buf[blockIdx.x] * hidden_units + index];
+ }
+}
+
+template
+void gpj_embedding_lookups_kernel_launcher(T* from_tensor,
+ const T* embedding_table,
+ const int* word_ids,
+ const int local_batch_size,
+ const int batch_size,
+ const int hidden_units,
+ int step,
+ int ite,
+ int max_input_len,
+ const int* start_lengths,
+ cudaStream_t stream) {
+ dim3 grid(min(local_batch_size, 65536));
+ dim3 block(min(hidden_units, 1024));
+ gptj_embedding_lookups_kernel
+ <<>>(from_tensor,
+ embedding_table,
+ word_ids,
+ local_batch_size,
+ batch_size,
+ hidden_units,
+ step,
+ ite,
+ max_input_len,
+ start_lengths);
}
template void init_kernelLauncher_v2(bool* finished,
@@ -527,7 +646,8 @@ template void apply_logits_mask_kernelLauncher(
cudaStream_t stream,
const float* logits_mask,
const bool min_penalty,
- const int end_id);
+ const int end_id,
+ const float* bias);
template void apply_logits_mask_kernelLauncher(
half* log_probs,
@@ -539,6 +659,55 @@ template void apply_logits_mask_kernelLauncher(
cudaStream_t stream,
const half* logits_mask,
const bool min_penalty,
- const int end_id);
+ const int end_id,
+ const half* bias);
+
+ template
+ void gptj_start_id_embedding_lookups_kernel_launcher(float* from_tensor,
+ int* output_ids,
+ const float* embedding_table,
+ const int* word_ids,
+ const int length,
+ const int max_length,
+ const int batch_size,
+ const int hidden_units,
+ cudaStream_t stream);
+
+ template
+ void gptj_start_id_embedding_lookups_kernel_launcher(half* from_tensor,
+ int* output_ids,
+ const half* embedding_table,
+ const int* word_ids,
+ const int length,
+ const int max_length,
+ const int batch_size,
+ const int hidden_units,
+ cudaStream_t stream);
+
+ template void gpj_embedding_lookups_kernel_launcher(
+ float* from_tensor,
+ const float* embedding_table,
+ const int* word_ids,
+ const int local_batch_size,
+ const int batch_size,
+ const int hidden_units,
+ int step,
+ int ite,
+ int max_input_len,
+ const int* start_lengths,
+ cudaStream_t stream);
+
+template void gpj_embedding_lookups_kernel_launcher(
+ half* from_tensor,
+ const half* embedding_table,
+ const int* word_ids,
+ const int local_batch_size,
+ const int batch_size,
+ const int hidden_units,
+ int step,
+ int ite,
+ int max_input_len,
+ const int* start_lengths,
+ cudaStream_t stream);
} // end of name space fastertransformer
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.cu
index 00b15c2e5a6f..aa03d980d2f6 100644
--- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.cu
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.cu
@@ -1,4 +1,5 @@
/***************************************************************************************************
+ * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are not permit-
@@ -80,9 +81,11 @@ struct Qk_vec_ {};
template<> struct Qk_vec_ { using Type = float; };
template<> struct Qk_vec_ { using Type = float2; };
template<> struct Qk_vec_ { using Type = float4; };
+template<> struct Qk_vec_ { using Type = float4; };
template<> struct Qk_vec_ { using Type = uint32_t; };
template<> struct Qk_vec_ { using Type = uint32_t; };
template<> struct Qk_vec_ { using Type = uint2; };
+template<> struct Qk_vec_ { using Type = uint4; };
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -430,7 +433,7 @@ inline size_t smem_size_in_bytes(
// The number of partial rows to reduce in the final reduction.
// int rows_per_red = threads_per_block / threads_per_value;
// to solve `threads_per_block / threads_per_value` is not 2^n
- int rows_per_red = pad_active_groups;
+ int rows_per_red = params.rotary_embedding_dim>0 ? threads_per_block / threads_per_value: pad_active_groups;
// The amount of storage needed to finalize the outputs.
size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2;
@@ -438,6 +441,7 @@ inline size_t smem_size_in_bytes(
return max(softmax_sz, red_sz);
}
+
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ constexpr uint32_t shfl_mask(int threads) {
@@ -500,7 +504,7 @@ __global__ void masked_multihead_attention_kernel(Masked_multihead_attention_par
// The number of elements per vector.
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
// Make sure the hidden size per head is a multiple of the vector size.
- static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, "");
+ static_assert(Dh % QK_VEC_SIZE == 0, "");
// The number of vectors per warp.
constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE;
@@ -837,6 +841,568 @@ __global__ void masked_multihead_attention_kernel(Masked_multihead_attention_par
}
}
+template<
+ // The type of the inputs. Supported types: float and half.
+ typename T,
+ // The hidden dimension per head.
+ int Dh,
+ int Dh_MAX,
+ // The number of threads per key.
+ int THREADS_PER_KEY,
+ // The number of threads per value.
+ int THREADS_PER_VALUE,
+ // The number of threads in a threadblock.
+ int THREADS_PER_BLOCK
+ >
+__global__ void gptj_masked_multihead_attention_kernel(Masked_multihead_attention_params params, int pad_active_groups)
+{
+
+ // Make sure the hidden dimension per head is a multiple of the number of threads per key.
+ static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
+ // Make sure the hidden dimension per head is a multiple of the number of threads per value.
+ static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
+
+ // The size of a warp.
+ constexpr int WARP_SIZE = 32;
+ // The number of warps in a threadblock.
+ constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
+
+ // Use smem_size_in_bytes (above) to determine the amount of shared memory.
+ extern __shared__ char smem_[];
+
+ // The shared memory for the Q*K^T values and partial logits in softmax.
+ float* qk_smem = reinterpret_cast(smem_);
+
+ // The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
+ char* logits_smem_ = smem_;
+
+ // DO_CROSS_ATTENTION = false
+ constexpr bool DO_CROSS_ATTENTION = false;
+
+#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
+ if (sizeof(T) != 4) {
+ // TODO - cahnge to tlength
+ logits_smem_ +=
+ (DO_CROSS_ATTENTION) ? div_up(params.seq_length + 1, 4) * 16 : div_up(params.timestep + 1, 4) * 16;
+ }
+ T* logits_smem = reinterpret_cast(logits_smem_);
+#else
+ float* logits_smem = reinterpret_cast(logits_smem_);
+#endif
+
+ // The shared memory to do the final reduction for the output values. Reuse qk_smem.
+ T* out_smem = reinterpret_cast(smem_);
+
+ // The shared memory buffers for the block-wide reductions. One for max, one for sum.
+ __shared__ float red_smem[WARPS_PER_BLOCK * 2];
+
+ // A vector of Q or K elements for the current timestep.
+ using Qk_vec = typename Qk_vec_::Type;
+
+ // Use alignment for safely casting the shared buffers as Qk_vec.
+ // Shared memory to store Q inputs.
+ __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
+
+ // This is one of the reasons we should have a separate kernel for cross attention
+ __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
+
+ // A vector of Q or K elements for the current timestep.
+ using Qk_vec = typename Qk_vec_::Type;
+ // The number of elements per vector.
+ constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
+ // Make sure the hidden size per head is a multiple of the vector size.
+ static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
+ // We will use block wide reduction if needed
+ // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
+ // The number of vectors per warp.
+ constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
+
+ // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
+ // owns x elements, we have to decompose the linear index into chunks of x values and the posi-
+ // tion of the thread in that chunk.
+
+ // The number of elements in a chunk of 16B (that's the x in the above formula).
+ constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
+ // The number of K vectors in 16B.
+ constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
+
+ // The batch/beam idx
+ const int bi = blockIdx.y;
+ if (params.finished != nullptr && params.finished[bi] == true) {
+ return;
+ }
+ // The beam idx
+ const int beami = bi % params.beam_width;
+ // The "beam-aware" batch idx
+ const int bbi = bi / params.beam_width;
+ // The head.
+ const int hi = blockIdx.x;
+ // Combine the batch and the head indices.
+ const int bhi = bi * params.num_heads + hi;
+ // Combine the "beam-aware" batch idx and the head indices.
+ const int bbhi = bbi * params.beam_width * params.num_heads + hi;
+ // The thread in the block.
+ const int tidx = threadIdx.x;
+
+ // While doing the product Q*K^T for the different keys we track the max.
+ float qk_max = -FLT_MAX;
+
+ float qk = 0.0F;
+
+ int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
+
+ // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
+ int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
+ (params.length_per_sample == nullptr) ? params.timestep :
+ params.length_per_sample[bi];
+ // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
+ if (tidx < QK_VECS_PER_WARP) {
+
+ // The offset in the Q and K buffer also accounts for the batch.
+ int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
+ // The offset in the bias buffer.
+ int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
+
+ // Trigger the loads from the Q and K buffers.
+ Qk_vec q;
+ zero(q);
+ q = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? *reinterpret_cast(¶ms.q[qk_offset]) : q;
+ Qk_vec k;
+ zero(k);
+ if (DO_CROSS_ATTENTION) {
+ // The 16B chunk written by the thread.
+ int co = tidx / QK_VECS_IN_16B;
+ // The position of the thread in that 16B chunk.
+ int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
+
+ // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
+ int offset = bhi * params.seq_length * Dh + co * params.seq_length * QK_ELTS_IN_16B +
+ // params.timestep*QK_ELTS_IN_16B +
+ tlength * QK_ELTS_IN_16B + ci;
+ k = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? *reinterpret_cast(¶ms.k_cache[offset]) :
+ k;
+ }
+ else {
+ k = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? *reinterpret_cast(¶ms.k[qk_offset]) : k;
+ }
+
+ // Trigger the loads from the Q and K bias buffers.
+ Qk_vec q_bias;
+ zero(q_bias);
+ q_bias = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
+ *reinterpret_cast(¶ms.q_bias[qk_bias_offset]) :
+ q_bias;
+ Qk_vec k_bias;
+ zero(k_bias);
+
+ if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) {
+ k_bias = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
+ *reinterpret_cast(¶ms.k_bias[qk_bias_offset]) :
+ k_bias;
+ }
+
+ // Computes the Q/K values with bias.
+ q = add(q, q_bias);
+ if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) {
+ k = add(k, k_bias);
+ if (params.rotary_embedding_dim > 0) {
+ apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep);
+ }
+ }
+ else {
+ if (params.rotary_embedding_dim > 0) {
+ apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep);
+ }
+ }
+
+ // Store the Q values to shared memory.
+ *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q;
+
+ // Store Dh values of k_bias into smem, since will need to add later
+ // if params.timestep == 0
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
+ *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
+ }
+
+ // Write the K values to the global memory cache.
+ //
+ // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
+ // system. We designed it this way as it allows much better memory loads (and there are many
+ // more loads) + the stores are really "write and forget" since we won't need the ack before
+ // the end of the kernel. There's plenty of time for the transactions to complete.
+
+ // The 16B chunk written by the thread.
+ int co = tidx / QK_VECS_IN_16B;
+ // The position of the thread in that 16B chunk.
+ int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
+
+ // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
+ int offset = bhi * params.seq_length * Dh + co * params.seq_length * QK_ELTS_IN_16B +
+ // params.timestep*QK_ELTS_IN_16B +
+ tlength * QK_ELTS_IN_16B + ci;
+
+ if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) {
+ // Trigger the stores to global memory.
+ if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
+ *reinterpret_cast(¶ms.k_cache[offset]) = k;
+ }
+ }
+
+ // Compute \sum_i Q[i] * K^T[i] for the current timestep.
+#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
+ using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type;
+#else
+ using Qk_vec_acum = Qk_vec;
+#endif
+ qk = dot(q, k);
+ if (QK_VECS_PER_WARP <= WARP_SIZE) {
+#pragma unroll
+ for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
+ qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
+ }
+ }
+ }
+
+ if (QK_VECS_PER_WARP > WARP_SIZE) {
+ constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
+ qk = block_sum(&red_smem[WARPS_PER_RED], qk);
+ }
+
+ // Store that value in shared memory. Keep the Q*K^T value in register for softmax.
+ if (tidx == 0) {
+ // Normalize qk.
+ qk *= params.inv_sqrt_dh;
+
+ if (params.relative_attention_bias_float != nullptr) {
+ qk = qk
+ + params.relative_attention_bias_float[hi * params.relative_attention_bias_stride
+ * params.relative_attention_bias_stride
+ + tlength * params.relative_attention_bias_stride + tlength];
+ }
+ else if (params.relative_attention_bias_half != nullptr) {
+ qk = qk
+ + (float)
+ params.relative_attention_bias_half[hi * params.relative_attention_bias_stride
+ * params.relative_attention_bias_stride
+ + tlength * params.relative_attention_bias_stride + tlength];
+ }
+ qk_max = qk;
+ qk_smem[tlength] = qk;
+ // qk_smem[params.timestep] = qk;
+ }
+
+ // Make sure the data is in shared memory.
+ __syncthreads();
+
+ // The type of queries and keys for the math in the Q*K^T product.
+ using K_vec = typename K_vec_::Type;
+ // The number of elements per vector.
+ constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
+ // Make sure the hidden size per head is a multiple of the vector size.
+ static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
+ // The number of elements per thread.
+ constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
+ // The number of vectors per thread.
+ constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
+
+ // The position the first key loaded by each thread from the cache buffer (for this B * H).
+ int ko = tidx / THREADS_PER_KEY;
+ // The position of the thread in the chunk of keys.
+ int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
+
+ static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
+
+ // Load the Q values from shared memory. The values are reused during the loop on K.
+ K_vec q[K_VECS_PER_THREAD];
+#pragma unroll
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+ q[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
+ }
+
+ K_vec k_bias[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1];
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
+#pragma unroll
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+ k_bias[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
+ }
+ }
+
+ // The number of timesteps loaded per iteration.
+ constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
+ // The number of keys per warp.
+ constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
+
+ // The base pointer for the key in the cache buffer.
+ T* k_cache = ¶ms.k_cache[bhi * params.seq_length * Dh + ki];
+ // Base pointer for the beam's batch, before offsetting with indirection buffer
+ T* k_cache_batch = ¶ms.k_cache[bbhi * params.seq_length * Dh + ki];
+
+ // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
+ // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
+ int ti_end = div_up(tlength, K_PER_WARP) * K_PER_WARP;
+
+ // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
+ for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
+
+ // The keys loaded from the key cache.
+ K_vec k[K_VECS_PER_THREAD];
+ K_vec k_vec_zero;
+ zero(k_vec_zero);
+#pragma unroll
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+ int jj = ii * params.seq_length + ti;
+ // if( ti < params.timestep ) {
+ if (ti < tlength) {
+ const int beam_src =
+ (params.cache_indir != nullptr) ?
+ params.cache_indir[(bbi * params.beam_width + beami) * params.seq_length + ti] :
+ 0;
+ const int beam_offset = beam_src * params.num_heads * params.seq_length * Dh;
+ k[ii] = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.seq_length) ?
+ *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]) :
+ k_vec_zero;
+ // add bias and update k_cache
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
+ k[ii] = add(k[ii], k_bias[ii]);
+ if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.seq_length) {
+ *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii];
+ }
+ }
+ }
+ }
+
+ // Perform the dot product and normalize qk.
+ //
+ // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
+ float qk = Qk_dot::dot(q, k) * params.inv_sqrt_dh;
+ bool is_mask = (params.input_lengths != nullptr && ti >= params.input_lengths[bi] && ti < params.max_input_len);
+
+ // Store the product to shared memory. There's one qk value per timestep. Update the max.
+ // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
+ if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
+ if (params.relative_attention_bias_float != nullptr) {
+ qk = qk
+ + params.relative_attention_bias_float[hi * params.relative_attention_bias_stride
+ * params.relative_attention_bias_stride
+ + tlength * params.relative_attention_bias_stride + ti];
+ }
+ else if (params.relative_attention_bias_half != nullptr) {
+ qk = qk
+ + (float)
+ params.relative_attention_bias_half[hi * params.relative_attention_bias_stride
+ * params.relative_attention_bias_stride
+ + tlength * params.relative_attention_bias_stride + ti];
+ }
+ qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
+ qk_smem[ti] = qk;
+ }
+ }
+
+// Perform the final reduction to compute the max inside each warp.
+//
+// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
+// group so it's not needed to run the reduction inside the group (again).
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
+ qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+ }
+
+ // Decompose the thread index into warp and lane.
+ const int warp = tidx / WARP_SIZE;
+ const int lane = tidx % WARP_SIZE;
+
+ // The warp leader writes the max to shared memory.
+ if (lane == 0) {
+ red_smem[warp] = qk_max;
+ }
+
+ // Make sure the products are in shared memory.
+ __syncthreads();
+
+ // The warps finalize the reduction.
+ qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+ for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
+ qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+ }
+
+ // Broadcast to all the threads in the warp.
+ qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
+
+ // Compute the logits and start the sum.
+ float sum = 0.f;
+ // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
+ for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
+ bool is_mask = (params.input_lengths != nullptr && ti >= params.input_lengths[bi] && ti < params.max_input_len);
+ float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max);
+ sum += logit;
+ qk_smem[ti] = logit;
+ }
+
+ // Compute the sum.
+ sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum);
+
+ // Normalize the logits.
+ float inv_sum = __fdividef(1.f, sum + 1.e-6f);
+ // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
+ for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
+ convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum);
+ }
+
+ // Put Values part below so we leverage __syncthreads
+ // from the previous step
+
+ // The number of elements per vector.
+ constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
+ // A vector of V elements for the current timestep.
+ using V_vec = typename V_vec_::Type;
+
+ // The value computed by this thread.
+ int vo = tidx / THREADS_PER_VALUE;
+ // The hidden dimensions computed by this particular thread.
+ int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
+
+ // The base pointer for the value in the cache buffer.
+ T* v_cache = ¶ms.v_cache[bhi * params.seq_length * Dh + vi];
+ // Base pointer for the beam's batch, before offsetting with indirection buffer
+ T* v_cache_batch = ¶ms.v_cache[bbhi * params.seq_length * Dh + vi];
+
+ // The number of values processed per iteration of the loop.
+ constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
+
+ // One group of threads computes the product(s) for the current timestep.
+ V_vec v_bias;
+ zero(v_bias);
+ // if( vo == params.timestep % V_PER_ITER ) {
+ if (Dh == Dh_MAX || vi < Dh) {
+ if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) {
+ if (vo == tlength % V_PER_ITER) {
+ // Trigger the loads from the V bias buffer.
+ if (params.v_bias != nullptr) {
+ v_bias = *reinterpret_cast(¶ms.v_bias[hi * Dh + vi]);
+ }
+ if (DO_CROSS_ATTENTION) {
+ *reinterpret_cast(&bias_smem[vi]) = v_bias;
+ }
+ }
+ }
+ }
+
+ // From previous, before values, step
+ // Also make sure the logits are in shared memory.
+ __syncthreads();
+
+ // Values continued
+#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
+ using V_vec_acum = typename V_vec_acum_fp32_::Type;
+#else
+ using V_vec_acum = V_vec;
+#endif
+ // The partial outputs computed by each thread.
+ V_vec_acum out;
+ zero(out);
+
+ // Loop over the timesteps to compute the partial outputs.
+ // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
+ if (Dh == Dh_MAX || vi < Dh) {
+ for (int ti = vo; ti < tlength; ti += V_PER_ITER) {
+
+ // Fetch offset based on cache_indir when beam sampling
+ const int beam_src = (params.cache_indir != nullptr) ?
+ params.cache_indir[(bbi * params.beam_width + beami) * params.seq_length + ti] :
+ 0;
+ const int beam_offset = beam_src * params.num_heads * params.seq_length * Dh;
+ // Load the values from the cache.
+ V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh]);
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
+ v = add(v, *reinterpret_cast(&bias_smem[vi]));
+ *reinterpret_cast(&v_cache[ti * Dh]) = v;
+ }
+ // Load the logits from shared memory.
+#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
+ float logit = logits_smem[ti];
+ out = fma(logit, cast_to_float(v), out);
+#else
+ T logit = logits_smem[ti];
+
+ // Update the partial sums.
+ out = fma(logit, v, out);
+#endif
+ }
+ }
+
+ // One group of threads computes the product(s) for the current timestep.
+ // if( vo == params.timestep % V_PER_ITER ) {
+ if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
+
+ V_vec v;
+ if (DO_CROSS_ATTENTION) {
+ v = *reinterpret_cast(&v_cache[tlength * Dh]);
+ }
+ else {
+ // Trigger the loads from the V buffer.
+ v = *reinterpret_cast(¶ms.v[qkv_base_offset + vi]);
+ // Trigger the loads from the V bias buffer.
+ // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]);
+ }
+
+ // Compute the V values with bias.
+ if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) {
+ v = add(v, v_bias);
+
+ // Store the values with bias back to global memory in the cache for V.
+ //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v;
+ *reinterpret_cast(&v_cache[tlength * Dh]) = v;
+ }
+
+ // Initialize the output value with the current timestep.
+#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
+ // out = fma(logits_smem[params.timestep], cast_to_float(v), out);
+ out = fma(logits_smem[tlength], cast_to_float(v), out);
+#else
+ // out = fma(logits_smem[params.timestep], v, out);
+ out = fma(logits_smem[tlength], v, out);
+#endif
+ }
+
+ // Make sure we can start writing to shared memory.
+ __syncthreads();
+
+ // Run the final reduction amongst the different groups computing different partial outputs.
+ if (Dh == Dh_MAX || vi < Dh)
+#pragma unroll
+ for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
+
+ // The midpoint in the number of active groups.
+ int midpoint = active_groups / 2;
+
+ // The upper part of active threads store to shared memory.
+ if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
+#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
+ convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out);
+#else
+ *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out;
+#endif
+ }
+ __syncthreads();
+
+ // The bottom warps update their values.
+ if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
+ out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out);
+ }
+ __syncthreads();
+ }
+
+ // Output the final values.
+ if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
+#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
+ convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out);
+#else
+ *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out;
+#endif
+ }
+}
+
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace mmha
@@ -846,21 +1412,40 @@ __global__ void masked_multihead_attention_kernel(Masked_multihead_attention_par
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \
int pad_active_groups = 1 << static_cast(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \
size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK, pad_active_groups); \
- dim3 grid(params.num_heads, params.batch_size); \
+ dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel \
<<>>(params, pad_active_groups)
+
+#define GPTJ_MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \
+ int pad_active_groups = 1 << static_cast(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \
+ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK, pad_active_groups); \
+ dim3 grid(params.num_heads, params.batch_size); \
+ mmha::gptj_masked_multihead_attention_kernel \
+ <<>>(params, pad_active_groups)
+
////////////////////////////////////////////////////////////////////////////////////////////////////
template < typename T, int Dh, int Dh_MAX>
void mmha_launch_kernel(const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream) {
- constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16;
- if( params.timestep < 32 ) {
- MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
- } else if( params.timestep < 2048 ) {
- MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
- } else {
- MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream);
+ if(params.rotary_embedding_dim>0){
+ constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
+ if( params.timestep < 32 ) {
+ GPTJ_MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
+ } else if( params.timestep < 2048 ) {
+ GPTJ_MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
+ } else {
+ GPTJ_MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream);
+ }
+ }else{
+ constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16;
+ if( params.timestep < 32 ) {
+ MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
+ } else if( params.timestep < 2048 ) {
+ MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
+ } else {
+ MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream);
+ }
}
}
@@ -884,6 +1469,18 @@ void masked_multihead_attention_(const Masked_multihead_attention_params &par
case 128:
mmha_launch_kernel(params, stream);
break;
+ case 160:
+ mmha_launch_kernel(params, stream);
+ break;
+ case 192:
+ mmha_launch_kernel(params, stream);
+ break;
+ case 224:
+ mmha_launch_kernel(params, stream);
+ break;
+ case 256: // GPTJ/CodeGen
+ mmha_launch_kernel(params, stream);
+ break;
default:
// assert(false);
throw std::runtime_error("Unsupported model size.");
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.h
new file mode 100644
index 000000000000..992a598536ee
--- /dev/null
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define CHECK_CUDA(call) do { \
+ cudaError_t status_ = call; \
+ if( status_ != cudaSuccess ) { \
+ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
+ exit(1); \
+ } \
+} while(0)
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// The structure of parameters for the masked multihead attention kernel.
+//
+// We use the following terminology to describe the different dimensions.
+//
+// B: Batch size (number of sequences),
+// L: Sequence length,
+// D: Hidden dimension,
+// H: Number of heads,
+// Dh: Hidden dimension per head - Dh = D / H.
+
+template< typename T >
+struct Masked_multihead_attention_params {
+
+ // The output buffer. Dimensions B x D.
+ T *out;
+
+ // The input Qs and the associated bias. Dimensions B x D and D, resp.
+ const T *q, *q_bias;
+ // The input Ks and the associated bias. Dimensions B x D and D, resp.
+ const T *k, *k_bias;
+ // The input Vs and the associated bias. Dimensions B x D and D, resp.
+ const T *v, *v_bias;
+
+ // The cache for the Ks. The size must be at least B x L x D.
+ T *k_cache;
+ // The cache for the Vs. The size must be at least B x L x D.
+ T *v_cache;
+
+ // The indirections to use for cache when beam sampling.
+ const int* cache_indir = nullptr;
+
+ // allows to exist attention eary
+ bool *finished;
+
+ // Stride to handle the case when KQV is a single buffer
+ int stride;
+
+ // The batch size.
+ int batch_size;
+ // The sequence length.
+ int seq_length;
+ // The number of heads (H).
+ int num_heads;
+ // The hidden dimension per head (Dh).
+ int hidden_size_per_head;
+ // The current timestep.
+ int timestep;
+
+ // The per-head latent space reserved for rotary embeddings.
+ int rotary_embedding_dim = 0;
+
+ // The 1.f / sqrt(Dh). Computed on the host.
+ float inv_sqrt_dh;
+
+ // params for masking.
+ bool is_mask;
+ const int *input_lengths = input_lengths;
+ int max_input_len = max_input_len;
+
+ const float* relative_attention_bias_float = nullptr;
+ const half* relative_attention_bias_half = nullptr;
+ int relative_attention_bias_stride;
+ // The beam width
+ int beam_width = 1;
+ // required in case of cross attention
+ int* memory_length_per_sample = nullptr;
+ // required in case of masked attention with different length
+ const int* length_per_sample = nullptr;
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+void masked_multihead_attention (const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream);
+void masked_multihead_attention (const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream);
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention_utils.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention_utils.h
new file mode 100644
index 000000000000..cdcc47a06180
--- /dev/null
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention_utils.h
@@ -0,0 +1,265 @@
+/*
+ * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+namespace mmha {
+
+inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step)
+{
+ const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim);
+ return {cos(inv_freq), sin(inv_freq)};
+}
+
+inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
+{
+ float2 rot_v;
+ rot_v.x = coef.x * v.x - coef.y * v.y;
+ rot_v.y = coef.x * v.y + coef.y * v.x;
+ return rot_v;
+}
+
+inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef)
+{
+ float2 fv = half2_to_float2(v);
+ float2 rot_fv = rotary_embedding_transform(fv, coef);
+ return float2_to_half2(rot_fv);
+}
+
+#ifdef ENABLE_BF16
+inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef)
+{
+ float2 fv = bf1622float2(v);
+ float2 rot_fv = rotary_embedding_transform(fv, coef);
+ return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
+}
+#endif
+
+inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step)
+{
+ return;
+}
+
+inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step)
+{
+ return;
+}
+
+inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step)
+{
+ if (2 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+ q = rotary_embedding_transform(q, coef);
+}
+
+inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step)
+{
+ if (2 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+ q = rotary_embedding_transform(q, coef);
+ k = rotary_embedding_transform(k, coef);
+}
+
+inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step)
+{
+ if (4 * tid >= rot_embed_dim) {
+ return;
+ }
+
+ Float4_& q_ = *reinterpret_cast(&q);
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+ q_.x = rotary_embedding_transform(q_.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+ q_.y = rotary_embedding_transform(q_.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step)
+{
+ if (4 * tid >= rot_embed_dim) {
+ return;
+ }
+
+ Float4_& q_ = *reinterpret_cast(&q);
+ Float4_& k_ = *reinterpret_cast(&k);
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+ q_.x = rotary_embedding_transform(q_.x, coef0);
+ k_.x = rotary_embedding_transform(k_.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+ q_.y = rotary_embedding_transform(q_.y, coef1);
+ k_.y = rotary_embedding_transform(k_.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step)
+{
+ if (2 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+ q = rotary_embedding_transform(q, coef);
+}
+
+inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step)
+{
+ if (2 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+ q = rotary_embedding_transform(q, coef);
+ k = rotary_embedding_transform(k, coef);
+}
+
+inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step)
+{
+ if (4 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+ q.x = rotary_embedding_transform(q.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+ q.y = rotary_embedding_transform(q.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step)
+{
+ if (4 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+ q.x = rotary_embedding_transform(q.x, coef0);
+ k.x = rotary_embedding_transform(k.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+ q.y = rotary_embedding_transform(q.y, coef1);
+ k.y = rotary_embedding_transform(k.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step)
+{
+ if (8 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
+ q.x = rotary_embedding_transform(q.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
+ q.y = rotary_embedding_transform(q.y, coef1);
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
+ q.z = rotary_embedding_transform(q.z, coef2);
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
+ q.w = rotary_embedding_transform(q.w, coef3);
+}
+
+inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step)
+{
+ if (8 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
+ q.x = rotary_embedding_transform(q.x, coef0);
+ k.x = rotary_embedding_transform(k.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
+ q.y = rotary_embedding_transform(q.y, coef1);
+ k.y = rotary_embedding_transform(k.y, coef1);
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
+ q.z = rotary_embedding_transform(q.z, coef2);
+ k.z = rotary_embedding_transform(k.z, coef2);
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
+ q.w = rotary_embedding_transform(q.w, coef3);
+ k.w = rotary_embedding_transform(k.w, coef3);
+}
+
+#ifdef ENABLE_BF16
+inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step)
+{
+ if (2 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+ q = rotary_embedding_transform(q, coef);
+}
+
+inline __device__ void
+apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step)
+{
+ if (2 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+ q = rotary_embedding_transform(q, coef);
+ k = rotary_embedding_transform(k, coef);
+}
+
+inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step)
+{
+ if (4 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+ q.x = rotary_embedding_transform(q.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+ q.y = rotary_embedding_transform(q.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step)
+{
+ if (4 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+ q.x = rotary_embedding_transform(q.x, coef0);
+ k.x = rotary_embedding_transform(k.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+ q.y = rotary_embedding_transform(q.y, coef1);
+ k.y = rotary_embedding_transform(k.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step)
+{
+ if (8 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
+ q.x = rotary_embedding_transform(q.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
+ q.y = rotary_embedding_transform(q.y, coef1);
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
+ q.z = rotary_embedding_transform(q.z, coef2);
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
+ q.w = rotary_embedding_transform(q.w, coef3);
+}
+
+inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step)
+{
+ if (8 * tid >= rot_embed_dim) {
+ return;
+ }
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
+ q.x = rotary_embedding_transform(q.x, coef0);
+ k.x = rotary_embedding_transform(k.x, coef0);
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
+ q.y = rotary_embedding_transform(q.y, coef1);
+ k.y = rotary_embedding_transform(k.y, coef1);
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
+ q.z = rotary_embedding_transform(q.z, coef2);
+ k.z = rotary_embedding_transform(k.z, coef2);
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
+ q.w = rotary_embedding_transform(q.w, coef3);
+ k.w = rotary_embedding_transform(k.w, coef3);
+}
+#endif // ENABLE_BF16
+
+} // namespace mmha
\ No newline at end of file
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_attention.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_attention.h
index b2d657a18610..6702087e9b11 100644
--- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_attention.h
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_attention.h
@@ -1,4 +1,5 @@
/*
+ * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cu
index 1d2aa51410e0..6245cdd3e46f 100644
--- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cu
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cu
@@ -159,4 +159,96 @@ template void transpose_general_kernelLauncher(half* dst,
const int head_num,
const int size_per_head,
cudaStream_t stream);
+
+
+
+template
+void fusedQKV_masked_attention_dispatch_v2(
+ const T* qkv_buf, const T* qkv_bias,
+ T* key_cache, T* value_cache,
+ T* context_buf, const bool* finished, int max_batch_size, int inference_batch_size,
+ int head_num, int size_per_head, const int step, const int max_seq_len,
+ const int max_input_len, const int* input_lengths, const int rotary_embedding_dim, cudaStream_t stream)
+{
+ using DataType = typename std::conditional::type;
+ // Prepare the parameters.
+ Masked_multihead_attention_params params;
+ memset(¶ms, 0, sizeof(params));
+ int hidden_units = head_num * size_per_head;
+ if (qkv_bias != nullptr) {
+ params.q_bias = reinterpret_cast(qkv_bias);
+ params.k_bias = reinterpret_cast(qkv_bias) + hidden_units;
+ params.v_bias = reinterpret_cast(qkv_bias) + 2 * hidden_units;
+ }
+ else {
+ // gptj/codegen no bias
+ params.q_bias = nullptr;
+ params.k_bias = nullptr;
+ params.v_bias = nullptr;
+ }
+
+ // Set the output buffer.
+ params.out = reinterpret_cast(context_buf);
+
+ // Set the input buffers.
+ params.q = reinterpret_cast(qkv_buf);
+ params.k = reinterpret_cast(qkv_buf) + hidden_units;
+ params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units;
+ params.stride = 3 * hidden_units;
+ params.finished = const_cast(finished);
+
+ params.k_cache = reinterpret_cast(key_cache);
+ params.v_cache = reinterpret_cast(value_cache);
+ params.batch_size = inference_batch_size;
+ params.seq_length = max_seq_len;
+ params.timestep = step-1;
+ params.num_heads = head_num;
+ params.hidden_size_per_head = size_per_head;
+ // GptJ: rotary_embedding
+ params.rotary_embedding_dim = rotary_embedding_dim;
+ params.inv_sqrt_dh = 1.F / sqrtf((float) params.hidden_size_per_head);
+
+ params.is_mask = true;
+ params.input_lengths = input_lengths;
+ params.max_input_len = max_input_len;
+
+ masked_multihead_attention(params, stream);
+}
+
+template void fusedQKV_masked_attention_dispatch_v2(
+ const float* qkv_buf,
+ const float* qkv_bias,
+ float* key_cache,
+ float* value_cache,
+ float* context_buf,
+ const bool* finished,
+ int max_batch_size,
+ int inference_batch_size,
+ int head_num,
+ int size_per_head,
+ const int step,
+ const int max_seq_len,
+ const int max_input_len,
+ const int* input_lengths,
+ const int rotary_embedding_dim,
+ cudaStream_t stream);
+
+template void fusedQKV_masked_attention_dispatch_v2(
+ const half* qkv_buf,
+ const half* qkv_bias,
+ half* key_cache,
+ half* value_cache,
+ half* context_buf,
+ const bool* finished,
+ int max_batch_size,
+ int inference_batch_size,
+ int head_num,
+ int size_per_head,
+ const int step,
+ const int max_seq_len,
+ const int max_input_len,
+ const int* input_lengths,
+ const int rotary_embedding_dim,
+ cudaStream_t stream);
+
}
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cuh b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cuh
index cabc89591971..624f36235c4b 100644
--- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cuh
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cuh
@@ -46,4 +46,12 @@ __global__ void transpose(T* src,
const int seq_len,
const int head_num,
const int size_per_head);
+
+template
+void fusedQKV_masked_attention_dispatch_v2(
+ const T* qkv_buf, const T* qkv_bias,
+ T* key_cache, T* value_cache,
+ T* context_buf, const bool* finished, int max_batch_size, int inference_batch_size,
+ int head_num, int size_per_head, const int step, const int max_seq_len,
+ const int max_input_len, const int* input_lengths, const int rotary_embedding_dim, cudaStream_t stream);
}
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu
index bfd680a6b24f..2f364f9c880d 100644
--- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu
@@ -35,13 +35,28 @@ __global__ void ker_curand_setup(curandState_t* state,
&state[blockIdx.x * blockDim.x + threadIdx.x]);
}
+__global__ void ker_curand_setup_bsz_one(curandState_t* state,
+ const int size,
+ const int seed) {
+ if (threadIdx.x + blockIdx.x * blockDim.x < size)
+ curand_init(seed,
+ 0,
+ seed,
+ &state[blockIdx.x * blockDim.x + threadIdx.x]);
+}
+
void ker_curand_setupLauncher(curandState_t* state,
DecodingSamplingArguments args,
cudaStream_t stream) {
dim3 block(256);
dim3 grid((int)(ceil(args.batch_size_ * 1.0 / 256)));
int seed = args.seed_ != -1 ? args.seed_ : clock() % INT_MAX;
- ker_curand_setup<<>>(state, args.batch_size_, seed);
+ if(args.batch_size_ != 1)
+ ker_curand_setup<<>>(state, args.batch_size_, seed);
+ else
+ // Reduce the huge occupation of gpu memory due to curand_init func when bsz=1.
+ // TODO(gongenlei): Solve above problem when bsz > 1.
+ ker_curand_setup_bsz_one<<>>(state, args.batch_size_, seed);
}
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cu
index 46a5efc9f57d..365b2fc0b388 100644
--- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cu
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cu
@@ -404,11 +404,17 @@ __global__ void add_bias_input_layernorm_2(const T* __restrict input,
for (int i = tid; i < n; i += blockDim.x) {
float local_out = (float)(__ldg(&input[blockIdx.x * n + i]));
local_out += (float)(output[blockIdx.x * n + i]);
- local_out += (float)(__ldg(&bias[i]));
+ if(bias != nullptr){
+ local_out += (float)(__ldg(&bias[i]));
+ }
output[blockIdx.x * n + i] = (T)local_out;
local_sum += local_out;
}
+ if (gamma == nullptr || beta == nullptr){
+ return;
+ }
+
mean = blockReduceSum(local_sum);
if (threadIdx.x == 0) s_mean = mean / n;
diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/gptj.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/gptj.h
new file mode 100644
index 000000000000..208fb4663b7f
--- /dev/null
+++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/gptj.h
@@ -0,0 +1,946 @@
+/*
+ * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * Decoder transformer
+ **/
+
+#pragma once
+
+#include "fastertransformer/utils/common.h"
+#include "fastertransformer/utils/functions.h"
+#include "fastertransformer/utils/allocator.h"
+#include "fastertransformer/utils/arguments.h"
+#include "fastertransformer/cuda/cuda_kernels.h"
+#include "fastertransformer/open_decoder.h"
+#include
+#include
+#include "fastertransformer/utils/nvtx_utils.h"
+
+namespace fastertransformer
+{
+
+template
+class DecodingGptJ
+{
+private:
+ typedef DecoderTransformerTraits Traits_;
+ typedef typename Traits_::DataType DataType_;
+ const IAllocator &allocator_;
+ struct GptJArguments args_;
+ TensorParallelParam t_parallel_param_;
+ LayerParallelParam l_parallel_param_;
+
+ const cudaDataType_t computeType_ = Traits_::computeType;
+ const cudaDataType_t AType_ = Traits_::AType;
+ const cudaDataType_t BType_ = Traits_::BType;
+ const cudaDataType_t CType_ = Traits_::CType;
+ std::map cublasAlgoMap_;
+
+ DataType_ *embedding_kernel_padded_;
+ DataType_ *embedding_bias_padded_;
+
+ OpenDecoder *decoder_;
+ DataType_ **K_cache_;
+ DataType_ **V_cache_;
+ DataType_ *from_tensor_[2];
+ DataType_ *decoder_buf_;
+ DataType_ *decoder_normed_result_buf_;
+ DataType_ *logits_buf_;
+ void *buf_;
+
+ void *topk_workspace_ = nullptr;
+ size_t topk_workspace_size_ = 0;
+ void *topp_workspace_ = nullptr;
+ size_t topp_workspace_size_ = 0;
+ void *topk_topp_workspace_ = nullptr;
+ size_t topk_topp_workspace_size_ = 0;
+ void *cublas_workspace_ = nullptr;
+ int *topp_id_vals_buf_;
+ int *topp_offset_buf_;
+ curandState_t *curandstate_buf_;
+ int *begin_topp_offset_buf_;
+
+ size_t nccl_buf_size_;
+ DataType_ *nccl_logits_buf_;
+
+ bool *finished_buf_;
+ bool *h_finished_buf_;
+
+public:
+ DecodingGptJ(const IAllocator &allocator,
+ const int batch_size,
+ const int seq_len,
+ const int head_num,
+ const int size_per_head,
+ const int vocab_size,
+ const int decoder_layers,
+ const int start_id,
+ const int end_id,
+ const int candidate_num = 1,
+ const float probability_threshold = 0.0,
+ const float temperature = 1.0,
+ const int tensor_para_size = 1,
+ const int layer_para_size = 1,
+ const bool is_fuse_QKV = true,
+ const float repetition_penalty = 1.0,
+ const int seed = -1,
+ const int rotary_embedding_dim = 0,
+ const int min_length = 0) : allocator_(allocator)
+ {
+#ifndef NDEBUG
+ PRINT_FUNC_NAME_();
+#endif
+ assert(temperature != 0.0);
+ assert(repetition_penalty > 0.0);
+ assert(candidate_num > 0 || probability_threshold > 0.0);
+ assert(decoder_layers % layer_para_size == 0);
+
+ args_.batch_size_ = batch_size;
+ args_.seq_len_ = seq_len;
+ args_.head_num_ = head_num;
+ args_.size_per_head_ = size_per_head;
+ args_.hidden_units_ = head_num * size_per_head;
+ args_.decoder_layers_ = decoder_layers;
+ args_.vocab_size_ = vocab_size;
+ args_.start_id_ = start_id;
+ args_.end_id_ = end_id;
+ args_.candidate_num_ = candidate_num;
+ args_.probability_threshold_ = probability_threshold;
+ args_.temperature_ = temperature;
+ args_.repetition_penalty_ = repetition_penalty;
+ /***** newly added by PaddleNLP *****/
+ args_.seed_ = seed;
+ args_.rotary_embedding_dim_ = rotary_embedding_dim;
+ args_.min_length_ = min_length;
+
+ K_cache_ = new DataType_ *[1];
+ V_cache_ = new DataType_ *[1];
+
+ decoder_ = new OpenDecoder(args_.head_num_, size_per_head, 0 /* memory_hidden_units */, is_fuse_QKV);
+ decoder_->set_max_batch_size(args_.batch_size_);
+
+
+ // args_.vocab_size_padded_ = div_up(args_.vocab_size_, 64) * 64;
+ if(std::is_same::value)
+ {
+ args_.vocab_size_padded_ = args_.vocab_size_;
+ }
+ else
+ {
+ args_.vocab_size_padded_ = div_up(args_.vocab_size_, 64) * 64;
+ }
+
+ size_t from_tensor_size = args_.batch_size_ * args_.hidden_units_; // type T
+ size_t decoder_workspace_size = (size_t)decoder_->getWorkspaceSize(); // type T
+ size_t decoder_normed_result_buffer_size = args_.batch_size_ * args_.hidden_units_; // type T
+ // cache costs lots of memory, so we only store part of them when we use multi-gpu for inference
+ size_t cache_size = args_.batch_size_ * args_.seq_len_ * args_.hidden_units_ / tensor_para_size; // type T
+ size_t logits_buf_size = args_.batch_size_ * args_.vocab_size_padded_; // type T
+
+ size_t topp_id_vals_buf_size = args_.batch_size_ * args_.vocab_size_padded_; // type int
+ size_t topp_offset_buf_size = args_.batch_size_ + 1;
+ size_t begin_topp_offset_buf_size = topp_offset_buf_size;
+ size_t curandState_size = args_.batch_size_;
+ size_t finished_buf_size = args_.batch_size_;
+
+ const int MEM_C = 128;
+ size_t embedding_kernel_transposed_padded_size = args_.hidden_units_ * args_.vocab_size_padded_;
+ embedding_kernel_transposed_padded_size = div_up(embedding_kernel_transposed_padded_size, MEM_C) * MEM_C;
+
+
+ size_t padded_embedding_bias_size = args_.vocab_size_padded_;
+ if(std::is_same