diff --git a/examples/code_generation/codegen/README.md b/examples/code_generation/codegen/README.md index 0d83dc8589bc..72430d775ed9 100644 --- a/examples/code_generation/codegen/README.md +++ b/examples/code_generation/codegen/README.md @@ -1,74 +1,219 @@ -# CodeGen: A Conversational Paradigm for Program Synthesis +# 代码生成:写代码的AI助理 -## 模型简介 +**目录** +- [代码生成](#代码生成) + - [简介](#简介) + - [特色](#特色) + - [效果展示](#效果展示) + - [开箱即用](#开箱即用) + - [支持单条、批量预测](#支持单条批量预测) + - [可配置参数说明](#可配置参数说明) + - [训练定制](#训练定制) + - [环境依赖](#环境依赖) + - [代码结构说明](#代码结构说明) + - [数据准备](#数据准备) + - [从本地文件创建数据集](#从本地文件创建数据集) + - [Github Copilot插件配置](#GithubCopilot插件配置) + - [插件环境依赖](#插件环境依赖) + - [启动服务](#启动服务) + - [配置参数](#配置参数说明) + - [测试服务](#测试服务) + - [配置插件](#配置插件) + - [注意事项](#注意事项) + - [TaskFlow调用](#TaskFlow调用) + - [使用案例](#使用案例) + - [模型列表](#模型列表) + - [References](#references) -[CodeGen](https://arxiv.org/pdf/2203.13474.pdf) (A Conversational Paradigm for Program Synthesis)提出了一种通过大型语言模型进行对话式程序生成的方法,将编写规范和程序的过程转换为用户和系统之间的多回合对话。它把程序生成看作一个序列预测问题,用自然语言表达规范,并有条件地对所期望的程序进行抽样。同时,CodeGen(16B)在HumanEval benchmark上已经超过[OpenAI's Codex](https://arxiv.org/pdf/2107.03374.pdf)。 -本项目展示如何调用CodeGen来进行代码生成。 +## 简介 +代码生成是根据编程人员的输入,生成出编程人员想要的代码,能够帮助编程人员甚至独立生成代码,提高编程效率。 -## 快速开始 + +### 特色 + +本项目是基于预训练语言模型CodeGen的代码生成,具有以下优势: +- **效果领先**。CodeGen(16B)在HumanEval benchmark上评估指标已经超过[OpenAI's Codex](https://arxiv.org/pdf/2107.03374.pdf)。 +- **免费的Github Copilot**。支持通过Github Copilot调用该模型,让你免费体验代码AI助理。 +- **高性能**。基于[FasterGeneration](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/faster_generation)打造高性能推理,毫秒级响应。具体加速指标可参考[perf](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/faster_generation/README.md)。 +- **支持自定义数据集训练**。可增加自己的代码数据加以微调,让其更智能。 +- **开箱即用**。本项目提供TaskFlow接口,无需训练,仅需几行代码便可预测。 + + +## 效果展示 + +## 训练定制 ### 环境依赖 +- PaddleNLP >= 2.4.0 +- PaddlePaddle >= 2.3.1 + +### 代码结构说明 + +以下是本项目主要代码结构及说明: + +```text +codegen/ +├── requirements.txt # 环境依赖 +├── codegen_server.py # server启动脚本 +├── run_clm.py # 训练评估脚本 +├── run_clm.sh # 启动脚本 +└── README.md # 说明文档 +``` + +### 数据准备 + +#### 从本地文件创建数据集 + +在许多情况,我们需要使用本地数据集来训练我们的代码生成模型,本项目支持使用固定格式本地数据集文件进行训练。 + +本地数据集文件格式如下: +- train.json/test.json 文件格式: +每行为一个jsonline +```text +{ + "code": "from paddlenlp.transformers import CodeGenForCausalLM\n\n\nmodel = CodeGenForCausalLM.from_pretrained('Salesforce/codegen-2B-mono')\n" +} +``` + +更多数据集读取格式详见[数据集加载](https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_load.html#)和[自定义数据集](https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_self_defined.html)。 + + +### 模型训练 +运行如下命令即可在样例训练集上进行finetune,并在样例验证集上进行验证。 + +```shell +# GPU启动,参数`--gpus`指定训练所用的GPU卡号,可以是单卡,也可以多卡 +unset CUDA_VISIBLE_DEVICES + +python -m paddle.distributed.launch --gpus 0,1 run_clm.py \ + --model_name_or_path Salesforce/codegen-350M-mono \ + --block_size 1024 \ + --output_dir output \ + --train_file train.json \ + --validation_file test.json \ + --num_train_epochs 5 \ + --logging_steps 1 \ + --save_steps 10 \ + --train_batch_size 2 \ + --eval_batch_size 2 \ + --learning_rate 1e-4 \ + --warmup_proportion 0.1 \ + --device gpu +``` +使用多卡训练可以指定多个GPU卡号,例如 --gpus "0,1" + +关键参数释义如下: +- `gpus` 指示了训练所用的GPU卡号。 +- `model_name_or_path` 指示了finetune使用的具体预训练模型,可以是PaddleNLP提供的预训练模型(详见[模型列表](#模型列表)),或者是本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。 +- `block_size` 表示训练时候数据被拆分的块数。 +- `output_dir` 表示模型的保存路径。 +- `train_file` 本地训练数据地址,数据格式必须与`dataset_name`所指数据集格式相同。 +- `validation_file` 本地测试数据地址,数据格式必须与`dataset_name`所指数据集格式相同。 +- `num_train_epochs` 表示训练轮数。 +- `logging_steps` 表示日志打印间隔。 +- `save_steps` 表示模型保存及评估间隔。 +- `train_batch_size` 表示训练时**每张卡**上的样本数目。 +- `eval_batch_size` 表示测试时**每张卡**上的样本数目。 +- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。 +- `warmup_propotion` 表示学习率逐渐升高到基础学习率(即上面配置的learning_rate)所需要的迭代数占总步数的比例,最早的使用可以参考[这篇论文](https://arxiv.org/pdf/1706.02677.pdf)。 +- `device` 表示使用的设备,从gpu和cpu中选择。 - - python >= 3.6 - - paddlepaddle >= 2.3.0 - - paddlenlp >= 2.3.4 +可通过`bash run_clm.sh`启动训练,更多参数详情和参数的默认值请参考`run_clm.py`。 -### 代码调用 +程序运行时将会自动进行训练和验证,训练过程中会自动保存模型在指定的`save_dir`中。 +如: +```text +./output/ +│── model_config.json +│── model_state.pdparams +│── tokenizer_config.json +│── special_tokens_map.json +│── added_tokens.json +│── vocab.json +│── merges.txt +└── ... +``` + +**NOTE:** 如需恢复模型训练,`model_name_or_path`配置本地模型的目录地址即可。 + +## GithubCopilot插件配置 +以下以VS Code的插件为例 +### 插件环境依赖 +- PaddleNLP >= 2.4.0 +- PaddlePaddle >= 2.3.1 + +其他依赖:`pip install -r requirements.txt` + + +### 启动服务 ```python -import re -import paddle -from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM +python codegen_server.py +``` + +##### 配置参数说明 +在codegen_server.py中配置如下参数: +- `model_name_or_path`:模型名,默认为 "Salesforce/codegen-2B-mono" +- `device`:运行设备,默认为"gpu" +- `temperature`:解码参数temperature,默认为0.5 +- `top_k`:解码参数top_k,默认为10 +- `top_p`:解码参数top_p,默认为1.0 +- `repetition_penalty`:解码重复惩罚项,默认为1.0 +- `min_length`:生成的最小长度,默认为0 +- `max_length`:生成的最大长度,默认为16 +- `decode_strategy`:解码策略,默认为"sampling" +- `load_state_as_np`:以numpy格式加载模型参数,可节省显存,默认为True +- `use_faster`:是否使用Fastergeneration,可加速推理,默认为True +- `use_fp16_decoding`:是否使用fp16推理,可节省显存和加速推理,默认为True + +### 测试服务 +`pip install --upgrade openai` + +```python +import openai +openai.api_key = 'dummy' +openai.api_base = 'http://127.0.0.1:8000/v1' +result = openai.Completion.create( + engine='codegen', prompt='def hello', max_tokens=16, temperature=0.1) +print(result) +''' + JSON: { + "id": "cmpl-dmhoeHmcw9DJ4NeqOJDQVKv3iivJ0", + "choices": [ + { + "text": "_world():\n print(\"Hello World!\")\n\n\n#", + "index": 0, + "finish_reason": "stop", + "logprobs": null, + } + ], + "usage": { + "completion_tokens": null, + "prompt_tokens": null, + "total_tokens": null + } +} +''' -# The supported models are shown in the following table -model_name = 'Salesforce/codegen-350M-mono' -# Init tokenizer -tokenizer = CodeGenTokenizer.from_pretrained(model_name) -# Init model -model = CodeGenForCausalLM.from_pretrained(model_name) -inputs = tokenizer(["def hello_world():"]) -inputs = {k: paddle.to_tensor(v) for (k, v) in inputs.items()} -# Generate -output, score = model.generate(inputs['input_ids'], - max_length=128, - decode_strategy='sampling', - top_k=5, - repetition_penalty=1.1, - temperature=0.6) -# Decode the result -print( - re.split( - "\nclass|\ndef|\n#|\n@|\nprint|\nif", - tokenizer.decode(output[0], - skip_special_tokens=True, - spaces_between_special_tokens=False))[0].rstrip()) ``` -其中参数释义如下: -- `max_length` 解码的最大长度,默认128。 -- `decode_strategy` 解码的策略,默认sampling。 -- `top_k` 解码参数top_k,默认5。 -- `repetition_penalty` 解码重复惩罚系数,默认1.1。 -- `temperature` 解码参数temperature,默认0.6。 +### 配置插件 +打开用户设置([settings.json](https://code.visualstudio.com/docs/getstarted/settings#_settings-file-locations)),增加一行配置 +```json + "github.copilot.advanced": { + "debug.overrideEngine": "codegen", + "debug.testOverrideProxyUrl": "http://127.0.0.1:8978", + "debug.overrideProxyUrl": "http://127.0.0.1:8978" + }, +``` -模型列表 -| 模型名称 | 说明 | -| :--------------------------------- | -------------------------------- | -| Salesforce/codegen-350M-mono | 基于Python数据集BIGPYTHON训练 | -| Salesforce/codegen-2B-mono | 基于Python数据集BIGPYTHON训练 | -| Salesforce/codegen-6B-mono | 基于Python数据集BIGPYTHON训练 | -| Salesforce/codegen-16B-mono | 基于Python数据集BIGPYTHON训练 | -| Salesforce/codegen-350M-nl | 基于自然语言数据集THEPILE训练 | -| Salesforce/codegen-2B-nl | 基于自然语言数据集THEPILE训练 | -| Salesforce/codegen-6B-nl | 基于自然语言数据集THEPILE训练 | -| Salesforce/codegen-16B-nl | 基于自然语言数据集THEPILE训练 | -| Salesforce/codegen-350M-multi | 基于多编程语言数据集BIGQUERY训练 | -| Salesforce/codegen-2B-multi | 基于多编程语言数据集BIGQUERY训练 | -| Salesforce/codegen-6B-multi | 基于多编程语言数据集BIGQUERY训练 | -| Salesforce/codegen-16B-multi | 基于多编程语言数据集BIGQUERY训练 | +接下来就可以愉快地使用了😊。 +#### 注意事项 +- 如果使用FasterGeneration,需要设置[codegen_server.py](#配置参数说明)中`use_faster=True`,第一次推理会涉及到编译,会耗费一些时间。FasterGeneration的环境依赖参考[这里](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/ops/README.md#%E4%BD%BF%E7%94%A8%E7%8E%AF%E5%A2%83%E8%AF%B4%E6%98%8E)。 +- 如果要使用自己训练好的模型,可以设置[codegen_server.py](#配置参数说明)中`model_name_or_path`为本地模型路径。 -### TaskFlow调用 +## TaskFlow调用 参考[TaskFlow文档](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/model_zoo/taskflow.md) ## 使用案例 @@ -157,4 +302,24 @@ def hello_world(): hello_world() ``` -其它更多趣味性的生成欢迎大家体验,同时也欢迎大家来开发代码补全的插件。 + +## 模型列表 +模型列表 +| 模型名称 | 说明 | +| :--------------------------------- | -------------------------------- | +| Salesforce/codegen-350M-mono | 基于Python数据集BIGPYTHON训练 | +| Salesforce/codegen-2B-mono | 基于Python数据集BIGPYTHON训练 | +| Salesforce/codegen-6B-mono | 基于Python数据集BIGPYTHON训练 | +| Salesforce/codegen-16B-mono | 基于Python数据集BIGPYTHON训练 | +| Salesforce/codegen-350M-nl | 基于自然语言数据集THEPILE训练 | +| Salesforce/codegen-2B-nl | 基于自然语言数据集THEPILE训练 | +| Salesforce/codegen-6B-nl | 基于自然语言数据集THEPILE训练 | +| Salesforce/codegen-16B-nl | 基于自然语言数据集THEPILE训练 | +| Salesforce/codegen-350M-multi | 基于多编程语言数据集BIGQUERY训练 | +| Salesforce/codegen-2B-multi | 基于多编程语言数据集BIGQUERY训练 | +| Salesforce/codegen-6B-multi | 基于多编程语言数据集BIGQUERY训练 | +| Salesforce/codegen-16B-multi | 基于多编程语言数据集BIGQUERY训练 | + +## References +- Nijkamp, Erik, et al. "A conversational paradigm for program synthesis." arXiv preprint arXiv:2203.13474 (2022). +- [https://github.com/features/copilot/](https://github.com/features/copilot/) diff --git a/examples/code_generation/codegen/codegen_server.py b/examples/code_generation/codegen/codegen_server.py new file mode 100644 index 000000000000..fd448749c7ac --- /dev/null +++ b/examples/code_generation/codegen/codegen_server.py @@ -0,0 +1,141 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import string +import time +import uvicorn +import paddle +from paddlenlp.utils.log import logger +from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM +from sse_starlette.sse import EventSourceResponse +from fastapi import FastAPI, Response, status +from pydantic import BaseModel + + +class DefaultConfig: + model_name_or_path = "Salesforce/codegen-2B-mono" + device = "gpu" + temperature = 0.5 + top_k = 10 + top_p = 1.0 + repetition_penalty = 1.0 + min_length = 0 + max_length = 16 + decode_strategy = "sampling" + load_state_as_np = True + use_faster = True + use_fp16_decoding = True + default_dtype = "float16" if use_faster and use_fp16_decoding else "float32" + + +class Input(BaseModel): + prompt: str + stream: bool = False + + +class Output(BaseModel): + id: str + model: str = "codegen" + object: str = "text_completion" + created: int = int(time.time()) + choices: list = None + usage = { + "completion_tokens": None, + "prompt_tokens": None, + "total_tokens": None, + } + + +generate_config = DefaultConfig() +paddle.set_device(generate_config.device) +paddle.set_default_dtype(generate_config.default_dtype) + +tokenizer = CodeGenTokenizer.from_pretrained(generate_config.model_name_or_path) +model = CodeGenForCausalLM.from_pretrained( + generate_config.model_name_or_path, + load_state_as_np=generate_config.load_state_as_np) + +app = FastAPI() + + +def random_completion_id(): + return 'cmpl-' + ''.join( + random.choice(string.ascii_letters + string.digits) for _ in range(29)) + + +@app.post("/v1/engines/codegen/completions", status_code=200) +async def gen(item: Input): + item = item.dict() + logger.info(f"Request: {item}") + temperature = item.get("temperature", generate_config.temperature) + top_k = item.get("top_k", generate_config.top_k) + if temperature == 0.0: + temperature = 1.0 + top_k = 1 + repetition_penalty = item.get('frequency_penalty', + generate_config.repetition_penalty) + + start_time = time.time() + logger.info("Start generating code") + tokenized = tokenizer([item['prompt']], + truncation=True, + return_tensors='pd') + output, _ = model.generate( + tokenized["input_ids"], + max_length=16, + min_length=generate_config.min_length, + decode_strategy=generate_config.decode_strategy, + top_k=top_k, + repetition_penalty=repetition_penalty, + temperature=temperature, + use_faster=generate_config.use_faster, + use_fp16_decoding=generate_config.use_fp16_decoding) + logger.info("Finish generating code") + end_time = time.time() + logger.info(f"Time cost: {end_time - start_time}") + output = tokenizer.decode(output[0], + skip_special_tokens=True, + spaces_between_special_tokens=False) + logger.info(f"Generated code: {output}") + output_json = Output( + id=random_completion_id(), + choices=[{ + "text": output, + "index": 0, + "finish_reason": "stop", + "logprobs": None, + }], + usage={ + "completion_tokens": None, + "prompt_tokens": None, + "total_tokens": None, + }, + ).json() + + def stream_response(response): + yield f"{response}\n\n" + yield "data: [DONE]\n\n" + + if item.get("stream", False): + return EventSourceResponse(stream_response(output_json)) + else: + return Response( + status_code=status.HTTP_200_OK, + content=output_json, + media_type="application/json", + ) + + +if __name__ == "__main__": + uvicorn.run("codegen_server:app", host="0.0.0.0", port=8978) diff --git a/examples/code_generation/codegen/requirements.txt b/examples/code_generation/codegen/requirements.txt new file mode 100644 index 000000000000..ded81d0ca44e --- /dev/null +++ b/examples/code_generation/codegen/requirements.txt @@ -0,0 +1,5 @@ +fastapi==0.79.0 +pydantic==1.9.1 +python-dotenv==0.20.0 +sse_starlette==0.10.3 +uvicorn==0.17.6 \ No newline at end of file diff --git a/examples/code_generation/codegen/run_clm.py b/examples/code_generation/codegen/run_clm.py new file mode 100644 index 000000000000..1166d22af69f --- /dev/null +++ b/examples/code_generation/codegen/run_clm.py @@ -0,0 +1,375 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import argparse +import random +import time +import distutils.util +from pprint import pprint +from functools import partial +import numpy as np +from itertools import chain +from datasets import load_dataset +import math +import paddle +import paddle.nn as nn +from paddle.io import BatchSampler, DistributedBatchSampler, DataLoader +from paddlenlp.transformers import CodeGenForCausalLM, CodeGenTokenizer +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.utils.log import logger +from paddlenlp.data import DataCollatorWithPadding +from paddle.metric import Accuracy + + +def parse_args(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--model_name_or_path", + default="Salesforce/codegen-350M-mono", + type=str, + required=True, + help="Path to pre-trained model. ") + parser.add_argument( + "--output_dir", + default="output", + type=str, + required=True, + help= + "The output directory where the model predictions and checkpoints will be written." + ) + parser.add_argument("--train_file", + default=None, + type=str, + required=True, + help="The input training data file.") + parser.add_argument("--validation_file", + default=None, + type=str, + required=True, + help="The input validation data file.") + parser.add_argument( + "--block_size", + default=None, + type=int, + help= + "The training dataset will be truncated in block of this size for training. " + ) + parser.add_argument("--learning_rate", + default=1e-4, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", + ) + parser.add_argument("--logging_steps", + type=int, + default=100, + help="Log every X updates steps.") + parser.add_argument("--save_steps", + type=int, + default=100, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--train_batch_size", + default=20, + type=int, + help="Batch size per GPU/CPU for training.", + ) + parser.add_argument( + "--eval_batch_size", + default=12, + type=int, + help="Batch size per GPU/CPU for evaluation.", + ) + parser.add_argument("--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--warmup_steps", + default=0, + type=int, + help= + "Linear warmup over warmup_steps. If > 0: Override warmup_proportion") + parser.add_argument("--warmup_proportion", + default=0.1, + type=float, + help="Linear warmup proportion over total steps.") + parser.add_argument("--adam_epsilon", + default=1e-6, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help= + "If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument("--seed", + default=42, + type=int, + help="random seed for initialization") + parser.add_argument( + "--device", + default="gpu", + type=str, + choices=["cpu", "gpu", "xpu"], + help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument("--overwrite_cache", + action="store_true", + help="Whether to overwrite cache for dataset.") + parser.add_argument("--use_amp", + default=False, + type=distutils.util.strtobool, + help="Enable mixed precision training.") + parser.add_argument("--scale_loss", + default=2**15, + type=float, + help="The value of scale_loss for fp16.") + args = parser.parse_args() + return args + + +def set_seed(args): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(args.seed) + np.random.seed(args.seed) + # Maybe different op seeds(for dropout) for different procs is better. By: + # `paddle.seed(args.seed + paddle.distributed.get_rank())` + paddle.seed(args.seed) + + +@paddle.no_grad() +def evaluate(model, data_loader, loss_fct): + model.eval() + metric = Accuracy() + metric.reset() + losses = [] + model = model._layers if isinstance(model, paddle.DataParallel) else model + for batch in data_loader: + labels = batch.pop("labels") + logits, _ = model(**batch) + loss = loss_fct(logits[:, :-1, :], labels[:, 1:]) + correct = metric.compute(paddle.max(logits[:, :-1, :], axis=-1), + labels[:, 1:]) + losses.append(loss) + losses = paddle.concat(losses) + eval_loss = paddle.mean(losses) + perplexity = math.exp(eval_loss) + accuracy = metric.accumulate() + logger.info("[validation] accuracy: %f, loss: %f, ppl: %f" % + (accuracy, eval_loss, perplexity)) + model.train() + return perplexity + + +def convert_example(examples, tokenizer): + """convert examples into necessary features""" + # Convert raw text to feature + tokenized_examples = tokenizer(examples["code"], + return_attention_mask=True, + return_position_ids=False, + return_token_type_ids=False) + return tokenized_examples + + +def group_texts(examples, block_size): + concatenated_examples = { + k: list(chain(*examples[k])) + for k in examples.keys() + } + total_length = len(concatenated_examples[list(examples.keys())[0]]) + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + result = { + k: [t[i:i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + +def process_ds(dataset, tokenizer, overwrite_cache, block_size): + trans_func = partial(convert_example, tokenizer=tokenizer) + dataset = dataset.map(trans_func, + batched=True, + remove_columns=dataset.column_names, + load_from_cache_file=overwrite_cache) + trans_func = partial(group_texts, block_size=block_size) + dataset = dataset.map(trans_func, + batched=True, + load_from_cache_file=overwrite_cache) + return dataset + + +def do_train(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + tokenizer = CodeGenTokenizer.from_pretrained(args.model_name_or_path) + + train_set = load_dataset("json", data_files=args.train_file, split="train") + dev_set = load_dataset("json", + data_files=args.validation_file, + split="train") + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(args.block_size, tokenizer.model_max_length) + + train_set = process_ds(train_set, tokenizer, args.overwrite_cache, + block_size) + dev_set = process_ds(dev_set, tokenizer, args.overwrite_cache, block_size) + + batchify_fn = DataCollatorWithPadding(tokenizer) + + train_batch_sampler = DistributedBatchSampler( + train_set, batch_size=args.train_batch_size, shuffle=True) + + train_data_loader = DataLoader(dataset=train_set, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + dev_batch_sampler = BatchSampler(dev_set, + batch_size=args.eval_batch_size, + shuffle=False) + dev_data_loader = DataLoader(dataset=dev_set, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + model = CodeGenForCausalLM.from_pretrained(args.model_name_or_path) + + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + if args.max_steps > 0: + num_training_steps = args.max_steps + num_train_epochs = math.ceil(num_training_steps / + len(train_data_loader)) + else: + num_training_steps = len(train_data_loader) * args.num_train_epochs + num_train_epochs = args.num_train_epochs + + warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion + + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, + warmup) + + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + decay_params = [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params) + + loss_fct = nn.CrossEntropyLoss() + if args.use_amp: + scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) + global_step = 0 + best_eval_ppl = float("inf") + tic_train = time.time() + for epoch in range(num_train_epochs): + for step, batch in enumerate(train_data_loader): + global_step += 1 + labels = batch.pop("labels") + with paddle.amp.auto_cast( + args.use_amp, + custom_white_list=["layer_norm", "softmax", "gelu"]): + logits, _ = model(**batch) + loss = loss_fct(logits[:, :-1, :], labels[:, 1:]) + if args.use_amp: + scaled_loss = scaler.scale(loss) + scaled_loss.backward() + scaler.minimize(optimizer, scaled_loss) + else: + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + if global_step % args.logging_steps == 0: + logger.info( + "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, ppl: %f, lr: %.10f, speed: %.4f step/s" + % (global_step, num_training_steps, epoch, step, + paddle.distributed.get_rank(), loss, math.exp(loss), + optimizer.get_lr(), args.logging_steps / + (time.time() - tic_train))) + tic_train = time.time() + if global_step % args.save_steps == 0 or global_step == num_training_steps: + tic_eval = time.time() + ppl = evaluate(model, dev_data_loader, loss_fct) + logger.info("eval done total : %s s" % (time.time() - tic_eval)) + if best_eval_ppl > ppl and paddle.distributed.get_rank() == 0: + best_eval_ppl = ppl + output_dir = args.output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + if global_step >= num_training_steps: + break + if global_step >= num_training_steps: + break + + if paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, + "codegen_model_final_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + +if __name__ == "__main__": + args = parse_args() + pprint(args) + do_train(args) diff --git a/examples/code_generation/codegen/run_clm.sh b/examples/code_generation/codegen/run_clm.sh new file mode 100644 index 000000000000..083abcf4cf47 --- /dev/null +++ b/examples/code_generation/codegen/run_clm.sh @@ -0,0 +1,11 @@ +python -m paddle.distributed.launch --gpus 0,1 run_clm.py \ + --model_name_or_path Salesforce/codegen-350M-mono \ + --output_dir output \ + --train_file train.json \ + --validation_file test.json \ + --num_train_epochs 5 \ + --logging_steps 10 \ + --save_steps 1000 \ + --train_batch_size 2 \ + --eval_batch_size 2 \ + --device gpu diff --git a/faster_generation/README.md b/faster_generation/README.md index 6adb349cd1ca..58cb2a6f4f99 100644 --- a/faster_generation/README.md +++ b/faster_generation/README.md @@ -11,7 +11,7 @@ FasterGeneration是PaddleNLP v2.2版本加入的文本生成高性能加速功 ## Featrues -- 全面支持生成式预训练模型。包括GPT、OPT、BART、mBART、UnifiedTransformer和UNIMO-text。 +- 全面支持生成式预训练模型。包括GPT、OPT、CodeGen、GPTJ、BART、mBART、UnifiedTransformer和UNIMO-text。 - 支持大多数主流解码策略。包括Beam Search、Sampling、Greedy Search。以及Diverse Sibling Search、Length Penalty等子策略。 - 解码速度快。最高可达非加速版generate函数的**18倍**。**并支持FP16混合精度计算**。 - 易用性强。功能的入口为`model.generate`,与非加速版生成api的使用方法相同,当满足加速条件时使用jit即时编译高性能算子并用于生成,不满足则自动切换回非加速版生成api。 @@ -20,17 +20,17 @@ FasterGeneration是PaddleNLP v2.2版本加入的文本生成高性能加速功 ### Inference Model Support 下表为PaddleNLP FasterGeneration对预训练模型和解码策略的支持情况(GPU)。 -| Model Name | GPT2 | OPT | BART | mBART | UnifiedTransformer | -|------------------------|---------|---------|-----------------|-----------------|--------------------| -| Model Structure | Decoder | Decoder | Encoder-Decoder | Encoder-Decoder | Prefix-LM | -| Beam Search | ❌ | ❌ | ✅ | ✅ | ✅ | -| Top-K Sampling | ✅ | ✅ | ✅ | ✅ | ✅ | -| Top-P Sampling | ✅ | ✅ | ✅ | ✅ | ✅ | -| Diverse Sibling Search | ❌ | ❌ | ✅ | ✅ | ✅ | -| Forced Decoding | ❌ | ❌ | ❌ | ✅ | ❌ | -| Length Penalty | ❌ | ❌ | ✅ | ✅ | ✅ | -| Temperature | ✅ | ✅ | ✅ | ✅ | ✅ | -| Repetition Penalty | ✅ | ✅ | ❌ | ❌ | ❌ | +| Model Name | GPT2 | OPT | CodeGen| GPTJ| BART | mBART | UnifiedTransformer | +|------------------------|---------|---------| ---------| ---------|-----------------|-----------------|--------------------| +| Model Structure | Decoder | Decoder |Decoder|Decoder| Encoder-Decoder | Encoder-Decoder | Prefix-LM | +| Beam Search | ❌ | ❌ |❌|❌| ✅ | ✅ | ✅ | +| Top-K Sampling | ✅ | ✅ |✅|✅| ✅ | ✅ | ✅ | +| Top-P Sampling | ✅ | ✅ |✅|✅| ✅ | ✅ | ✅ | +| Diverse Sibling Search | ❌ | ❌ |❌|❌| ✅ | ✅ | ✅ | +| Forced Decoding | ❌ | ❌ |❌|❌| ❌ | ✅ | ❌ | +| Length Penalty | ❌ | ❌ |❌|❌| ✅ | ✅ | ✅ | +| Temperature | ✅ | ✅ |✅|✅| ✅ | ✅ | ✅ | +| Repetition Penalty | ✅ | ✅ |✅|✅| ❌ | ❌ | ❌ | ## Performence @@ -61,6 +61,26 @@ FasterGeneration的高性能解码相比原版generate方法加速明显,并

+**CodeGen:** +* 环境和超参 +- Platform: Tesla V100-SXM2-32GB +- CUDA 10.1 +- CUDNN 7.6.5 +- PaddlePaddle-gpu 2.3.1.post101 +- transformers==4.21.1 +- torch==1.11.0 +- Batch Size: 1 +- Input Length: 60 +- Output Length: 20 +

+ +

+ +- Platform: A100-40G +

+ +

+ 更详细的性能数据请参见[这里](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/faster_generation/perf) ## Quick Start diff --git a/faster_generation/perf/README.md b/faster_generation/perf/README.md index c1c1e1d53090..2f89147651ba 100644 --- a/faster_generation/perf/README.md +++ b/faster_generation/perf/README.md @@ -154,6 +154,49 @@ transformers: 4.20.1 | | top_k=16 | 110.80 | 78.19 | 488.34 | 4.41 | 6.25 | | | top_p=0.4 | 128.33 | 92.57 | 544.18 | 4.24 | 5.88 | +**CodeGen:** +* 环境和超参 + +- Platform: Tesla V100-SXM2-32GB +- CUDA 10.1 +- CUDNN 7.6.5 +- PaddlePaddle-gpu 2.3.1.post101 +- transformers==4.21.1 +- torch==1.11.0 +- Batch Size: 1 +- Input Length: 60 +- Output Length: 20 + +* 模型参数 + +| Model Name | num_layers | num_attention_heads | hidden_size | +|------------|------------|---------------------|-------------| +| Salesforce/codegen-350M-mono | 20 | 16 | 1024 | +| Salesforce/codegen-2B-mono | 32 | 32 | 2560 | +| Salesforce/codegen-6B-mono | 33 | 16 | 4096 | +| Salesforce/codegen-16B-mono | 34 | 24 | 6144 | + + + +* 性能结果报表 + +| Model | Decoding Strategy | Faster Generation(FP32)(ms) | Faster Generation(FP16)(ms) | HF Generation(ms) | Speed Up Rate(Faster32/HF) | Speed Up Rate(Faster16/HF) | +|:--------:|:-------------------:|:-----------------------------:|:-----------------------------:|:-------------------:|:----------------------------:|:----------------------------:| +| Salesforce/codegen-350M-mono | top_k=1 | 57.76 | 51.35 | 709.62 | 12.29 | 13.82 | +| | top_k=4 | 57.42 | 50.88 | 639.58 | 11.14 | 12.57 | +| | top_k=8 | 57.24 | 51.67 | 685.82 | 11.98 | 13.27 | +| | top_k=16 | 57.57 | 51.61 | 686.62 | 11.93 | 13.30 | +| | top_p=0.4 | 67.26 | 57.35 | 656.12 | 9.75 | 11.44 | +| Salesforce/codegen-2B-mono| top_k=1 | 319.03 | 207.41 | 1040.71 | 3.26 | 5.02 | +| | top_k=4 | 318.98 | 207.37 | 1014.32 | 3.18 | 4.89 | +| | top_k=8 | 319.66 | 207.26 | 1084.09 | 3.39 | 5.23 | +| | top_k=16 | 320.04 | 207.74 | 1040.28 | 3.25 | 5.01 | +| | top_p=0.4 | 329.07 | 213.97 | 1055.55 | 3.21 | 4.93 | +| Salesforce/codegen-6B-mono| top_k=1 | 762.91 | 411.94 | 1384.90 | 1.82 | 3.36 | +| | top_k=4 | 762.58 | 412.79 | 1378.32 | 1.81 | 3.34 | +| | top_k=8 | 763.43 | 413.32 | 1366.45 | 1.79 | 3.31 | +| | top_k=16 | 762.79 | 413.83 | 1376.69 | 1.80 | 3.33 | +| | top_p=0.4 | 771.77 | 419.16 | 1366.49 | 1.77 | 3.26 | ## 测试方法 运行如下命令即可bart性能测试: diff --git a/faster_generation/perf/codegen_perf.py b/faster_generation/perf/codegen_perf.py new file mode 100644 index 000000000000..dc391956d0a1 --- /dev/null +++ b/faster_generation/perf/codegen_perf.py @@ -0,0 +1,190 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time +from pprint import pprint +import numpy as np +import paddle +from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM + +import pynvml + +pynvml.nvmlInit() + + +def query_by_id(gpu_id=2): + handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) + meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) + return meminfo.used // 1024 // 1024 + + +def perf_pd(args): + start_mem = query_by_id() + place = "gpu" + place = paddle.set_device(place) + tokenizer = CodeGenTokenizer.from_pretrained(args.model_name_or_path) + model = CodeGenForCausalLM.from_pretrained(args.model_name_or_path, + load_state_as_np=True) + model.eval() + load_mem = query_by_id() + + input_ids_np = [ + np.random.choice(list(tokenizer.decoder.keys())[:-1], args.input_len) + for _ in range(args.batch_size) + ] + input_ids = paddle.to_tensor(input_ids_np) + + num_loop = 100 + with paddle.no_grad(): + for i in range(num_loop): + # For warmup. + if num_loop // 2 == i: + # PaddlePaddle >= 2.2 + paddle.device.cuda.synchronize(place) + start = time.perf_counter() + output, _ = model.generate(input_ids=input_ids, + max_length=args.generate_len, + min_length=args.generate_len, + decode_strategy="sampling", + top_k=args.top_k, + top_p=args.top_p, + use_faster=args.use_faster, + use_fp16_decoding=args.use_fp16_decoding) + generate_mem = query_by_id() + paddle.device.cuda.synchronize(place) + pd_cost = (time.perf_counter() - start) / (num_loop - + num_loop // 2) * 1000 + return pd_cost, load_mem - start_mem, generate_mem - start_mem + + +def perf_hf(args): + import torch + from transformers import CodeGenTokenizer as hf_tokenizer, CodeGenForCausalLM as hf_codegen + start_mem = query_by_id() + device = torch.device("cuda") + tokenizer = hf_tokenizer.from_pretrained(args.model_name_or_path) + model = hf_codegen.from_pretrained(args.model_name_or_path) + model.to(device) + model.eval() + load_mem = query_by_id() + + input_ids_np = [ + np.random.choice(list(tokenizer.decoder.keys()), args.input_len) + for _ in range(args.batch_size) + ] + input_ids = torch.tensor(input_ids_np) + input_ids = input_ids.to(device) + num_loop = 100 + with torch.no_grad(): + for i in range(num_loop): + # For warmup. + if num_loop // 2 == i: + torch.cuda.synchronize() + start = time.perf_counter() + output = model.generate( + input_ids, + do_sample=True, + max_length=args.generate_len + input_ids.shape[-1], + min_length=args.generate_len + input_ids.shape[-1], + top_k=args.top_k, + top_p=args.top_p) + generate_mem = query_by_id() + torch.cuda.synchronize() + hf_cost = (time.perf_counter() - start) / (num_loop - + num_loop // 2) * 1000 + return hf_cost, load_mem - start_mem, generate_mem - start_mem + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--perf_type", + default="pd", + type=str, + choices=['pd', 'pd_faster_fp32', 'pd_faster_fp16', 'hf'], + help="The type of perf. ") + parser.add_argument("--model_name_or_path", + default="Salesforce/codegen-350M-mono", + type=str, + choices=[ + 'Salesforce/codegen-350M-mono', + 'Salesforce/codegen-2B-mono', + 'Salesforce/codegen-6B-mono', + 'Salesforce/codegen-16B-mono' + ], + help="The model name to specify the bart to use. ") + parser.add_argument( + "--top_k", + default=4, + type=int, + help="The number of candidate to procedure topk sampling. ") + parser.add_argument( + "--top_p", + default=1.0, + type=float, + help="The probability threshold to procedure topp sampling. ") + parser.add_argument("--batch_size", + default=1, + type=int, + help="The size of input batch. ") + parser.add_argument("--input_len", + default=60, + type=int, + help="The size of model input. ") + parser.add_argument("--generate_len", + default=20, + type=int, + help="Length of output . ") + parser.add_argument( + '--use_faster', + action='store_true', + help='Whether to process inference using faster codegen. ') + + parser.add_argument("--use_fp16_decoding", + action="store_true", + help="Whether to use fp16 decoding to predict. ") + args = parser.parse_args() + return args + + +def do_predict(args): + try: + if args.perf_type == 'pd': + args.use_faster = False + cost, load_mem, generate_mem = perf_pd(args) + elif args.perf_type == 'pd_faster_fp32': + args.use_faster = True + args.use_fp16_decoding = False + cost, load_mem, generate_mem = perf_pd(args) + elif args.perf_type == 'pd_faster_fp16': + args.use_faster = True + args.use_fp16_decoding = True + paddle.set_default_dtype('float16') + cost, load_mem, generate_mem = perf_pd(args) + else: + cost, load_mem, generate_mem = perf_hf(args) + pprint(args) + print( + f'CodeGenPerfResult: cost_time: {cost} ms, load_mem: {load_mem} MB, generate_mem:{generate_mem} MB, args:{args}\n' + ) + except Exception as e: + pprint(args) + print(f'CodeGenPerfResult: ERROR: {e}, args:{args}\n') + + +if __name__ == "__main__": + args = parse_args() + do_predict(args) diff --git a/faster_generation/perf/run_perf_codegen.sh b/faster_generation/perf/run_perf_codegen.sh new file mode 100644 index 000000000000..0d02575af92d --- /dev/null +++ b/faster_generation/perf/run_perf_codegen.sh @@ -0,0 +1,47 @@ +export CUDA_VISIBLE_DEVICES=2 + +for model_name in Salesforce/codegen-350M-mono Salesforce/codegen-2B-mono Salesforce/codegen-6B-mono; + do + for top_k in 1 4 8 16; + do + for input_len in 60; + do + for generate_len in 20; + do + for perf_type in pd pd_faster_fp32 pd_faster_fp16 hf; + do + echo model_name: $model_name, perf_type: $perf_type, top_k: $top_k, top_p: 1.0, input_len: $input_len, generate_len: $generate_len + python codegen_perf.py \ + --model_name_or_path=$model_name \ + --perf_type=$perf_type \ + --top_k=$top_k \ + --top_p=1.0 \ + --input_len=$input_len \ + --generate_len=$generate_len + sleep 3s + done + done + done + done + for top_p in 0.4; + do + for input_len in 60; + do + for generate_len in 20; + do + for perf_type in pd pd_faster_fp32 pd_faster_fp16 hf; + do + echo model_name: $model_name, perf_type: $perf_type, top_k: 0, top_p: $top_p, input_len: $input_len, generate_len: $generate_len + python codegen_perf.py \ + --model_name_or_path=$model_name \ + --perf_type=$perf_type \ + --top_k=0 \ + --top_p=$top_p \ + --input_len=$input_len \ + --generate_len=$generate_len + sleep 3s + done + done + done + done + done \ No newline at end of file diff --git a/faster_generation/samples/codegen_sample.py b/faster_generation/samples/codegen_sample.py new file mode 100644 index 000000000000..050073d8f95c --- /dev/null +++ b/faster_generation/samples/codegen_sample.py @@ -0,0 +1,41 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddlenlp.transformers import CodeGenTokenizer, CodeGenForCausalLM + +model_name = 'Salesforce/codegen-350M-mono' + +tokenizer = CodeGenTokenizer.from_pretrained(model_name) +model = CodeGenForCausalLM.from_pretrained(model_name) +model.eval() + +inputs = 'def hello' +input_ids = tokenizer([inputs], return_tensors='pd')['input_ids'] + +outputs, _ = model.generate(input_ids=input_ids, + max_length=128, + decode_strategy='greedy_search', + use_fp16_decoding=True, + use_faster=True) + +result = tokenizer.decode(outputs[0], + truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"]) + +print("Model input:", inputs) +print("Result:", result) +# Result: _world(): +# print("Hello World") + +# hello_world() diff --git a/faster_generation/samples/gptj_sample.py b/faster_generation/samples/gptj_sample.py new file mode 100644 index 000000000000..90f041e0585d --- /dev/null +++ b/faster_generation/samples/gptj_sample.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddlenlp.transformers import GPTJTokenizer, GPTJForCausalLM + +paddle.set_default_dtype('float16') +model_name = 'EleutherAI/gpt-j-6B' + +tokenizer = GPTJTokenizer.from_pretrained(model_name) +model = GPTJForCausalLM.from_pretrained(model_name, load_state_as_np=True) +model.eval() + +inputs = "What is PaddleNLP?" +input_ids = tokenizer([inputs], return_tensors='pd')['input_ids'] + +outputs, _ = model.generate(input_ids=input_ids, + max_length=100, + decode_strategy='sampling', + temperature=0.8, + top_p=0.9, + use_fp16_decoding=True, + use_faster=True) + +result = tokenizer.decode(outputs[0]) + +print("Model input:", inputs) +print("Result:", result) \ No newline at end of file diff --git a/paddlenlp/ops/CMakeLists.txt b/paddlenlp/ops/CMakeLists.txt index c0c1e2d21e8b..194f32167ebe 100644 --- a/paddlenlp/ops/CMakeLists.txt +++ b/paddlenlp/ops/CMakeLists.txt @@ -36,6 +36,7 @@ option(WITH_BART "Compile with BART" option(WITH_MBART "Compile with MBART" ON) option(WITH_PARALLEL "Compile with model parallel for GPT" OFF) option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF) +option(WITH_GPTJ "Compile with GPTJ" ON) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") if(WITH_PARALLEL) @@ -85,8 +86,12 @@ if(WITH_MBART) list(APPEND decoding_op_files fusion_mbart_decoding_op.cc fusion_mbart_decoding_op.cu) endif() -if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER AND NOT WITH_BART AND NOT WITH_MBART) - message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON or/and -DWITH_BART=ON or/and -DWITH_MBART=ON must be set to use FasterTransformer. ") +if(WITH_GPTJ) + list(APPEND decoding_op_files fusion_gptj_op.cc fusion_gptj_op.cu) +endif() + +if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER AND NOT WITH_BART AND NOT WITH_MBART AND NOT WITH_GPTJ) + message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON or/and -DWITH_BART=ON or/and -DWITH_MBART=ON or/and -DWITH_GPTJ=ON must be set to use FasterTransformer. ") endif() set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) @@ -321,6 +326,21 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME} file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/opt.h opt_h_src) file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/opt.h opt_h_dst) +file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/gptj.h gptj_h_src) +file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/gptj.h gptj_h_dst) + +file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention_utils.h masked_multihead_attention_utils_h_src) +file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/masked_multihead_attention_utils.h masked_multihead_attention_utils_h_dst) + +file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.h masked_multihead_attention_h_src) +file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/masked_multihead_attention.h masked_multihead_attention_h_dst) + +file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cu attention_kernels_cu_src) +file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/attention_kernels.cu attention_kernels_cu_dst) + +file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh attention_kernels_cuh_src) +file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/attention_kernels.cuh attention_kernels_cuh_dst) + # Encoder patches. file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_src) file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_dst) @@ -360,6 +380,11 @@ set(FT_PATCH_COMMAND && cp ${fastertransformer_cmakelists_src} ${fastertransformer_cmakelists_dst} && cp ${gpt_h_src} ${gpt_h_dst} && cp ${opt_h_src} ${opt_h_dst} + && cp ${gptj_h_src} ${gptj_h_dst} + && cp ${masked_multihead_attention_h_src} ${masked_multihead_attention_h_dst} + && cat blank_lines ${masked_multihead_attention_utils_h_src} >> ${masked_multihead_attention_utils_h_dst} + && cat blank_lines ${attention_kernels_cu_src} >> ${attention_kernels_cu_dst} + && cat blank_lines ${attention_kernels_cuh_src} >> ${attention_kernels_cuh_dst} && cat blank_lines ${cuda_kernels_h_src} >> ${cuda_kernels_h_dst} && cat blank_lines ${lightseq_kernels_cu_src} >> ${topk_kernels_dst} && cat blank_lines ${cuda_kernels_cu_src} >> ${cuda_kernels_cu_dst} diff --git a/paddlenlp/ops/faster_transformer/src/CMakeLists.txt b/paddlenlp/ops/faster_transformer/src/CMakeLists.txt index 6de779f002de..46588a22151a 100644 --- a/paddlenlp/ops/faster_transformer/src/CMakeLists.txt +++ b/paddlenlp/ops/faster_transformer/src/CMakeLists.txt @@ -221,6 +221,16 @@ else(ON_INFER) set(PYTHON_PATH ${PY_CMD} CACHE STRING "Python path") endif() + execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import paddle; print(paddle.__version__)" + RESULT_VARIABLE _INC_PYTHON_SUCCESS + OUTPUT_VARIABLE _INC_PYTHON_VALUES) + message(STATUS "PADDLE_VERSION: ${_INC_PYTHON_VALUES}") + + # TODO(gongenlei): support PADDLE_NEW_ALLOCATOR for ON_INFER + if(_INC_PYTHON_VALUES VERSION_GREATER_EQUAL "2.3.0") + add_definitions(-DPADDLE_NEW_ALLOCATOR) + endif() + execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import paddle; print(paddle.sysconfig.get_include())" RESULT_VARIABLE _INC_PYTHON_SUCCESS OUTPUT_VARIABLE _INC_PYTHON_VALUES) diff --git a/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cc b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cc new file mode 100644 index 000000000000..81d5411eb4be --- /dev/null +++ b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cc @@ -0,0 +1,203 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "fusion_gptj_op.h" +#include "pd_traits.h" + + +std::vector GPTJForward( + const paddle::Tensor& input, + const paddle::Tensor& attn_mask, + const paddle::Tensor& start_length, + const paddle::Tensor& word_embedding, + const std::vector& self_ln_weight, + const std::vector& self_ln_bias, + const std::vector& self_q_weight, + const std::vector& self_out_weight, + const std::vector& ffn_inter_weight, + const std::vector& ffn_inter_bias, + const std::vector& ffn_out_weight, + const std::vector& ffn_out_bias, + const paddle::Tensor& decoder_ln_weight, + const paddle::Tensor& decoder_ln_bias, + const paddle::Tensor& emb_weight, + const paddle::Tensor& emb_bias, + const int topk, + const float topp, + const int max_len, + const int n_head, + const int size_per_head, + const int num_layer, + const int bos_id, + const int eos_id, + const float temperature, + const int rotary_embedding_dim, + const float repetition_penalty, + const int min_length, + const bool use_fp16 = false, + const int tensor_para_size = 1, + const int layer_para_size = 1, + const int layer_para_batch_size = 1) { + int batch_size = input.shape()[0]; + int start_len = input.shape()[1]; + int total_len = max_len + start_len; + std::vector output_dims({total_len, batch_size}); + +#ifdef PADDLE_NEW_ALLOCATOR + // For PaddlePaddle>=2.3.0 + auto output_ids = paddle::empty(output_dims, paddle::DataType::INT32, input.place()); + auto gpu_place = paddle::GPUPlace(); +#else + auto output_ids = paddle::Tensor(input.place(), output_dims); + auto gpu_place = paddle::PlaceType::kGPU; +#endif + + if (word_embedding.place() == gpu_place) { + return GPTJCUDAForward(input, + attn_mask, + start_length, + word_embedding, + self_ln_weight, + self_ln_bias, + self_q_weight, + self_out_weight, + ffn_inter_weight, + ffn_inter_bias, + ffn_out_weight, + ffn_out_bias, + decoder_ln_weight, + decoder_ln_bias, + emb_weight, + emb_bias, + output_ids, + topk, + topp, + total_len, + n_head, + size_per_head, + num_layer, + bos_id, + eos_id, + temperature, + rotary_embedding_dim, + repetition_penalty, + min_length, + use_fp16, + tensor_para_size, + layer_para_size, + layer_para_batch_size); + } else { + PD_THROW("Not implemented place. Only GPU is supported. "); + } +} + +std::vector> GPTJInferShape( + const std::vector& input_shape, + const std::vector& attn_mask_shape, + const std::vector& start_length, + const std::vector& word_embedding_shape, + const std::vector>& self_ln_weight_shapes, + const std::vector>& self_ln_bias_shapes, + const std::vector>& self_q_weight_shapes, + const std::vector>& self_out_weight_shapes, + const std::vector>& ffn_inter_weight_shapes, + const std::vector>& ffn_inter_bias_shapes, + const std::vector>& ffn_out_weight_shapes, + const std::vector>& ffn_out_bias_shapes, + const std::vector& decoder_ln_weight_shape, + const std::vector& decoder_ln_bias_shape, + const std::vector& emb_weight_shape, + const std::vector& emb_bias_shape, + const int topk, + const float topp, + const int max_len, + const int n_head, + const int size_per_head, + const int num_layer, + const int bos_id, + const int eos_id, + const float temperature, + const int rotary_embedding_dim, + const float repetition_penalty, + const int min_length, + const bool use_fp16 = false, + const int tensor_para_size = 1, + const int layer_para_size = 1, + const int layer_para_batch_size = 1) { + int64_t batch_size = input_shape[0]; + int64_t start_len = input_shape[1]; + std::vector output_dims({max_len + start_len, batch_size}); + return {output_dims}; +} + +std::vector GPTJInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& attn_mask_dtype, + const paddle::DataType& start_length_dtype, + const paddle::DataType& word_embedding_dtype, + const std::vector& self_ln_weight_dtype, + const std::vector& self_ln_bias_dtype, + const std::vector& self_q_weight_dtype, + const std::vector& self_out_weight_dtype, + const std::vector& ffn_inter_weight_dtype, + const std::vector& ffn_inter_bias_dtype, + const std::vector& ffn_out_weight_dtype, + const std::vector& ffn_out_bias_dtype, + const paddle::DataType& decoder_ln_weight_dtype, + const paddle::DataType& decoder_ln_bias_dtype, + const paddle::DataType& emb_weight_dtype, + const paddle::DataType& emb_bias_dtype) { + return {paddle::DataType::INT32}; +} + +PD_BUILD_OP(fusion_gptj) + .Inputs({"Input", + "AttentionMask", + "StartLength", + "WordEmbedding", + paddle::Vec("SelfLayernormWeight"), + paddle::Vec("SelfLayernormBias"), + paddle::Vec("SelfQueryWeight"), + paddle::Vec("SelfOutWeight"), + paddle::Vec("FFNInterWeight"), + paddle::Vec("FFNInterBias"), + paddle::Vec("FFNOutWeight"), + paddle::Vec("FFNOutBias"), + "DecoderLayernormWeight", + "DecoderLayernormBias", + "EmbWeight", + "EmbBias"}) + .Outputs({"OutputIds"}) + .Attrs({"topk: int", + "topp: float", + "max_len: int", + "n_head: int", + "size_per_head: int", + "num_layer: int", + "bos_id: int", + "eos_id: int", + "temperature: float", + "rotary_embedding_dim: int", + "repetition_penalty: float", + "min_length: int", + "use_fp16: bool", + "tensor_para_size: int", + "layer_para_size: int", + "layer_para_batch_size: int"}) + .SetKernelFn(PD_KERNEL(GPTJForward)) + .SetInferShapeFn(PD_INFER_SHAPE(GPTJInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GPTJInferDtype)); diff --git a/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cu b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cu new file mode 100644 index 000000000000..c2217a84a0c1 --- /dev/null +++ b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.cu @@ -0,0 +1,334 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Use the global cublas handle +#include "cublas_handle.h" + +// TODO(guosheng): `HOST` conflict exists in float.h of paddle and mpi.h of mpi +#include "fusion_gptj_op.h" +#include "pd_traits.h" +#ifdef HOST +#undef HOST +#endif +#include "fastertransformer/cuda/cub/cub.cuh" +#include "fastertransformer/utils/common.h" + +#ifdef BUILD_GPT // consistent with FasterTransformer +#include "parallel_utils.h" +#endif + +template +std::vector gptj_kernel( + const paddle::Tensor& input, + const paddle::Tensor& attn_mask, + const paddle::Tensor& start_length, + const paddle::Tensor& word_emb, + const std::vector& self_ln_weight, + const std::vector& self_ln_bias, + const std::vector& self_q_weight, + const std::vector& self_out_weight, + const std::vector& ffn_inter_weight, + const std::vector& ffn_inter_bias, + const std::vector& ffn_out_weight, + const std::vector& ffn_out_bias, + const paddle::Tensor& decoder_ln_weight, + const paddle::Tensor& decoder_ln_bias, + const paddle::Tensor& emb_weight, + const paddle::Tensor& emb_bias, + paddle::Tensor& output_ids, + const int topk, + const float topp, + const int max_len, + const int n_head, + const int size_per_head, + const int num_layer, + const int bos_id, + const int eos_id, + const float temperature, + const int rotary_embedding_dim, + const float repetition_penalty, + const int min_length, + cudaStream_t stream, + const int tensor_para_size = 1, + const int layer_para_size = 1, + const int layer_para_batch_size = 1) { + auto input_dims = input.shape(); + int batch_size_ = input_dims[0]; + int start_len = input_dims[1]; + const int vocab_size = emb_bias.shape()[0]; + + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t_; + + DecodingInitParam decoding_params; + decoding_params.cublas_handle = CublasHandle::GetInstance()->cublas_handle_; + decoding_params.cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_; + +#ifdef PADDLE_NEW_ALLOCATOR + // For PaddlePaddle>=2.3.0 + decoding_params.output_ids = output_ids.data(); +#else + decoding_params.output_ids = output_ids.mutable_data(word_emb.place()); +#endif + + typedef DecoderTransformerTraits DecodingTraits_; + decoding_params.stream = stream; + fastertransformer::Allocator allocator_(stream); + + const int hidden_unit = size_per_head * n_head; + +#ifdef BUILD_GPT + auto* model_para_desc = ModelParaDescFactory::CreateModelParaDesc( + n_head, + size_per_head, + num_layer, + tensor_para_size, + layer_para_size, + layer_para_batch_size, + const_cast(word_emb.data())); + auto& tensor_parallel_param = model_para_desc->tensor_parallel_param; + auto& layer_parallel_param = model_para_desc->layer_parallel_param; + auto seed = model_para_desc->dist(model_para_desc->gen); +#else + TensorParallelParam tensor_parallel_param; + LayerParallelParam layer_parallel_param; + tensor_parallel_param.rank = 0; + tensor_parallel_param.world_size = 1; + tensor_parallel_param.local_head_num_ = n_head; + tensor_parallel_param.local_hidden_units_ = hidden_unit; + + layer_parallel_param.rank = 0; + layer_parallel_param.world_size = 1; + layer_parallel_param.layers_per_group = num_layer; + layer_parallel_param.local_batch_size = batch_size_; + int seed = -1; +#endif + + DecodingGptJ* gptj_decoding; + + decoding_params.request_batch_size = batch_size_; + decoding_params.max_input_len = start_len; + decoding_params.request_input_len = start_len; + decoding_params.request_output_len = max_len - start_len; + + decoding_params.d_start_ids = const_cast(input.data()); + decoding_params.d_attn_mask = + reinterpret_cast(const_cast(attn_mask.data())); + decoding_params.d_start_lengths = start_length.data(); + + gptj_decoding = + new DecodingGptJ(allocator_, + batch_size_, + max_len, + n_head, + size_per_head, + vocab_size, + num_layer, + bos_id, + eos_id, + topk, + topp, + temperature, + tensor_para_size, + layer_para_size, + true, /*is_fuse_QKV*/ + repetition_penalty, /*repetition_penalty*/ + seed, + rotary_embedding_dim, + min_length); + + gptj_decoding->set_tensor_parallel_param(tensor_parallel_param); + gptj_decoding->set_layer_parallel_param(layer_parallel_param); + + DecoderInitParam* params = + new DecoderInitParam[num_layer]; + + for (int i = 0; i < self_ln_weight.size(); ++i) { + // Allow python passing weights of all layers or only passing the + // corresponding layers to save memory. + int layer_idx = self_ln_weight.size() != num_layer + ? layer_parallel_param.rank * + layer_parallel_param.layers_per_group + + i + : i; + + params[layer_idx].stream = stream; + params[layer_idx].cublas_handle = CublasHandle::GetInstance()->cublas_handle_; + params[layer_idx].cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_; + + params[layer_idx].request_batch_size = batch_size_; + params[layer_idx].request_max_mem_seq_len = start_len; + + params[layer_idx].self_layernorm.gamma = + reinterpret_cast(self_ln_weight[i].data()); + params[layer_idx].self_layernorm.beta = + reinterpret_cast(self_ln_bias[i].data()); + + params[layer_idx].self_attention.query_weight.kernel = + reinterpret_cast(self_q_weight[i].data()); + params[layer_idx].self_attention.query_weight.bias = nullptr; + + params[layer_idx].self_attention.attention_output_weight.kernel = + reinterpret_cast(self_out_weight[i].data()); + params[layer_idx].self_attention.attention_output_weight.bias = nullptr; + + params[layer_idx].ffn.intermediate_weight.kernel = + reinterpret_cast(ffn_inter_weight[i].data()); + params[layer_idx].ffn.intermediate_weight.bias = + reinterpret_cast(ffn_inter_bias[i].data()); + params[layer_idx].ffn.output_weight.kernel = + reinterpret_cast(ffn_out_weight[i].data()); + params[layer_idx].ffn.output_weight.bias = + reinterpret_cast(ffn_out_bias[i].data()); + } + + decoding_params.layernorm.gamma = + reinterpret_cast(decoder_ln_weight.data()); + decoding_params.layernorm.beta = + reinterpret_cast(decoder_ln_bias.data()); + decoding_params.embedding_table = + reinterpret_cast(word_emb.data()); + decoding_params.embedding_kernel = + reinterpret_cast(emb_weight.data()); + decoding_params.embedding_bias = + reinterpret_cast(emb_bias.data()); + + gptj_decoding->forward_context(params, decoding_params); + gptj_decoding->forward(params, decoding_params); + + delete gptj_decoding; + delete[] params; + + return {output_ids}; +} + +std::vector GPTJCUDAForward( + const paddle::Tensor& input, + const paddle::Tensor& attn_mask, + const paddle::Tensor& start_length, + const paddle::Tensor& word_embedding, + const std::vector& self_ln_weight, + const std::vector& self_ln_bias, + const std::vector& self_q_weight, + const std::vector& self_out_weight, + const std::vector& ffn_inter_weight, + const std::vector& ffn_inter_bias, + const std::vector& ffn_out_weight, + const std::vector& ffn_out_bias, + const paddle::Tensor& decoder_ln_weight, + const paddle::Tensor& decoder_ln_bias, + const paddle::Tensor& emb_weight, + const paddle::Tensor& emb_bias, + paddle::Tensor& output_ids, + const int topk, + const float topp, + const int max_len, + const int n_head, + const int size_per_head, + const int num_layer, + const int bos_id, + const int eos_id, + const float temperature, + const int rotary_embedding_dim, + const float repetition_penalty, + const int min_length, + const bool use_fp16 = false, + const int tensor_para_size = 1, + const int layer_para_size = 1, + const int layer_para_batch_size = 1) { + + auto stream = word_embedding.stream(); + cublasSetStream(CublasHandle::GetInstance()->cublas_handle_, stream); + + if (use_fp16) { + return gptj_kernel(input, + attn_mask, + start_length, + word_embedding, + self_ln_weight, + self_ln_bias, + self_q_weight, + self_out_weight, + ffn_inter_weight, + ffn_inter_bias, + ffn_out_weight, + ffn_out_bias, + decoder_ln_weight, + decoder_ln_bias, + emb_weight, + emb_bias, + output_ids, + topk, + topp, + max_len, + n_head, + size_per_head, + num_layer, + bos_id, + eos_id, + temperature, + rotary_embedding_dim, + repetition_penalty, + min_length, + stream, + tensor_para_size, + layer_para_size, + layer_para_batch_size); + } else { + return gptj_kernel(input, + attn_mask, + start_length, + word_embedding, + self_ln_weight, + self_ln_bias, + self_q_weight, + self_out_weight, + ffn_inter_weight, + ffn_inter_bias, + ffn_out_weight, + ffn_out_bias, + decoder_ln_weight, + decoder_ln_bias, + emb_weight, + emb_bias, + output_ids, + topk, + topp, + max_len, + n_head, + size_per_head, + num_layer, + bos_id, + eos_id, + temperature, + rotary_embedding_dim, + repetition_penalty, + min_length, + stream, + tensor_para_size, + layer_para_size, + layer_para_batch_size); + } +} diff --git a/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.h b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.h new file mode 100644 index 000000000000..0470de164433 --- /dev/null +++ b/paddlenlp/ops/faster_transformer/src/fusion_gptj_op.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "fastertransformer/gptj.h" +#include "fastertransformer/open_decoder.h" +#include "fastertransformer/utils/common.h" + +#ifdef PADDLE_ON_INFERENCE +#include "paddle/include/experimental/ext_all.h" +#else +#include "paddle/extension.h" +#endif + + +std::vector GPTJCUDAForward( + const paddle::Tensor& input, + const paddle::Tensor& attn_mask, + const paddle::Tensor& start_length, + const paddle::Tensor& word_embedding, + const std::vector& self_ln_weight, + const std::vector& self_ln_bias, + const std::vector& self_q_weight, + const std::vector& self_out_weight, + const std::vector& ffn_inter_weight, + const std::vector& ffn_inter_bias, + const std::vector& ffn_out_weight, + const std::vector& ffn_out_bias, + const paddle::Tensor& decoder_ln_weight, + const paddle::Tensor& decoder_ln_bias, + const paddle::Tensor& emb_weight, + const paddle::Tensor& emb_bias, + paddle::Tensor& output_ids, + const int topk, + const float topp, + const int max_len, + const int n_head, + const int size_per_head, + const int num_layer, + const int bos_id, + const int eos_id, + const float temperature, + const int rotary_embedding_dim, + const float repetition_penalty, + const int min_length, + const bool use_fp16, + const int tensor_para_size, + const int layer_para_size, + const int layer_para_batch_size); diff --git a/paddlenlp/ops/faster_transformer/transformer/decoding.py b/paddlenlp/ops/faster_transformer/transformer/decoding.py index a6cf4091f731..3e53ea89680d 100644 --- a/paddlenlp/ops/faster_transformer/transformer/decoding.py +++ b/paddlenlp/ops/faster_transformer/transformer/decoding.py @@ -622,6 +622,68 @@ def infer_mbart_decoding( attrs_names, attrs_val, outputs_names, outputs_dtype) +def infer_gptj_decoding(input, attn_mask, mem_seq_len, word_emb, slf_ln_weight, + slf_ln_bias, slf_q_weight, slf_out_weight, + ffn_inter_weight, ffn_inter_bias, ffn_out_weight, + ffn_out_bias, decoder_ln_weight, decoder_ln_bias, + linear_weight, linear_bias, topk, topp, max_out_len, + head_num, size_per_head, num_layer, bos_id, eos_id, + temperature, rotary_embedding_dim, repetition_penalty, + min_length, use_fp16_decoding): + tensor_para_size = get_ft_para_conf().tensor_para_size + layer_para_size = get_ft_para_conf().layer_para_size + layer_para_batch_size = get_ft_para_conf().layer_para_batch_size + + inputs = { + "Input": input, + "AttentionMask": attn_mask, + "StartLength": mem_seq_len, + "WordEmbedding": word_emb, + "SelfLayernormWeight@VECTOR": slf_ln_weight, + "SelfLayernormBias@VECTOR": slf_ln_bias, + "SelfQueryWeight@VECTOR": slf_q_weight, + "SelfOutWeight@VECTOR": slf_out_weight, + "FFNInterWeight@VECTOR": ffn_inter_weight, + "FFNInterBias@VECTOR": ffn_inter_bias, + "FFNOutWeight@VECTOR": ffn_out_weight, + "FFNOutBias@VECTOR": ffn_out_bias, + "DecoderLayernormWeight": decoder_ln_weight, + "DecoderLayernormBias": decoder_ln_bias, + "EmbWeight": linear_weight, + "EmbBias": linear_bias + } + + attrs = { + "topk": topk, + "topp": topp, + "max_len": max_out_len, + "n_head": head_num, + "size_per_head": size_per_head, + "num_layer": num_layer, + "bos_id": bos_id, + "eos_id": eos_id, + "temperature": temperature, + "rotary_embedding_dim": rotary_embedding_dim, + "repetition_penalty": repetition_penalty, + "min_length": min_length, + "use_fp16": use_fp16_decoding, + "tensor_para_size": tensor_para_size, + "layer_para_size": layer_para_size, + "layer_para_batch_size": layer_para_batch_size + } + + outputs_names = ["OutputIds"] + outputs_dtype = ["int32"] + + return run_custom(op_name="fusion_gptj", + inputs_names=inputs.keys(), + inputs_var=inputs.values(), + attrs_names=attrs.keys(), + attrs_val=attrs.values(), + outputs_names=outputs_names, + outputs_dtype=outputs_dtype) + + def finalize(beam_size, output_ids, parent_ids, @@ -2516,3 +2578,265 @@ def forward(self, sequence_length, decoding_strategy=decoding_strategy) return ids + + +def convert_gptj_params(faster_model, + model, + fuse_qkv=1, + use_fp16=False, + restore_data=False, + permutation=None): + r""" + Convert parameters included in Transformer layer from original models + to the format of faster models. + + Args: + faster_model (Layer): The faster model object. + model (Layer): The Transformer layer. + fuse_qkv (int): 0 for nofuse, 1 for fuse, 2 for fuse and delete the + unfused parameters. If environment variable `PPFG_QKV_MEM_OPT` is + set and the weights of q/k/v is fused, it will try to delete the + original unfused weights. Note the rollback to original model would + not be guarantee anymore when the faster model failed if the original + weights are deleted. Default to 1. + use_fp16 (bool): Whether to use float16. Maybe we should use the default + dtype as the highest priority later. Default to `False`. + restore_data (bool): If `False`, need to reload the weight values. It + should be `True` for weight loaded models. Default to `False`. + + Returns: + defaultdict: Each value is a list including converted parameters in all + layers. For other parameters not included in Transformer module to + be converted, such as embeddings, you can achieve it by using the + returned dict `params` though `params['word_emb'].append()` directly + which would do CPU/GPU and fp32/fp16 transfer automatically. + """ + if fuse_qkv == 1: + fuse_qkv = 2 if os.getenv("PPFG_QKV_MEM_OPT", "0") == "1" else 1 + ft_para_conf = get_ft_para_conf() + + class _list(list): + + def append(self, item): + if isinstance(item[0], nn.Layer): + # Axis is used for tensor slice in tensor parallel. + # Use None to make no slice on the tensor. + if len(item) == 2: + layer, attr = item + axis = None + else: + layer, attr, axis = item + param = getattr(layer, attr) + if axis is not None and isinstance(layer, nn.Linear): + param = ft_para_conf.slice_weight(param, axis) + param = transfer_param( + param, + is_bias=attr.endswith("bias"), + dtype="float16" if use_fp16 else "float32", + restore_data=restore_data) + # NOTE: Assignment to parameter 'weight' should be of type + # Parameter or None, thus delete first in case of param is + # a tensor. + # TODO(guosheng): Make slice_weight use `output_param=True` + # and remove delattr. Currently, if `param` is Tensor rather + # than Parameter, it would not be in state_dict. + delattr(layer, attr) + setattr(layer, attr, param) + else: + # NOTE: Compared with if branch, there is no layer attribute + # refered to the transfered param, thus we should set it as + # the layer attribute to be able to convert to static graph. + # Additionally, we suppose no need to process tensor parallel + # here since the param passed in might have been processed. + if len(item) == 2: + param, is_bias = item + attr_handle = lambda x: x + else: + param, is_bias, attr_handle = item + param = transfer_param( + param, + is_bias=is_bias, + dtype="float16" if use_fp16 else "float32", + restore_data=restore_data) + attr_handle(param) + return super().append(param) + + params = defaultdict(_list) + + def _convert(module): + num_layer = len(module) + for i, layer in enumerate(module): + if not ft_para_conf.is_load(i, num_layer): + continue + # TODO(guosheng): Tensor with size 0 might be failed in + # paddle develop, thus use tensor with size 1 instead + # temporarily. Besides, we use 2D tensor since jit log + # requires that on linear weight. While size 0 seems all + # right in jit.to_static/jit.save. + dummy_tensor = paddle.zeros([1, 1]) + if permutation is not None: + qkv = layer.attn.qkv_proj.weight.numpy() + qkv = qkv[:, permutation] + if fuse_qkv == 2: + del layer.attn.qkv_proj.weight + setattr(layer.attn.qkv_proj, "weight", dummy_tensor) + w = paddle.to_tensor(qkv) + else: + w = _convert_qkv(layer.attn.q_proj, + layer.attn.k_proj, + layer.attn.v_proj, + attr="weight", + use_numpy=fuse_qkv == 2, + del_param=fuse_qkv == 2, + dummy_tensor=dummy_tensor) + params["slf_q_weight"].append((w, False)) + # NOTE: Use `params["slf_q_weight"][-1]` rather than `w`, + # since the appended tensor might be a new transfered tensor. + # Besides, to allow convert_params be called more than once, + # we find a attr name not existing to avoid overwriting the + # existing attr. + attr = "slf_q_weight_" + str(i) + while hasattr(faster_model, attr): + attr += "_" + setattr(faster_model, attr, params["slf_q_weight"][-1]) + + params["slf_out_weight"].append((layer.attn.out_proj, "weight", 0)) + params["slf_ln_weight"].append((layer.ln_1, "weight")) + params["slf_ln_bias"].append((layer.ln_1, "bias")) + # Slice tensor when append according to axis(1 or 0) if parallel + # is enable. + params["ffn_inter_weight"].append((layer.mlp.fc_in, "weight", 1)) + params["ffn_inter_bias"].append((layer.mlp.fc_in, "bias", 1)) + params["ffn_out_weight"].append((layer.mlp.fc_out, "weight", 0)) + params["ffn_out_bias"].append((layer.mlp.fc_out, "bias")) + + _convert(model) + return params + + +class InferGptJDecoding(nn.Layer): + + def __init__(self, + model, + decoding_lib=None, + use_fp16_decoding=False, + transpose_qkv=False): + if decoding_lib is not None and os.path.isfile(decoding_lib): + if "FasterTransformer" not in LOADED_EXT.keys(): + ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op( + decoding_lib) + LOADED_EXT["FasterTransformer"] = ops + else: + if decoding_lib is not None: + logger.warning( + "The specified decoding_lib does not exist, and it will be built automatically." + ) + load("FasterTransformer" + if get_ft_para_conf().no_para else "FasterTransformerParallel", + verbose=True, + need_parallel=not get_ft_para_conf().no_para) + + super(InferGptJDecoding, self).__init__() + + self.use_fp16_decoding = use_fp16_decoding + self.model = model + self.head_num = self.model.transformer.config['n_head'] + self.size_per_head = int(self.model.transformer.config['n_embd'] / + self.head_num) + self.num_layer = self.model.transformer.config['n_layer'] + self.rotary_embedding_dim = self.model.transformer.config['rotary_dim'] + logger.info("Converting model weights, it will cost a few seconds.....") + permutation = None + if transpose_qkv: + # GPTJ is different with CodeGen in attention project layer. + local_dim = self.model.transformer.config['n_embd'] // 4 + base_permutation = [0, 3, 6, 9, 2, 5, 8, 11, 1, 4, 7, 10] + permutation = np.concatenate([ + np.arange(i * local_dim, (i + 1) * local_dim) + for i in base_permutation + ]) + params = convert_gptj_params(self, + model.transformer.h, + fuse_qkv=2, + use_fp16=use_fp16_decoding, + restore_data=True, + permutation=permutation) + + params["word_emb"].append((self.model.transformer.wte, "weight")) + params["decoder_ln_weight"].append( + (self.model.transformer.ln_f, "weight")) + params["decoder_ln_bias"].append((self.model.transformer.ln_f, "bias")) + params["linear_weight"].append((self.model.lm_head.weight.t(), + partial(setattr, self, + "linear_weight_out"))) + params["linear_bias"].append((self.model.lm_head, "bias")) + + for k, v in params.items(): + setattr(self, k, v) + logger.info("Already converted model weights.") + + def forward(self, + input_ids, + mem_seq_len, + attention_mask=None, + topk=4, + topp=0.0, + bos_token_id=None, + eos_token_id=None, + pad_token_id=None, + forced_eos_token_id=None, + max_out_len=256, + temperature=1, + repetition_penalty=1.0, + min_length=0): + if attention_mask is None: + batch_size, input_length = paddle.shape(input_ids) + attention_mask = paddle.unsqueeze( + (input_ids != pad_token_id).astype("float32"), axis=[1]) + causal_mask = paddle.tril( + paddle.ones([batch_size, input_length, input_length], + dtype="float32")) + attention_mask = paddle.logical_and(attention_mask, causal_mask) + if not self.use_fp16_decoding: + attention_mask = paddle.cast(attention_mask, dtype="float32") + else: + attention_mask = paddle.cast(attention_mask, dtype="float16") + + if self.use_fp16_decoding and attention_mask.dtype == paddle.float32: + attention_mask = paddle.cast(attention_mask, dtype="float16") + + output_ids, = infer_gptj_decoding( + input=[input_ids], + attn_mask=[attention_mask], + mem_seq_len=[mem_seq_len], + word_emb=self.word_emb, + slf_ln_weight=self.slf_ln_weight, + slf_ln_bias=self.slf_ln_bias, + slf_q_weight=self.slf_q_weight, + slf_out_weight=self.slf_out_weight, + ffn_inter_weight=self.ffn_inter_weight, + ffn_inter_bias=self.ffn_inter_bias, + ffn_out_weight=self.ffn_out_weight, + ffn_out_bias=self.ffn_out_bias, + decoder_ln_weight=self.decoder_ln_weight, + decoder_ln_bias=self.decoder_ln_bias, + linear_weight=self.linear_weight, + linear_bias=self.linear_bias, + topk=topk, + topp=topp, + max_out_len=max_out_len, + head_num=self.head_num, + size_per_head=self.size_per_head, + num_layer=self.num_layer, + bos_id=bos_token_id, + eos_id=eos_token_id, + temperature=temperature, + rotary_embedding_dim=self.rotary_embedding_dim, + repetition_penalty=repetition_penalty, + min_length=min_length, + use_fp16_decoding=self.use_fp16_decoding) + + output_ids = output_ids[paddle.shape(input_ids)[-1]:, :] + if forced_eos_token_id is not None: + output_ids[:, -1] = forced_eos_token_id + return output_ids diff --git a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py index 66dce6a0f195..5b9d7fae1f1d 100644 --- a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py +++ b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py @@ -24,7 +24,8 @@ InferTransformerModel, GPTModel) from paddlenlp.ops import (InferTransformerDecoding, InferGptDecoding, InferUnifiedDecoding, InferBartDecoding, - InferMBartDecoding, InferOptDecoding) + InferMBartDecoding, InferOptDecoding, + InferGptJDecoding) from .encoder import enable_faster_encoder, disable_faster_encoder from paddlenlp.ops.ext_utils import load @@ -33,7 +34,8 @@ UnifiedTransformerPretrainedModel, UNIMOPretrainedModel, BartPretrainedModel, GPTPretrainedModel, MBartPretrainedModel, - OPTPretrainedModel) + OPTPretrainedModel, GPTJPretrainedModel, + CodeGenPreTrainedModel) class FasterTransformer(TransformerModel): @@ -1481,3 +1483,140 @@ def forward(self, early_stopping=early_stopping) generate = forward + + +class FasterGPTJ(GPTJPretrainedModel): + + def __init__(self, model, decoding_lib=None, use_fp16_decoding=False): + super(FasterGPTJ, self).__init__() + self._model = model + self.use_fp16_decoding = use_fp16_decoding + self.decoding = InferGptJDecoding(model=model, + decoding_lib=decoding_lib, + use_fp16_decoding=use_fp16_decoding) + + def forward(self, + input_ids, + seq_len=None, + attention_mask=None, + top_k=4, + top_p=0.0, + min_length=0, + max_length=256, + bos_token_id=None, + eos_token_id=None, + pad_token_id=None, + forced_eos_token_id=None, + temperature=0, + repetition_penalty=1.0, + decode_strategy="sampling", + num_return_sequences=1, + **model_kwargs): + if input_ids.dtype == paddle.int64: + input_ids = paddle.cast(input_ids, "int32") + + # change top_p to zero if not using top_p sampling for FT + if decode_strategy == "greedy_search": + top_p = 0.0 + top_k = 1 + if top_p == 1.0: + top_p = 0.0 + if seq_len is None: + seq_len = paddle.sum(paddle.cast(input_ids != pad_token_id, + dtype="int32"), + axis=-1, + dtype="int32") + + if num_return_sequences > 1: + input_ids, model_kwargs = self.expand_inputs_for_generation( + input_ids, + expand_size=num_return_sequences, + seq_len=seq_len, + attention_mask=attention_mask) + seq_len = model_kwargs["seq_len"] + attention_mask = model_kwargs.get("attention_mask", None) + + return self.decoding(input_ids, + mem_seq_len=seq_len, + attention_mask=attention_mask, + topk=top_k, + topp=top_p, + max_out_len=max_length, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + forced_eos_token_id=forced_eos_token_id, + temperature=temperature, + repetition_penalty=repetition_penalty, + min_length=min_length) + + generate = forward + + +class FasterCodeGen(CodeGenPreTrainedModel): + + def __init__(self, model, decoding_lib=None, use_fp16_decoding=False): + super(FasterCodeGen, self).__init__() + self._model = model + self.use_fp16_decoding = use_fp16_decoding + self.decoding = InferGptJDecoding(model=model, + decoding_lib=decoding_lib, + use_fp16_decoding=use_fp16_decoding, + transpose_qkv=True) + + def forward(self, + input_ids, + seq_len=None, + attention_mask=None, + top_k=4, + top_p=0.0, + min_length=0, + max_length=256, + bos_token_id=None, + eos_token_id=None, + pad_token_id=None, + forced_eos_token_id=None, + temperature=0, + repetition_penalty=1.0, + decode_strategy="sampling", + num_return_sequences=1, + **model_kwargs): + if input_ids.dtype == paddle.int64: + input_ids = paddle.cast(input_ids, "int32") + + # change top_p to zero if not using top_p sampling for FT + if decode_strategy == "greedy_search": + top_p = 0.0 + top_k = 1 + if top_p == 1.0: + top_p = 0.0 + if seq_len is None: + seq_len = paddle.sum(paddle.cast(input_ids != pad_token_id, + dtype="int32"), + axis=-1, + dtype="int32") + + if num_return_sequences > 1: + input_ids, model_kwargs = self.expand_inputs_for_generation( + input_ids, + expand_size=num_return_sequences, + seq_len=seq_len, + attention_mask=attention_mask) + seq_len = model_kwargs["seq_len"] + attention_mask = model_kwargs.get("attention_mask", None) + + return self.decoding(input_ids, + mem_seq_len=seq_len, + attention_mask=attention_mask, + topk=top_k, + topp=top_p, + max_out_len=max_length, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + forced_eos_token_id=forced_eos_token_id, + temperature=temperature, + repetition_penalty=repetition_penalty, + min_length=min_length) + + generate = forward diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cu new file mode 100644 index 000000000000..32c824ef01b0 --- /dev/null +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cu @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fastertransformer/cuda/masked_multihead_attention_utils.h" +namespace fastertransformer +{ + +template +struct Vec_t {}; +template<> +struct Vec_t { + using Type = float2; +}; +template<> +struct Vec_t { + using Type = uint32_t; +}; + +#ifdef ENABLE_BF16 +template<> +struct Vec_t<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; +#endif + + +template +__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, + T* k_buf, + T* v_buf, + const T* __restrict QKV, + const T* __restrict qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim) +{ + using Vec_t = typename Vec_t::Type; + const int batch_idx = blockIdx.z; + const int head_idx = blockIdx.y; + const int seq_idx = blockIdx.x; + const int tidx = threadIdx.x; + if (tidx * 2 >= size_per_head) { + return; + } + + const int batch_time_idx = seq_len * batch_idx + seq_idx; + const int hidden_idx = head_idx * size_per_head + tidx * 2; + const int n = head_num * size_per_head; + + // src QKV: [batch, time, 3, head, hidden] + const int q_idx = batch_time_idx * 3 * n + hidden_idx; + const int k_idx = batch_time_idx * 3 * n + hidden_idx + n; + const int v_idx = batch_time_idx * 3 * n + hidden_idx + 2 * n; + + Vec_t q = *reinterpret_cast(&QKV[q_idx]); + Vec_t k = *reinterpret_cast(&QKV[k_idx]); + Vec_t v = *reinterpret_cast(&QKV[v_idx]); + + if(qkv_bias != nullptr){ + // qkv_bias: [3, head, hidden] + Vec_t q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); + Vec_t k_bias = *reinterpret_cast(&qkv_bias[hidden_idx + n]); + Vec_t v_bias = *reinterpret_cast(&qkv_bias[hidden_idx + 2 * n]); + + q = mmha::add(q, q_bias); + k = mmha::add(k, k_bias); + v = mmha::add(v, v_bias); + } + + mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, seq_idx); + + // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] + const int dest_idx = size_per_head * seq_len * head_num * batch_idx + size_per_head * seq_len * head_idx + + size_per_head * seq_idx + tidx * 2; + + *reinterpret_cast(&q_buf[dest_idx]) = q; + *reinterpret_cast(&k_buf[dest_idx]) = k; + *reinterpret_cast(&v_buf[dest_idx]) = v; +} + +template +void add_fusedQKV_bias_transpose_kernelLauncher( + T* q_buf, + T* k_buf, + T* v_buf, + T* QKV, + const T* qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + cudaStream_t stream) +{ + if (rotary_embedding_dim == 0) { + const int m = batch_size * seq_len; + const int n = head_num * size_per_head; + dim3 block(384); + dim3 grid((int)(ceil(1.0 * m * n / 384))); + add_fusedQKV_bias_transpose_kernel<<>>( + q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); + } + else { + // To implement rotary embeddings, each thread processes two QKV elems: + dim3 block((size_per_head / 2 + 31) / 32 * 32); + dim3 grid(seq_len, head_num, batch_size); + add_fusedQKV_bias_transpose_kernel<<>>( + q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, rotary_embedding_dim); + } +} + +template void add_fusedQKV_bias_transpose_kernelLauncher( + float* q_buf, + float* k_buf, + float* v_buf, + float* QKV, + const float* qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + cudaStream_t stream); + +template void add_fusedQKV_bias_transpose_kernelLauncher( + half* q_buf, + half* k_buf, + half* v_buf, + half* QKV, + const half* qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + cudaStream_t stream); + +} // namespace fastertransformer \ No newline at end of file diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh new file mode 100644 index 000000000000..0c1fd232994a --- /dev/null +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/attention_kernels.cuh @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace fastertransformer +{ + +template +void add_fusedQKV_bias_transpose_kernelLauncher( + T* q_buf, + T* k_buf, + T* v_buf, + T* QKV, + const T* qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int rotary_embedding_dim, + cudaStream_t stream); + + +} // namespace fastertransformer \ No newline at end of file diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.h index 409919775c3f..8149895a9e7f 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.h +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/cuda_kernels.h @@ -150,6 +150,31 @@ void apply_logits_mask_kernelLauncher(T* log_probs, cudaStream_t stream, const T* logits_mask = nullptr, const bool min_penalty = false, - const int end_id = -1); + const int end_id = -1, + const T* bias = nullptr); + +template +void gptj_start_id_embedding_lookups_kernel_launcher(T* from_tensor, + int* output_ids, + const T* embedding_table, + const int* word_ids, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + +template +void gpj_embedding_lookups_kernel_launcher(T* from_tensor, + const T* embedding_table, + const int* word_ids, + const int local_batch_size, + const int batch_size, + const int hidden_units, + int step, + int ite, + int max_input_len, + const int* start_lengths, + cudaStream_t stream); } // namespace fastertransformer diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/decoding_kernels.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/decoding_kernels.cu index 48d32a44a8ae..23c3aaf3a160 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/decoding_kernels.cu +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/decoding_kernels.cu @@ -340,7 +340,8 @@ __global__ void apply_logits_mask_kernel(int vocab_size_padded, const bool* finished, const T* logits_mask = nullptr, const bool min_penalty = false, - const int end_id = -1) { + const int end_id = -1, + const T* bias = nullptr) { int tid = threadIdx.x; int bid = blockIdx.x; int bbid = blockIdx.y; // batch_size * beam_size: index @@ -355,6 +356,8 @@ __global__ void apply_logits_mask_kernel(int vocab_size_padded, log_probs[i + bbid * vocab_size_padded] += -MAX_T_VAL; } else if (logits_mask) { log_probs[i + bbid * vocab_size_padded] += logits_mask[i]; + } else if (bias) { + log_probs[i + bbid * vocab_size_padded] += bias[i]; } else { continue; } @@ -372,8 +375,9 @@ void apply_logits_mask_kernelLauncher(T* log_probs, cudaStream_t stream, const T* logits_mask, const bool min_penalty, - const int end_id) { - if (logits_mask == nullptr && !min_penalty) return; + const int end_id, + const T* bias) { + if (logits_mask == nullptr && !min_penalty && bias == nullptr) return; dim3 block(256); dim3 grid((vocab_size_padded + block.x - 1) / block.x, @@ -386,7 +390,122 @@ void apply_logits_mask_kernelLauncher(T* log_probs, finished, logits_mask, min_penalty, - end_id); + end_id, + bias); +} + + + template __launch_bounds__(1024, 1) + __global__ void gptj_start_id_embedding_lookups_kernel(T* from_tensor, + int* output_ids, + const T* embedding_table, + const int* word_ids, + const int length, + const int max_length, + const int batch_size, + const int hidden_units) + { + for(int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * length * hidden_units; index += blockDim.x * gridDim.x) + { + // transpose the word_ids [batch, length] (part of [batch, max_length]) to output_ids [length, batch] + if(index < batch_size * max_length) + { + const int seq_id = index % max_length; + const int batch_id = index / max_length; + if(seq_id < length) + output_ids[seq_id * batch_size + batch_id] = word_ids[index]; + // output_ids[index] = word_ids[index]; + } + + // embedding lookup from word ids [batch, length] (part of [batch, max_length]) and [vocab, hidden] to generate embedding [batch, length, hidden] + const int word_index = index / hidden_units; + const int word_index_row = word_index / length; + const int word_index_col = word_index % length; + const int real_word_index = word_index_row * max_length + word_index_col; + const int col_index = index % hidden_units; + from_tensor[index] = embedding_table[word_ids[real_word_index] * hidden_units + col_index]; + } + } + + + template + void gptj_start_id_embedding_lookups_kernel_launcher(T* from_tensor, + int *output_ids, + const T* embedding_table, + const int* word_ids, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream) + { + dim3 grid(min(batch_size * length, 65536)); + dim3 block(min(hidden_units, 1024)); + gptj_start_id_embedding_lookups_kernel<<>>(from_tensor, + output_ids, + embedding_table, + word_ids, + length, + max_length, + batch_size, + hidden_units); + } + + + // TODO Add half2 implementation +template +__global__ void gptj_embedding_lookups_kernel( + T* from_tensor, + const T* embedding_table, + const int* word_ids, + const int local_batch_size, + const int batch_size, + const int hidden_units, + int step, + int ite, + int max_input_len, + const int* start_lengths) { + int timestep = step - 1; + // if the input is padded in the batch, indices of the word_id + // should be shifted forward by the length of the padding. + int len_padding = + max_input_len - start_lengths[local_batch_size * ite + blockIdx.x]; + int idx_word_id = (step == max_input_len) ? timestep - len_padding : timestep; + + int* word_ids_buf = + (int*)word_ids + idx_word_id * batch_size + local_batch_size * ite; + T* from_tensor_buf = from_tensor + blockIdx.x * hidden_units; + for (int index = threadIdx.x; index < hidden_units; index += blockDim.x) { + from_tensor_buf[index] = + embedding_table[word_ids_buf[blockIdx.x] * hidden_units + index]; + } +} + +template +void gpj_embedding_lookups_kernel_launcher(T* from_tensor, + const T* embedding_table, + const int* word_ids, + const int local_batch_size, + const int batch_size, + const int hidden_units, + int step, + int ite, + int max_input_len, + const int* start_lengths, + cudaStream_t stream) { + dim3 grid(min(local_batch_size, 65536)); + dim3 block(min(hidden_units, 1024)); + gptj_embedding_lookups_kernel + <<>>(from_tensor, + embedding_table, + word_ids, + local_batch_size, + batch_size, + hidden_units, + step, + ite, + max_input_len, + start_lengths); } template void init_kernelLauncher_v2(bool* finished, @@ -527,7 +646,8 @@ template void apply_logits_mask_kernelLauncher( cudaStream_t stream, const float* logits_mask, const bool min_penalty, - const int end_id); + const int end_id, + const float* bias); template void apply_logits_mask_kernelLauncher( half* log_probs, @@ -539,6 +659,55 @@ template void apply_logits_mask_kernelLauncher( cudaStream_t stream, const half* logits_mask, const bool min_penalty, - const int end_id); + const int end_id, + const half* bias); + + template + void gptj_start_id_embedding_lookups_kernel_launcher(float* from_tensor, + int* output_ids, + const float* embedding_table, + const int* word_ids, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + + template + void gptj_start_id_embedding_lookups_kernel_launcher(half* from_tensor, + int* output_ids, + const half* embedding_table, + const int* word_ids, + const int length, + const int max_length, + const int batch_size, + const int hidden_units, + cudaStream_t stream); + + template void gpj_embedding_lookups_kernel_launcher( + float* from_tensor, + const float* embedding_table, + const int* word_ids, + const int local_batch_size, + const int batch_size, + const int hidden_units, + int step, + int ite, + int max_input_len, + const int* start_lengths, + cudaStream_t stream); + +template void gpj_embedding_lookups_kernel_launcher( + half* from_tensor, + const half* embedding_table, + const int* word_ids, + const int local_batch_size, + const int batch_size, + const int hidden_units, + int step, + int ite, + int max_input_len, + const int* start_lengths, + cudaStream_t stream); } // end of name space fastertransformer diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.cu index 00b15c2e5a6f..aa03d980d2f6 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.cu +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.cu @@ -1,4 +1,5 @@ /*************************************************************************************************** + * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are not permit- @@ -80,9 +81,11 @@ struct Qk_vec_ {}; template<> struct Qk_vec_ { using Type = float; }; template<> struct Qk_vec_ { using Type = float2; }; template<> struct Qk_vec_ { using Type = float4; }; +template<> struct Qk_vec_ { using Type = float4; }; template<> struct Qk_vec_ { using Type = uint32_t; }; template<> struct Qk_vec_ { using Type = uint32_t; }; template<> struct Qk_vec_ { using Type = uint2; }; +template<> struct Qk_vec_ { using Type = uint4; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -430,7 +433,7 @@ inline size_t smem_size_in_bytes( // The number of partial rows to reduce in the final reduction. // int rows_per_red = threads_per_block / threads_per_value; // to solve `threads_per_block / threads_per_value` is not 2^n - int rows_per_red = pad_active_groups; + int rows_per_red = params.rotary_embedding_dim>0 ? threads_per_block / threads_per_value: pad_active_groups; // The amount of storage needed to finalize the outputs. size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; @@ -438,6 +441,7 @@ inline size_t smem_size_in_bytes( return max(softmax_sz, red_sz); } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ constexpr uint32_t shfl_mask(int threads) { @@ -500,7 +504,7 @@ __global__ void masked_multihead_attention_kernel(Masked_multihead_attention_par // The number of elements per vector. constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh % QK_VEC_SIZE == 0 && Dh / QK_VEC_SIZE <= WARP_SIZE, ""); + static_assert(Dh % QK_VEC_SIZE == 0, ""); // The number of vectors per warp. constexpr int QK_VECS_PER_WARP = Dh / QK_VEC_SIZE; @@ -837,6 +841,568 @@ __global__ void masked_multihead_attention_kernel(Masked_multihead_attention_par } } +template< + // The type of the inputs. Supported types: float and half. + typename T, + // The hidden dimension per head. + int Dh, + int Dh_MAX, + // The number of threads per key. + int THREADS_PER_KEY, + // The number of threads per value. + int THREADS_PER_VALUE, + // The number of threads in a threadblock. + int THREADS_PER_BLOCK + > +__global__ void gptj_masked_multihead_attention_kernel(Masked_multihead_attention_params params, int pad_active_groups) +{ + + // Make sure the hidden dimension per head is a multiple of the number of threads per key. + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + // Make sure the hidden dimension per head is a multiple of the number of threads per value. + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + // The size of a warp. + constexpr int WARP_SIZE = 32; + // The number of warps in a threadblock. + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // Use smem_size_in_bytes (above) to determine the amount of shared memory. + extern __shared__ char smem_[]; + + // The shared memory for the Q*K^T values and partial logits in softmax. + float* qk_smem = reinterpret_cast(smem_); + + // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. + char* logits_smem_ = smem_; + + // DO_CROSS_ATTENTION = false + constexpr bool DO_CROSS_ATTENTION = false; + +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(T) != 4) { + // TODO - cahnge to tlength + logits_smem_ += + (DO_CROSS_ATTENTION) ? div_up(params.seq_length + 1, 4) * 16 : div_up(params.timestep + 1, 4) * 16; + } + T* logits_smem = reinterpret_cast(logits_smem_); +#else + float* logits_smem = reinterpret_cast(logits_smem_); +#endif + + // The shared memory to do the final reduction for the output values. Reuse qk_smem. + T* out_smem = reinterpret_cast(smem_); + + // The shared memory buffers for the block-wide reductions. One for max, one for sum. + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec = typename Qk_vec_::Type; + + // Use alignment for safely casting the shared buffers as Qk_vec. + // Shared memory to store Q inputs. + __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; + + // This is one of the reasons we should have a separate kernel for cross attention + __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec = typename Qk_vec_::Type; + // The number of elements per vector. + constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // We will use block wide reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // The number of vectors per warp. + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread + // owns x elements, we have to decompose the linear index into chunks of x values and the posi- + // tion of the thread in that chunk. + + // The number of elements in a chunk of 16B (that's the x in the above formula). + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + // The number of K vectors in 16B. + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); + + // The batch/beam idx + const int bi = blockIdx.y; + if (params.finished != nullptr && params.finished[bi] == true) { + return; + } + // The beam idx + const int beami = bi % params.beam_width; + // The "beam-aware" batch idx + const int bbi = bi / params.beam_width; + // The head. + const int hi = blockIdx.x; + // Combine the batch and the head indices. + const int bhi = bi * params.num_heads + hi; + // Combine the "beam-aware" batch idx and the head indices. + const int bbhi = bbi * params.beam_width * params.num_heads + hi; + // The thread in the block. + const int tidx = threadIdx.x; + + // While doing the product Q*K^T for the different keys we track the max. + float qk_max = -FLT_MAX; + + float qk = 0.0F; + + int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; + + // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : + (params.length_per_sample == nullptr) ? params.timestep : + params.length_per_sample[bi]; + // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. + if (tidx < QK_VECS_PER_WARP) { + + // The offset in the Q and K buffer also accounts for the batch. + int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + // The offset in the bias buffer. + int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + + // Trigger the loads from the Q and K buffers. + Qk_vec q; + zero(q); + q = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? *reinterpret_cast(¶ms.q[qk_offset]) : q; + Qk_vec k; + zero(k); + if (DO_CROSS_ATTENTION) { + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi * params.seq_length * Dh + co * params.seq_length * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength * QK_ELTS_IN_16B + ci; + k = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? *reinterpret_cast(¶ms.k_cache[offset]) : + k; + } + else { + k = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? *reinterpret_cast(¶ms.k[qk_offset]) : k; + } + + // Trigger the loads from the Q and K bias buffers. + Qk_vec q_bias; + zero(q_bias); + q_bias = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? + *reinterpret_cast(¶ms.q_bias[qk_bias_offset]) : + q_bias; + Qk_vec k_bias; + zero(k_bias); + + if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { + k_bias = (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? + *reinterpret_cast(¶ms.k_bias[qk_bias_offset]) : + k_bias; + } + + // Computes the Q/K values with bias. + q = add(q, q_bias); + if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { + k = add(k, k_bias); + if (params.rotary_embedding_dim > 0) { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep); + } + } + else { + if (params.rotary_embedding_dim > 0) { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep); + } + } + + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + + // Store Dh values of k_bias into smem, since will need to add later + // if params.timestep == 0 + if (DO_CROSS_ATTENTION && params.timestep == 0) { + *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; + } + + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi * params.seq_length * Dh + co * params.seq_length * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength * QK_ELTS_IN_16B + ci; + + if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { + // Trigger the stores to global memory. + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[offset]) = k; + } + } + + // Compute \sum_i Q[i] * K^T[i] for the current timestep. +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; +#else + using Qk_vec_acum = Qk_vec; +#endif + qk = dot(q, k); + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if (tidx == 0) { + // Normalize qk. + qk *= params.inv_sqrt_dh; + + if (params.relative_attention_bias_float != nullptr) { + qk = qk + + params.relative_attention_bias_float[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + tlength]; + } + else if (params.relative_attention_bias_half != nullptr) { + qk = qk + + (float) + params.relative_attention_bias_half[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + tlength]; + } + qk_max = qk; + qk_smem[tlength] = qk; + // qk_smem[params.timestep] = qk; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The type of queries and keys for the math in the Q*K^T product. + using K_vec = typename K_vec_::Type; + // The number of elements per vector. + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + // The number of elements per thread. + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + // The number of vectors per thread. + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + // The position the first key loaded by each thread from the cache buffer (for this B * H). + int ko = tidx / THREADS_PER_KEY; + // The position of the thread in the chunk of keys. + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); + + // Load the Q values from shared memory. The values are reused during the loop on K. + K_vec q[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + + K_vec k_bias[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; + if (DO_CROSS_ATTENTION && params.timestep == 0) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + k_bias[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + } + + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + // The base pointer for the key in the cache buffer. + T* k_cache = ¶ms.k_cache[bhi * params.seq_length * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* k_cache_batch = ¶ms.k_cache[bbhi * params.seq_length * Dh + ki]; + + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). + // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + int ti_end = div_up(tlength, K_PER_WARP) * K_PER_WARP; + + // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + + // The keys loaded from the key cache. + K_vec k[K_VECS_PER_THREAD]; + K_vec k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.seq_length + ti; + // if( ti < params.timestep ) { + if (ti < tlength) { + const int beam_src = + (params.cache_indir != nullptr) ? + params.cache_indir[(bbi * params.beam_width + beami) * params.seq_length + ti] : + 0; + const int beam_offset = beam_src * params.num_heads * params.seq_length * Dh; + k[ii] = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.seq_length) ? + *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]) : + k_vec_zero; + // add bias and update k_cache + if (DO_CROSS_ATTENTION && params.timestep == 0) { + k[ii] = add(k[ii], k_bias[ii]); + if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.seq_length) { + *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; + } + } + } + } + + // Perform the dot product and normalize qk. + // + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q, k) * params.inv_sqrt_dh; + bool is_mask = (params.input_lengths != nullptr && ti >= params.input_lengths[bi] && ti < params.max_input_len); + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias_float != nullptr) { + qk = qk + + params.relative_attention_bias_float[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + ti]; + } + else if (params.relative_attention_bias_half != nullptr) { + qk = qk + + (float) + params.relative_attention_bias_half[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + ti]; + } + qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti] = qk; + } + } + +// Perform the final reduction to compute the max inside each warp. +// +// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the +// group so it's not needed to run the reduction inside the group (again). +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + const int warp = tidx / WARP_SIZE; + const int lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Compute the logits and start the sum. + float sum = 0.f; + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + bool is_mask = (params.input_lengths != nullptr && ti >= params.input_lengths[bi] && ti < params.max_input_len); + float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); + sum += logit; + qk_smem[ti] = logit; + } + + // Compute the sum. + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // Normalize the logits. + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); + } + + // Put Values part below so we leverage __syncthreads + // from the previous step + + // The number of elements per vector. + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + // A vector of V elements for the current timestep. + using V_vec = typename V_vec_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + + // The base pointer for the value in the cache buffer. + T* v_cache = ¶ms.v_cache[bhi * params.seq_length * Dh + vi]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* v_cache_batch = ¶ms.v_cache[bbhi * params.seq_length * Dh + vi]; + + // The number of values processed per iteration of the loop. + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + // One group of threads computes the product(s) for the current timestep. + V_vec v_bias; + zero(v_bias); + // if( vo == params.timestep % V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { + if (vo == tlength % V_PER_ITER) { + // Trigger the loads from the V bias buffer. + if (params.v_bias != nullptr) { + v_bias = *reinterpret_cast(¶ms.v_bias[hi * Dh + vi]); + } + if (DO_CROSS_ATTENTION) { + *reinterpret_cast(&bias_smem[vi]) = v_bias; + } + } + } + } + + // From previous, before values, step + // Also make sure the logits are in shared memory. + __syncthreads(); + + // Values continued +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec; +#endif + // The partial outputs computed by each thread. + V_vec_acum out; + zero(out); + + // Loop over the timesteps to compute the partial outputs. + // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = vo; ti < tlength; ti += V_PER_ITER) { + + // Fetch offset based on cache_indir when beam sampling + const int beam_src = (params.cache_indir != nullptr) ? + params.cache_indir[(bbi * params.beam_width + beami) * params.seq_length + ti] : + 0; + const int beam_offset = beam_src * params.num_heads * params.seq_length * Dh; + // Load the values from the cache. + V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh]); + if (DO_CROSS_ATTENTION && params.timestep == 0) { + v = add(v, *reinterpret_cast(&bias_smem[vi])); + *reinterpret_cast(&v_cache[ti * Dh]) = v; + } + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti]; + out = fma(logit, cast_to_float(v), out); +#else + T logit = logits_smem[ti]; + + // Update the partial sums. + out = fma(logit, v, out); +#endif + } + } + + // One group of threads computes the product(s) for the current timestep. + // if( vo == params.timestep % V_PER_ITER ) { + if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { + + V_vec v; + if (DO_CROSS_ATTENTION) { + v = *reinterpret_cast(&v_cache[tlength * Dh]); + } + else { + // Trigger the loads from the V buffer. + v = *reinterpret_cast(¶ms.v[qkv_base_offset + vi]); + // Trigger the loads from the V bias buffer. + // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + } + + // Compute the V values with bias. + if (!DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0)) { + v = add(v, v_bias); + + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength * Dh]) = v; + } + + // Initialize the output value with the current timestep. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + // out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[tlength], cast_to_float(v), out); +#else + // out = fma(logits_smem[params.timestep], v, out); + out = fma(logits_smem[tlength], v, out); +#endif + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different partial outputs. + if (Dh == Dh_MAX || vi < Dh) +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); +#else + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; +#endif + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace mmha @@ -846,21 +1412,40 @@ __global__ void masked_multihead_attention_kernel(Masked_multihead_attention_par #define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ int pad_active_groups = 1 << static_cast(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK, pad_active_groups); \ - dim3 grid(params.num_heads, params.batch_size); \ + dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel \ <<>>(params, pad_active_groups) + +#define GPTJ_MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ + int pad_active_groups = 1 << static_cast(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK, pad_active_groups); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::gptj_masked_multihead_attention_kernel \ + <<>>(params, pad_active_groups) + //////////////////////////////////////////////////////////////////////////////////////////////////// template < typename T, int Dh, int Dh_MAX> void mmha_launch_kernel(const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream) { - constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16; - if( params.timestep < 32 ) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); - } else if( params.timestep < 2048 ) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); - } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); + if(params.rotary_embedding_dim>0){ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + if( params.timestep < 32 ) { + GPTJ_MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); + } else if( params.timestep < 2048 ) { + GPTJ_MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); + } else { + GPTJ_MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); + } + }else{ + constexpr int THREADS_PER_VALUE = Dh * sizeof(T) / 16; + if( params.timestep < 32 ) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); + } else if( params.timestep < 2048 ) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); + } else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); + } } } @@ -884,6 +1469,18 @@ void masked_multihead_attention_(const Masked_multihead_attention_params &par case 128: mmha_launch_kernel(params, stream); break; + case 160: + mmha_launch_kernel(params, stream); + break; + case 192: + mmha_launch_kernel(params, stream); + break; + case 224: + mmha_launch_kernel(params, stream); + break; + case 256: // GPTJ/CodeGen + mmha_launch_kernel(params, stream); + break; default: // assert(false); throw std::runtime_error("Unsupported model size."); diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.h new file mode 100644 index 000000000000..992a598536ee --- /dev/null +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(call) do { \ + cudaError_t status_ = call; \ + if( status_ != cudaSuccess ) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ +} while(0) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The structure of parameters for the masked multihead attention kernel. +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. + +template< typename T > +struct Masked_multihead_attention_params { + + // The output buffer. Dimensions B x D. + T *out; + + // The input Qs and the associated bias. Dimensions B x D and D, resp. + const T *q, *q_bias; + // The input Ks and the associated bias. Dimensions B x D and D, resp. + const T *k, *k_bias; + // The input Vs and the associated bias. Dimensions B x D and D, resp. + const T *v, *v_bias; + + // The cache for the Ks. The size must be at least B x L x D. + T *k_cache; + // The cache for the Vs. The size must be at least B x L x D. + T *v_cache; + + // The indirections to use for cache when beam sampling. + const int* cache_indir = nullptr; + + // allows to exist attention eary + bool *finished; + + // Stride to handle the case when KQV is a single buffer + int stride; + + // The batch size. + int batch_size; + // The sequence length. + int seq_length; + // The number of heads (H). + int num_heads; + // The hidden dimension per head (Dh). + int hidden_size_per_head; + // The current timestep. + int timestep; + + // The per-head latent space reserved for rotary embeddings. + int rotary_embedding_dim = 0; + + // The 1.f / sqrt(Dh). Computed on the host. + float inv_sqrt_dh; + + // params for masking. + bool is_mask; + const int *input_lengths = input_lengths; + int max_input_len = max_input_len; + + const float* relative_attention_bias_float = nullptr; + const half* relative_attention_bias_half = nullptr; + int relative_attention_bias_stride; + // The beam width + int beam_width = 1; + // required in case of cross attention + int* memory_length_per_sample = nullptr; + // required in case of masked attention with different length + const int* length_per_sample = nullptr; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention (const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream); +void masked_multihead_attention (const Masked_multihead_attention_params ¶ms, const cudaStream_t &stream); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention_utils.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention_utils.h new file mode 100644 index 000000000000..cdcc47a06180 --- /dev/null +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/masked_multihead_attention_utils.h @@ -0,0 +1,265 @@ +/* + * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace mmha { + +inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step) +{ + const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim); + return {cos(inv_freq), sin(inv_freq)}; +} + +inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef) +{ + float2 rot_v; + rot_v.x = coef.x * v.x - coef.y * v.y; + rot_v.y = coef.x * v.y + coef.y * v.x; + return rot_v; +} + +inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) +{ + float2 fv = half2_to_float2(v); + float2 rot_fv = rotary_embedding_transform(fv, coef); + return float2_to_half2(rot_fv); +} + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) +{ + float2 fv = bf1622float2(v); + float2 rot_fv = rotary_embedding_transform(fv, coef); + return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); +} +#endif + +inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step) +{ + return; +} + +inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step) +{ + return; +} + +inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + + Float4_& q_ = *reinterpret_cast(&q); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + q_.x = rotary_embedding_transform(q_.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + q_.y = rotary_embedding_transform(q_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + + Float4_& q_ = *reinterpret_cast(&q); + Float4_& k_ = *reinterpret_cast(&k); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + q_.x = rotary_embedding_transform(q_.x, coef0); + k_.x = rotary_embedding_transform(k_.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + q_.y = rotary_embedding_transform(q_.y, coef1); + k_.y = rotary_embedding_transform(k_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} + +#ifdef ENABLE_BF16 +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void +apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} +#endif // ENABLE_BF16 + +} // namespace mmha \ No newline at end of file diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_attention.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_attention.h index b2d657a18610..6702087e9b11 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_attention.h +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_attention.h @@ -1,4 +1,5 @@ /* + * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cu index 1d2aa51410e0..6245cdd3e46f 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cu +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cu @@ -159,4 +159,96 @@ template void transpose_general_kernelLauncher(half* dst, const int head_num, const int size_per_head, cudaStream_t stream); + + + +template +void fusedQKV_masked_attention_dispatch_v2( + const T* qkv_buf, const T* qkv_bias, + T* key_cache, T* value_cache, + T* context_buf, const bool* finished, int max_batch_size, int inference_batch_size, + int head_num, int size_per_head, const int step, const int max_seq_len, + const int max_input_len, const int* input_lengths, const int rotary_embedding_dim, cudaStream_t stream) +{ + using DataType = typename std::conditional::type; + // Prepare the parameters. + Masked_multihead_attention_params params; + memset(¶ms, 0, sizeof(params)); + int hidden_units = head_num * size_per_head; + if (qkv_bias != nullptr) { + params.q_bias = reinterpret_cast(qkv_bias); + params.k_bias = reinterpret_cast(qkv_bias) + hidden_units; + params.v_bias = reinterpret_cast(qkv_bias) + 2 * hidden_units; + } + else { + // gptj/codegen no bias + params.q_bias = nullptr; + params.k_bias = nullptr; + params.v_bias = nullptr; + } + + // Set the output buffer. + params.out = reinterpret_cast(context_buf); + + // Set the input buffers. + params.q = reinterpret_cast(qkv_buf); + params.k = reinterpret_cast(qkv_buf) + hidden_units; + params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units; + params.stride = 3 * hidden_units; + params.finished = const_cast(finished); + + params.k_cache = reinterpret_cast(key_cache); + params.v_cache = reinterpret_cast(value_cache); + params.batch_size = inference_batch_size; + params.seq_length = max_seq_len; + params.timestep = step-1; + params.num_heads = head_num; + params.hidden_size_per_head = size_per_head; + // GptJ: rotary_embedding + params.rotary_embedding_dim = rotary_embedding_dim; + params.inv_sqrt_dh = 1.F / sqrtf((float) params.hidden_size_per_head); + + params.is_mask = true; + params.input_lengths = input_lengths; + params.max_input_len = max_input_len; + + masked_multihead_attention(params, stream); +} + +template void fusedQKV_masked_attention_dispatch_v2( + const float* qkv_buf, + const float* qkv_bias, + float* key_cache, + float* value_cache, + float* context_buf, + const bool* finished, + int max_batch_size, + int inference_batch_size, + int head_num, + int size_per_head, + const int step, + const int max_seq_len, + const int max_input_len, + const int* input_lengths, + const int rotary_embedding_dim, + cudaStream_t stream); + +template void fusedQKV_masked_attention_dispatch_v2( + const half* qkv_buf, + const half* qkv_bias, + half* key_cache, + half* value_cache, + half* context_buf, + const bool* finished, + int max_batch_size, + int inference_batch_size, + int head_num, + int size_per_head, + const int step, + const int max_seq_len, + const int max_input_len, + const int* input_lengths, + const int rotary_embedding_dim, + cudaStream_t stream); + } diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cuh b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cuh index cabc89591971..624f36235c4b 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cuh +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/open_decoder.cuh @@ -46,4 +46,12 @@ __global__ void transpose(T* src, const int seq_len, const int head_num, const int size_per_head); + +template +void fusedQKV_masked_attention_dispatch_v2( + const T* qkv_buf, const T* qkv_bias, + T* key_cache, T* value_cache, + T* context_buf, const bool* finished, int max_batch_size, int inference_batch_size, + int head_num, int size_per_head, const int step, const int max_seq_len, + const int max_input_len, const int* input_lengths, const int rotary_embedding_dim, cudaStream_t stream); } diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu index bfd680a6b24f..2f364f9c880d 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/topk_kernels.cu @@ -35,13 +35,28 @@ __global__ void ker_curand_setup(curandState_t* state, &state[blockIdx.x * blockDim.x + threadIdx.x]); } +__global__ void ker_curand_setup_bsz_one(curandState_t* state, + const int size, + const int seed) { + if (threadIdx.x + blockIdx.x * blockDim.x < size) + curand_init(seed, + 0, + seed, + &state[blockIdx.x * blockDim.x + threadIdx.x]); +} + void ker_curand_setupLauncher(curandState_t* state, DecodingSamplingArguments args, cudaStream_t stream) { dim3 block(256); dim3 grid((int)(ceil(args.batch_size_ * 1.0 / 256))); int seed = args.seed_ != -1 ? args.seed_ : clock() % INT_MAX; - ker_curand_setup<<>>(state, args.batch_size_, seed); + if(args.batch_size_ != 1) + ker_curand_setup<<>>(state, args.batch_size_, seed); + else + // Reduce the huge occupation of gpu memory due to curand_init func when bsz=1. + // TODO(gongenlei): Solve above problem when bsz > 1. + ker_curand_setup_bsz_one<<>>(state, args.batch_size_, seed); } diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cu b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cu index 46a5efc9f57d..365b2fc0b388 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cu +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/cuda/transformer_kernels.cu @@ -404,11 +404,17 @@ __global__ void add_bias_input_layernorm_2(const T* __restrict input, for (int i = tid; i < n; i += blockDim.x) { float local_out = (float)(__ldg(&input[blockIdx.x * n + i])); local_out += (float)(output[blockIdx.x * n + i]); - local_out += (float)(__ldg(&bias[i])); + if(bias != nullptr){ + local_out += (float)(__ldg(&bias[i])); + } output[blockIdx.x * n + i] = (T)local_out; local_sum += local_out; } + if (gamma == nullptr || beta == nullptr){ + return; + } + mean = blockReduceSum(local_sum); if (threadIdx.x == 0) s_mean = mean / n; diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/gptj.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/gptj.h new file mode 100644 index 000000000000..208fb4663b7f --- /dev/null +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/gptj.h @@ -0,0 +1,946 @@ +/* + * Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Decoder transformer + **/ + +#pragma once + +#include "fastertransformer/utils/common.h" +#include "fastertransformer/utils/functions.h" +#include "fastertransformer/utils/allocator.h" +#include "fastertransformer/utils/arguments.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/open_decoder.h" +#include +#include +#include "fastertransformer/utils/nvtx_utils.h" + +namespace fastertransformer +{ + +template +class DecodingGptJ +{ +private: + typedef DecoderTransformerTraits Traits_; + typedef typename Traits_::DataType DataType_; + const IAllocator &allocator_; + struct GptJArguments args_; + TensorParallelParam t_parallel_param_; + LayerParallelParam l_parallel_param_; + + const cudaDataType_t computeType_ = Traits_::computeType; + const cudaDataType_t AType_ = Traits_::AType; + const cudaDataType_t BType_ = Traits_::BType; + const cudaDataType_t CType_ = Traits_::CType; + std::map cublasAlgoMap_; + + DataType_ *embedding_kernel_padded_; + DataType_ *embedding_bias_padded_; + + OpenDecoder *decoder_; + DataType_ **K_cache_; + DataType_ **V_cache_; + DataType_ *from_tensor_[2]; + DataType_ *decoder_buf_; + DataType_ *decoder_normed_result_buf_; + DataType_ *logits_buf_; + void *buf_; + + void *topk_workspace_ = nullptr; + size_t topk_workspace_size_ = 0; + void *topp_workspace_ = nullptr; + size_t topp_workspace_size_ = 0; + void *topk_topp_workspace_ = nullptr; + size_t topk_topp_workspace_size_ = 0; + void *cublas_workspace_ = nullptr; + int *topp_id_vals_buf_; + int *topp_offset_buf_; + curandState_t *curandstate_buf_; + int *begin_topp_offset_buf_; + + size_t nccl_buf_size_; + DataType_ *nccl_logits_buf_; + + bool *finished_buf_; + bool *h_finished_buf_; + +public: + DecodingGptJ(const IAllocator &allocator, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int vocab_size, + const int decoder_layers, + const int start_id, + const int end_id, + const int candidate_num = 1, + const float probability_threshold = 0.0, + const float temperature = 1.0, + const int tensor_para_size = 1, + const int layer_para_size = 1, + const bool is_fuse_QKV = true, + const float repetition_penalty = 1.0, + const int seed = -1, + const int rotary_embedding_dim = 0, + const int min_length = 0) : allocator_(allocator) + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + assert(temperature != 0.0); + assert(repetition_penalty > 0.0); + assert(candidate_num > 0 || probability_threshold > 0.0); + assert(decoder_layers % layer_para_size == 0); + + args_.batch_size_ = batch_size; + args_.seq_len_ = seq_len; + args_.head_num_ = head_num; + args_.size_per_head_ = size_per_head; + args_.hidden_units_ = head_num * size_per_head; + args_.decoder_layers_ = decoder_layers; + args_.vocab_size_ = vocab_size; + args_.start_id_ = start_id; + args_.end_id_ = end_id; + args_.candidate_num_ = candidate_num; + args_.probability_threshold_ = probability_threshold; + args_.temperature_ = temperature; + args_.repetition_penalty_ = repetition_penalty; + /***** newly added by PaddleNLP *****/ + args_.seed_ = seed; + args_.rotary_embedding_dim_ = rotary_embedding_dim; + args_.min_length_ = min_length; + + K_cache_ = new DataType_ *[1]; + V_cache_ = new DataType_ *[1]; + + decoder_ = new OpenDecoder(args_.head_num_, size_per_head, 0 /* memory_hidden_units */, is_fuse_QKV); + decoder_->set_max_batch_size(args_.batch_size_); + + + // args_.vocab_size_padded_ = div_up(args_.vocab_size_, 64) * 64; + if(std::is_same::value) + { + args_.vocab_size_padded_ = args_.vocab_size_; + } + else + { + args_.vocab_size_padded_ = div_up(args_.vocab_size_, 64) * 64; + } + + size_t from_tensor_size = args_.batch_size_ * args_.hidden_units_; // type T + size_t decoder_workspace_size = (size_t)decoder_->getWorkspaceSize(); // type T + size_t decoder_normed_result_buffer_size = args_.batch_size_ * args_.hidden_units_; // type T + // cache costs lots of memory, so we only store part of them when we use multi-gpu for inference + size_t cache_size = args_.batch_size_ * args_.seq_len_ * args_.hidden_units_ / tensor_para_size; // type T + size_t logits_buf_size = args_.batch_size_ * args_.vocab_size_padded_; // type T + + size_t topp_id_vals_buf_size = args_.batch_size_ * args_.vocab_size_padded_; // type int + size_t topp_offset_buf_size = args_.batch_size_ + 1; + size_t begin_topp_offset_buf_size = topp_offset_buf_size; + size_t curandState_size = args_.batch_size_; + size_t finished_buf_size = args_.batch_size_; + + const int MEM_C = 128; + size_t embedding_kernel_transposed_padded_size = args_.hidden_units_ * args_.vocab_size_padded_; + embedding_kernel_transposed_padded_size = div_up(embedding_kernel_transposed_padded_size, MEM_C) * MEM_C; + + + size_t padded_embedding_bias_size = args_.vocab_size_padded_; + if(std::is_same::value || (std::is_same::value && args_.vocab_size_padded_ == args_.vocab_size_)) + { + padded_embedding_bias_size = 0; + } + + // prevent memory misalinged address + logits_buf_size = (size_t)(ceil(logits_buf_size / 4.)) * 4; + + topp_id_vals_buf_size = (size_t)(ceil(topp_id_vals_buf_size / 4.)) * 4; + topp_offset_buf_size = (size_t)(ceil(topp_offset_buf_size / 4.)) * 4; + begin_topp_offset_buf_size = topp_offset_buf_size; + curandState_size = (size_t)(ceil(curandState_size / 32.)) * 32; + finished_buf_size = (size_t)(ceil(finished_buf_size / 32.)) * 32; + + topP_sampling_kernel_kernelLauncher_v2(topp_workspace_, + topp_workspace_size_, + logits_buf_, + topp_id_vals_buf_, + topp_offset_buf_, + begin_topp_offset_buf_, + nullptr, + curandstate_buf_, + args_, + nullptr, + nullptr, + args_.vocab_size_padded_, + 0, + args_.batch_size_); + + topK_sampling_kernel_kernelLauncher_v2(topk_workspace_, + topk_workspace_size_, + logits_buf_, + nullptr, + nullptr, + nullptr, + curandstate_buf_, + args_, + 0, + args_.batch_size_); + + topK_topP_sampling_kernel_kernelLauncher_v2(topk_topp_workspace_, + topk_topp_workspace_size_, + nullptr, + logits_buf_, + nullptr, + curandstate_buf_, + args_, + 0, + args_.batch_size_); + + size_t datatype_buf_size = from_tensor_size * 2 + decoder_workspace_size + + cache_size * 2 * (args_.decoder_layers_ / layer_para_size) + decoder_normed_result_buffer_size; + + nccl_buf_size_ = args_.batch_size_ * args_.vocab_size_padded_; + nccl_buf_size_ = (size_t)(ceil(nccl_buf_size_ / 4.)) * 4; + + buf_ = reinterpret_cast(allocator_.malloc( + ((sizeof(DataType_) == sizeof(half)) ? CUBLAS_WORKSPACE_SIZE : 0) + + sizeof(DataType_) * embedding_kernel_transposed_padded_size + + sizeof(DataType_) * (datatype_buf_size + logits_buf_size) + + sizeof(DataType_) * (padded_embedding_bias_size) + + sizeof(int) * (topp_id_vals_buf_size + topp_offset_buf_size + begin_topp_offset_buf_size) + + topp_workspace_size_ + topk_workspace_size_ + topk_topp_workspace_size_ + sizeof(DataType_) * nccl_buf_size_ + + finished_buf_size + curandState_size * sizeof(curandState_t))); + + if (sizeof(DataType_) == sizeof(half)) + { + cublas_workspace_ = buf_; + embedding_kernel_padded_ = (DataType_ *)((char*)cublas_workspace_ + CUBLAS_WORKSPACE_SIZE); + } + else + { + cublas_workspace_ = nullptr; + embedding_kernel_padded_ = (DataType_ *)buf_; + } + embedding_bias_padded_ = (DataType_ *)(embedding_kernel_padded_ + embedding_kernel_transposed_padded_size); + from_tensor_[0] = (DataType_ *)(embedding_bias_padded_ + padded_embedding_bias_size); + from_tensor_[1] = (DataType_ *)(from_tensor_[0] + from_tensor_size); + + K_cache_[0] = from_tensor_[1] + from_tensor_size + 0 * cache_size * args_.decoder_layers_ / layer_para_size; + V_cache_[0] = from_tensor_[1] + from_tensor_size + 1 * cache_size * args_.decoder_layers_ / layer_para_size; + + decoder_buf_ = V_cache_[0] + cache_size * args_.decoder_layers_ / layer_para_size; + decoder_normed_result_buf_ = (decoder_buf_ + decoder_workspace_size); + logits_buf_ = decoder_normed_result_buf_ + decoder_normed_result_buffer_size; + topp_id_vals_buf_ = (int *)((DataType_*)logits_buf_ + logits_buf_size); + begin_topp_offset_buf_ = (int *)(topp_id_vals_buf_ + topp_id_vals_buf_size); + topp_offset_buf_ = (int *)((int*)begin_topp_offset_buf_ + begin_topp_offset_buf_size); + topp_workspace_ = (void *)((int*)topp_offset_buf_ + topp_offset_buf_size); + topk_workspace_ = (void *)((char*)topp_workspace_ + topp_workspace_size_); + topk_topp_workspace_ = (void *)((char*)topk_workspace_ + topk_workspace_size_); + nccl_logits_buf_ = (DataType_ *)((char*)topk_topp_workspace_ + topk_topp_workspace_size_); + curandstate_buf_ = (curandState_t*)(nccl_logits_buf_ + nccl_buf_size_); + finished_buf_ = (bool*)(curandstate_buf_ + curandState_size); + h_finished_buf_ = new bool[args_.batch_size_]; + + cudaMemset(embedding_kernel_padded_, 0, embedding_kernel_transposed_padded_size * sizeof(DataType_)); + + int isConfigExist = access("decoding_gemm_config.in", 0); + if (isConfigExist == -1) + printf("[WARNING] decoding_gemm_config.in is not found\n"); + else + { + readAlgoFromConfig(cublasAlgoMap_, 1); + // check that the gemm_config setting is runnable + for (auto iter = cublasAlgoMap_.begin() ; iter != cublasAlgoMap_.end() ; iter++) + { + int algoId = iter->second.algoId; + int stages = iter->second.stages; + //only check for cublas + if (stages != -1) + continue; + if (Traits_::OpType == OperationType::FP32) + { + if (algoId > CUBLAS_GEMM_ALGO23 || algoId < CUBLAS_GEMM_DEFAULT) + { + // the algorithm is not for FP32 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP32. \n", algoId); + exit(-1); + } + } + else + { + if (algoId > CUBLAS_GEMM_ALGO15_TENSOR_OP || algoId < CUBLAS_GEMM_DEFAULT_TENSOR_OP) + { + // the algorithm is not for FP16 + printf("[ERROR] cuBLAS Algorithm %d is not used in FP16. \n", algoId); + exit(-1); + } + } + } + } + } + + void set_tensor_parallel_param(const TensorParallelParam param) + { + t_parallel_param_ = param; + decoder_->set_tensor_parallel_param(param); + } + + void set_layer_parallel_param(const LayerParallelParam param) + { + l_parallel_param_ = param; + decoder_->set_layer_parallel_param(param); + } + + + void forward_context(const DecoderInitParam *decoder_param, + const DecodingInitParam decoding_params) + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + const int input_len = decoding_params.request_input_len; + const int max_len = (decoding_params.request_output_len > 0 && input_len + decoding_params.request_output_len <= args_.seq_len_) ? + input_len + decoding_params.request_output_len : + args_.seq_len_; + const int request_batch_size = decoding_params.request_batch_size; + cudaMemsetAsync(decoding_params.output_ids, 0, sizeof(int) * request_batch_size * max_len, decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + // const int input_len = decoding_params.request_input_len; + const int max_input_len = decoding_params.max_input_len; + + // d_start_ids: [batch * seqlen] + if(input_len == 1) + { + cudaMemcpyAsync(decoding_params.output_ids, decoding_params.d_start_ids, + sizeof(int) * request_batch_size, cudaMemcpyDeviceToDevice, decoding_params.stream); + return; + } + const int local_batch_size = ceil(request_batch_size * 1.0 / l_parallel_param_.world_size); + const int m = local_batch_size * input_len; + const int h_1 = args_.hidden_units_; + + DataType_* from_tensor[2]; + DataType_* decoder_output; + DataType_* decoder_workspace; + void *buf = reinterpret_cast(allocator_.malloc( + decoder_->getContextWorkspaceSize(input_len, local_batch_size) + + (m * h_1 + 2 * request_batch_size * input_len * h_1) * sizeof(DataType_) + )); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + from_tensor[0] = (DataType_*) buf; + from_tensor[1] = from_tensor[0] + request_batch_size * input_len * h_1; + decoder_output = from_tensor[1] + request_batch_size * input_len * h_1; + decoder_workspace = decoder_output + m * h_1; + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + if(l_parallel_param_.rank == 0) + { + PUSH_RANGE("Before Transformer/Embedding") + gptj_start_id_embedding_lookups_kernel_launcher(from_tensor[0], + decoding_params.output_ids, + decoding_params.embedding_table, + decoding_params.d_start_ids, + input_len, + max_input_len, + request_batch_size, + args_.hidden_units_, + decoding_params.stream); + POP_RANGE +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + + int ite_num = (int)(ceil(request_batch_size * 1.0 / local_batch_size)); + for(int ite = 0; ite < ite_num; ite++) + { + int in_id, out_id; + for (int layer = 0; layer < args_.decoder_layers_; ++layer) + { + if(l_parallel_param_.is_valid(layer)) + { + in_id = layer & 0x1; + out_id = 1 - in_id; + + if(layer == l_parallel_param_.layers_per_group * l_parallel_param_.rank && layer != 0 && l_parallel_param_.world_size > 1) + { + const int size = m * t_parallel_param_.local_hidden_units_; + nccl_recv(from_tensor[in_id] + ite * m * h_1 + size * t_parallel_param_.rank, size, l_parallel_param_.rank - 1, + l_parallel_param_, decoding_params.stream); + all2all_gather(from_tensor[in_id] + ite * m * h_1, from_tensor[in_id] + ite * m * h_1, size, + t_parallel_param_, decoding_params.stream); + } + + decoder_->initialize(decoder_param[layer], decoder_buf_, cublas_workspace_, false); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + int dummy_decoder_max_seq_len = args_.seq_len_; + // int dummy_decoder_max_seq_len = -1; + size_t cache_offset; + if(dummy_decoder_max_seq_len == -1) + { + cache_offset = (layer - l_parallel_param_.layers_per_group * l_parallel_param_.rank) * + args_.batch_size_ * args_.seq_len_ * t_parallel_param_.local_hidden_units_; + } + else + { + cache_offset = (layer - l_parallel_param_.layers_per_group * l_parallel_param_.rank) * + args_.batch_size_ * args_.seq_len_ * t_parallel_param_.local_hidden_units_ + + ite * local_batch_size * args_.seq_len_ * t_parallel_param_.local_hidden_units_; + } + decoder_->forward_context(decoder_workspace, + from_tensor[out_id] + ite * m * h_1, + K_cache_[0] + cache_offset, + V_cache_[0] + cache_offset, + from_tensor[in_id] + ite * m * h_1, + decoding_params.d_attn_mask + ite * local_batch_size * input_len * input_len, + local_batch_size, + input_len, + ite, + dummy_decoder_max_seq_len, + layer == args_.decoder_layers_ - 1, + nullptr, + args_.rotary_embedding_dim_); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + if(layer == l_parallel_param_.layers_per_group * (l_parallel_param_.rank + 1) - 1 && layer != args_.decoder_layers_ - 1 && l_parallel_param_.world_size > 1) + { + const int size = m * t_parallel_param_.local_hidden_units_; + nccl_send(from_tensor[out_id] + ite * m * h_1 + size * t_parallel_param_.rank, size, l_parallel_param_.rank + 1, + l_parallel_param_, decoding_params.stream); + } + } + } // end of for loop of layer + } // end of for loop of ite + allocator_.free(buf); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + + void forward(const DecoderInitParam *decoder_param, + DecodingInitParam decoding_params) + { +#ifndef NDEBUG + PRINT_FUNC_NAME_(); +#endif + const int input_len = decoding_params.request_input_len; + const int max_input_len = decoding_params.max_input_len; + const int request_batch_size = decoding_params.request_batch_size; + const int max_len = (decoding_params.request_output_len > 0 && input_len + decoding_params.request_output_len <= args_.seq_len_) ? + input_len + decoding_params.request_output_len : + args_.seq_len_; + + assert(request_batch_size <= args_.batch_size_); + assert(request_batch_size % l_parallel_param_.local_batch_size == 0); + const int m = request_batch_size; + const int k = args_.hidden_units_; + const DataType_* embedding_kernel_ptr = nullptr; + const DataType_* embedding_bias_ptr = nullptr; + + cudaMemsetAsync(finished_buf_, false, sizeof(finished_buf_[0]) * request_batch_size, decoding_params.stream); + if (args_.probability_threshold_ != 0.0) + { + topp_initialization_kernelLauncher_v2(nullptr, + nullptr, + nullptr, + topp_id_vals_buf_, + topp_offset_buf_, + begin_topp_offset_buf_, + args_.candidate_num_ > 0 ? args_.candidate_num_ : args_.vocab_size_padded_, + args_, + decoding_params.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + ker_curand_setupLauncher(curandstate_buf_, + args_, + decoding_params.stream); + + if(std::is_same::value || (std::is_same::value && args_.vocab_size_padded_ == args_.vocab_size_)) + { + embedding_kernel_ptr = (const DataType_ *)decoding_params.embedding_kernel; + embedding_bias_ptr = (const DataType_ *)decoding_params.embedding_bias; + } + else + { + cudaMemcpyAsync(embedding_kernel_padded_, decoding_params.embedding_kernel, + sizeof(DataType_) * args_.vocab_size_ * args_.hidden_units_, cudaMemcpyDeviceToDevice, decoding_params.stream); + embedding_kernel_ptr = (const DataType_ *)embedding_kernel_padded_; + bias_padding_kernelLauncher(embedding_bias_padded_, + decoding_params.embedding_bias, // GPTJ/CodeGen bias + args_.vocab_size_, + args_.vocab_size_padded_, + decoding_params.stream); + embedding_bias_ptr = embedding_bias_padded_; + } +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + bool is_generation_done = false; + const int local_batch = l_parallel_param_.local_batch_size; + for (size_t step = input_len; step < max_len; ++step) + { + const int ite_num = request_batch_size / local_batch; + for(size_t ite = 0; ite < ite_num; ite++) + { + if(l_parallel_param_.rank == 0 && l_parallel_param_.world_size > 1) + { + if(step != (size_t)input_len) + { + PUSH_RANGE("token/recv") + nccl_recv(decoding_params.output_ids + (step - 1) * m + ite * local_batch, local_batch, + l_parallel_param_.world_size - 1, l_parallel_param_, decoding_params.stream); + POP_RANGE + } + } + + if(l_parallel_param_.rank < l_parallel_param_.world_size - 1 && l_parallel_param_.world_size > 1) + { + if(step != (size_t)input_len) + { + nccl_broadcast(finished_buf_ + ite * local_batch, local_batch, l_parallel_param_.world_size - 1, l_parallel_param_, decoding_params.stream); + } + } + if(ite == 0) + { + cudaMemcpyAsync(h_finished_buf_, finished_buf_, sizeof(bool) * request_batch_size, cudaMemcpyDeviceToHost, decoding_params.stream); + cudaStreamSynchronize(decoding_params.stream); + uint sum = 0; + for (uint i = 0; i < request_batch_size; i++) + { + sum += (int)h_finished_buf_[i]; + } + if (sum == request_batch_size) + { + is_generation_done = true; + break; + } + } + + if(l_parallel_param_.rank == 0) + { + PUSH_RANGE("Before Transformer/Embedding") + /***** newly fixed by PaddleNLP *****/ + gpj_embedding_lookups_kernel_launcher(from_tensor_[0], + decoding_params.embedding_table, + decoding_params.output_ids, + local_batch, + m, + args_.hidden_units_, + step, + ite, + max_input_len, + decoding_params.d_start_lengths, + decoding_params.stream); + POP_RANGE +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + + //we use two-way buffer + int from_id, out_id; + for (int layer = 0; layer < args_.decoder_layers_; ++layer) + { + if(l_parallel_param_.is_valid(layer)) + { + /* + For the first layer (layer-0), from_id is 0. We also stored the embedding lookup + result in from_tensor_[0] + */ + from_id = layer & 0x1; + out_id = 1 - from_id; + + if(layer == l_parallel_param_.layers_per_group * l_parallel_param_.rank && layer != 0 && l_parallel_param_.world_size > 1) + { + const int size = local_batch * t_parallel_param_.local_hidden_units_; + nccl_recv(from_tensor_[from_id] + size * t_parallel_param_.rank, size, l_parallel_param_.rank - 1, + l_parallel_param_, decoding_params.stream); + all2all_gather(from_tensor_[from_id], from_tensor_[from_id], size, + t_parallel_param_, decoding_params.stream); + } + + /* + We use one decoder_ object to process multiple decoder layers. + At the beginning of each decoder layer, we initialize the decoder object + with corresponding weights and decoder_buf_. + The decoder_buf_ is reused. + */ + decoder_->initialize(decoder_param[layer], decoder_buf_, cublas_workspace_, false); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + int dummy_decoder_max_seq_len = args_.seq_len_; + // int dummy_decoder_max_seq_len = -1; + size_t cache_offset; + if(dummy_decoder_max_seq_len == -1) + { + cache_offset = (layer - l_parallel_param_.layers_per_group * l_parallel_param_.rank) * + args_.batch_size_ * args_.seq_len_ * t_parallel_param_.local_hidden_units_ + + ite * local_batch * t_parallel_param_.local_hidden_units_; + } + else + { + cache_offset = (layer - l_parallel_param_.layers_per_group * l_parallel_param_.rank) * + args_.batch_size_ * args_.seq_len_ * t_parallel_param_.local_hidden_units_ + + ite * local_batch * args_.seq_len_ * t_parallel_param_.local_hidden_units_; + } + decoder_->forward_v2(from_tensor_[from_id], + nullptr, // memory_tensor should be nullptr + K_cache_[0] + cache_offset, + V_cache_[0] + cache_offset, + nullptr, nullptr, // key_mem_cache_ and value_mem_cache_ should be nullptr + nullptr, // memory_sequence_length should be nullptr + from_tensor_[out_id], step, dummy_decoder_max_seq_len, + false, + finished_buf_ + ite * local_batch, + max_input_len, + decoding_params.d_start_lengths + ite * local_batch, + args_.rotary_embedding_dim_); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + if(layer == l_parallel_param_.layers_per_group * (l_parallel_param_.rank + 1) - 1 && layer != args_.decoder_layers_ - 1 && l_parallel_param_.world_size > 1) + { + const size_t size = local_batch * t_parallel_param_.local_hidden_units_; + nccl_send(from_tensor_[out_id] + size * t_parallel_param_.rank, size, l_parallel_param_.rank + 1, + l_parallel_param_, decoding_params.stream); + } + } + } + + if(l_parallel_param_.rank == l_parallel_param_.world_size - 1) + { + + layer_norm(from_tensor_[out_id], + decoding_params.layernorm.gamma, + decoding_params.layernorm.beta, + decoder_normed_result_buf_, + local_batch, + k, + decoding_params.stream); + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + DataType_ alpha = DataType_(1.0f); + DataType_ beta = DataType_(0.0f); + assert(args_.vocab_size_padded_ % t_parallel_param_.world_size == 0); + int n = args_.vocab_size_padded_ / t_parallel_param_.world_size; + + if(t_parallel_param_.world_size == 1) + { + PUSH_RANGE("After Transformer/GEMM") + cublasMM_cublasLtMM_wrapper_decoder(decoding_params.cublaslt_handle, + decoding_params.cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + n, local_batch, k, + &alpha, + embedding_kernel_ptr, AType_, k, + decoder_normed_result_buf_, BType_, k, + &beta, + logits_buf_, CType_, n, + decoding_params.stream, cublasAlgoMap_, + cublas_workspace_); + POP_RANGE + } + else + { + PUSH_RANGE("After Transformer/GEMM") + cublasMM_cublasLtMM_wrapper_decoder(decoding_params.cublaslt_handle, + decoding_params.cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + n, local_batch, k, + &alpha, + embedding_kernel_ptr + t_parallel_param_.rank * n * k, + AType_, k, + decoder_normed_result_buf_, BType_, k, + &beta, + nccl_logits_buf_ + t_parallel_param_.rank * local_batch * n, + CType_, n, + decoding_params.stream, cublasAlgoMap_, + cublas_workspace_); + POP_RANGE + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + apply_logits_mask_kernelLauncher(logits_buf_, + finished_buf_, + args_.batch_size_, + 1, + args_.vocab_size_padded_, + args_.vocab_size_, + decoding_params.stream, + (DataType_*) nullptr, + (args_.min_length_ != 0 && step-input_len < args_.min_length_), + args_.end_id_, + embedding_bias_ptr); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + + if(t_parallel_param_.world_size == 1) + { + apply_temperature_penalty_kernelLauncher(logits_buf_, + (DataType_) args_.temperature_, + local_batch, + args_.vocab_size_, + n, + decoding_params.stream); + } + else + { + if(t_parallel_param_.rank == t_parallel_param_.world_size - 1) + { + apply_temperature_penalty_kernelLauncher(nccl_logits_buf_ + t_parallel_param_.rank * local_batch * n, + (DataType_) args_.temperature_, + local_batch, + args_.vocab_size_ - n * t_parallel_param_.rank, + n, + decoding_params.stream); + } + else + { + apply_temperature_penalty_kernelLauncher(nccl_logits_buf_ + t_parallel_param_.rank * local_batch * n, + (DataType_) args_.temperature_, + local_batch, + n, + n, + decoding_params.stream); + } + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + // reduce and concat the reuslt + if(t_parallel_param_.world_size > 1) + { + PUSH_RANGE("After Transformer/all2all_gather") + all2all_gather(nccl_logits_buf_, nccl_logits_buf_, local_batch * n, + t_parallel_param_, decoding_params.stream); + POP_RANGE + + transpose_axis_01_kernelLauncher(logits_buf_, nccl_logits_buf_, + t_parallel_param_.world_size, local_batch, n, decoding_params.stream); + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + n = args_.vocab_size_padded_; + + // Apply repetition penalty. + if (args_.repetition_penalty_ != 1.0) { + PUSH_RANGE("After Transformer/Repetition_penalty") + apply_repetition_penalty_kernelLauncher(logits_buf_, + args_.repetition_penalty_, + decoding_params.d_start_ids, + decoding_params.output_ids, + m, + local_batch, + args_.vocab_size_, + n, + decoding_params.d_start_lengths, + max_input_len, + step, + ite, + decoding_params.stream); + POP_RANGE + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + // Sampling + if(args_.candidate_num_ > 0 && args_.probability_threshold_ == 0.0) + { + PUSH_RANGE("After Transformer/Sampling") + // top k sampling + topK_sampling_kernel_kernelLauncher_v2(topk_workspace_, + topk_workspace_size_, + logits_buf_, + decoding_params.output_ids + step * m + ite * local_batch, + nullptr, + finished_buf_ + ite * local_batch, + curandstate_buf_, // used as random number + args_, + decoding_params.stream, + local_batch); + POP_RANGE + } + else if(args_.candidate_num_ == 0 && args_.probability_threshold_ > 0.0f) + { + PUSH_RANGE("After Transformer/Sampling") + // top p sampling + softmax_kernelLauncher(logits_buf_, + (DataType_*) nullptr, + args_.end_id_, + finished_buf_ + ite * local_batch, + local_batch, + args_.vocab_size_padded_, + args_.vocab_size_, + decoding_params.stream); +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + topP_sampling_kernel_kernelLauncher_v2(topp_workspace_, + topp_workspace_size_, + logits_buf_, + topp_id_vals_buf_, + topp_offset_buf_, + begin_topp_offset_buf_, + finished_buf_ + ite * local_batch, + curandstate_buf_, + args_, + decoding_params.output_ids + step * m + ite * local_batch, + nullptr, + n, + decoding_params.stream, + local_batch); + + POP_RANGE + } + else if(args_.candidate_num_ > 0 && args_.probability_threshold_ > 0.0f) + { + PUSH_RANGE("After Transformer/Sampling") + topK_topP_sampling_kernel_kernelLauncher_v2(topk_topp_workspace_, + topk_topp_workspace_size_, + decoding_params.output_ids + step * m + ite * local_batch, + logits_buf_, + finished_buf_ + ite * local_batch, + curandstate_buf_, + args_, + decoding_params.stream, + local_batch); + POP_RANGE + } +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } + if(step < (size_t)max_input_len) + { + // Replace the sampled id by start ids + set_start_ids_kernelLauncher(decoding_params.output_ids, decoding_params.d_start_ids, max_input_len, + step, ite, request_batch_size, local_batch, args_.end_id_, decoding_params.stream); + } + + if(l_parallel_param_.rank == l_parallel_param_.world_size - 1 && l_parallel_param_.world_size > 1) + { + PUSH_RANGE("token/send") + nccl_send(decoding_params.output_ids + step * m + ite * local_batch, local_batch, 0, l_parallel_param_, decoding_params.stream); + POP_RANGE + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + + if(l_parallel_param_.rank == l_parallel_param_.world_size - 1 && l_parallel_param_.world_size > 1 && step < max_len - 1) + { + nccl_broadcast(finished_buf_ + ite * local_batch, local_batch, l_parallel_param_.world_size - 1, l_parallel_param_, decoding_params.stream); + } +#ifndef NDEBUG + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +#endif + } // end for ite for loop + + if (is_generation_done) { + break; + } + } // end for decoding step for loop + if(l_parallel_param_.rank == 0 && l_parallel_param_.world_size > 1) + { + for(size_t ite = 0; ite < request_batch_size / local_batch; ite++) + { + nccl_recv(decoding_params.output_ids + (max_len - 1) * m + ite * local_batch, + local_batch, l_parallel_param_.world_size - 1, + l_parallel_param_, decoding_params.stream); + } + } + } // end of forward + + virtual ~DecodingGptJ() + { + delete[] K_cache_; + delete[] V_cache_; + delete decoder_; + allocator_.free(buf_); + delete [] h_finished_buf_; + } + + inline int get_num_layer() {return args_.decoder_layers_;} + + inline void set_local_batch_size(int local_batch) + { + l_parallel_param_.local_batch_size = local_batch; + decoder_->set_local_batch_size(local_batch); + } +}; + +} //namespace fastertransformer diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/open_decoder.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/open_decoder.h index e8354be4aa46..fb30f5e41d7f 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/open_decoder.h +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/open_decoder.h @@ -588,7 +588,8 @@ class OpenDecoder { const bool is_cross_attention, const bool *finished = nullptr, const int max_input_len = 0, - const int *input_lengths = nullptr) { + const int *input_lengths = nullptr, + const int rotary_embedding_dim = 0) { #ifndef NDEBUG PRINT_FUNC_NAME_(); #endif @@ -618,7 +619,8 @@ class OpenDecoder { step, decoder_max_seq_len, max_input_len, - input_lengths); + input_lengths, + rotary_embedding_dim); POP_RANGE #ifndef NDEBUG @@ -697,7 +699,20 @@ class OpenDecoder { hidden_units_, param_.stream); } else { - add_bias_input_layernorm_2_kernelLauncher( + if (rotary_embedding_dim > 0){ + add_bias_input_layernorm_2_kernelLauncher( + from_tensor, + (DataType_*) nullptr, + (DataType_*) nullptr, + (DataType_*) nullptr, + masked_output_buf_, + norm_masked_output_buf_, + m, + hidden_units_, + param_.stream); + + }else{ + add_bias_input_layernorm_2_kernelLauncher( from_tensor, param_.ffn_layernorm.gamma, param_.ffn_layernorm.beta, @@ -707,12 +722,14 @@ class OpenDecoder { m, hidden_units_, param_.stream); + } + #ifndef NDEBUG cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif PUSH_RANGE("Transformer/MLP") - ffn(norm_masked_output_buf_, + ffn(rotary_embedding_dim > 0 ? norm_from_tensor_buf_:norm_masked_output_buf_, ffn_inner_buf_, decoder_output, m, @@ -917,7 +934,8 @@ class OpenDecoder { const int ite, const int max_seq_len, const bool is_final, - const int* memory_sequence_length = nullptr) { + const int* memory_sequence_length = nullptr, + const int rotary_embedding_dim = 0) { #ifndef NDEBUG PRINT_FUNC_NAME_(); #endif @@ -968,7 +986,8 @@ class OpenDecoder { ite, max_seq_len, is_final, - memory_sequence_length); + memory_sequence_length, + rotary_embedding_dim); if (is_final) return; POP_RANGE @@ -976,17 +995,30 @@ class OpenDecoder { cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif + if (rotary_embedding_dim > 0){ + add_bias_input_layernorm_2_kernelLauncher( + from_tensor, + (DataType_*) nullptr, + (DataType_*) nullptr, + (DataType_*) nullptr, + masked_output_buf, + norm_masked_output_buf, + m, + hidden_units_, + param_.stream); + }else{ + add_bias_input_layernorm_2_kernelLauncher( + from_tensor, + param_.ffn_layernorm.gamma, + param_.ffn_layernorm.beta, + param_.self_attention.attention_output_weight.bias, + masked_output_buf, + norm_masked_output_buf, + m, + hidden_units_, + param_.stream); + } - add_bias_input_layernorm_2_kernelLauncher( - from_tensor, - param_.ffn_layernorm.gamma, - param_.ffn_layernorm.beta, - param_.self_attention.attention_output_weight.bias, - masked_output_buf, - norm_masked_output_buf, - m, - hidden_units_, - param_.stream); #ifndef NDEBUG cudaDeviceSynchronize(); @@ -995,7 +1027,7 @@ class OpenDecoder { // For GPT decoder PUSH_RANGE("Transformer/MLP"); - ffn(norm_masked_output_buf, + ffn(rotary_embedding_dim > 0 ? norm_from_tensor_buf:norm_masked_output_buf, ffn_inner_buf, decoder_output, m, @@ -1008,7 +1040,6 @@ class OpenDecoder { cudaDeviceSynchronize(); check_cuda_error(cudaGetLastError()); #endif - add_bias_input_kernelLauncher(decoder_output, param_.ffn.output_weight.bias, masked_output_buf, @@ -1310,7 +1341,8 @@ class OpenDecoder { const int step, const int max_seq_len, const int max_input_len, - const int *input_lengths) { + const int *input_lengths, + const int rotary_embedding_dim = 0) { assert(is_fuse_QKV_in_normal_gemm_ == true); // only support for is_fuse_QKV = True. @@ -1361,6 +1393,7 @@ class OpenDecoder { max_seq_len, max_input_len, input_lengths, + rotary_embedding_dim, param_.stream); k = t_parallel_param_.local_hidden_units_; @@ -1616,7 +1649,8 @@ class OpenDecoder { const int ite, const int max_seq_len, const bool is_final, - const int* memory_sequence_length = nullptr) { + const int* memory_sequence_length = nullptr, + const int rotary_embedding_dim = 0) { const DataType_ scalar = 1 / sqrtf(size_per_head_ * 1.0f); const int m = local_batch_size * seq_len; @@ -1663,7 +1697,6 @@ class OpenDecoder { param_.stream, cublasAlgoMap_, cublas_workspace_); - add_fusedQKV_bias_transpose_kernelLauncher( q_buf, k_buf, @@ -1674,6 +1707,7 @@ class OpenDecoder { seq_len, t_parallel_param_.local_head_num_, size_per_head_, + rotary_embedding_dim, param_.stream); } else { const int n = t_parallel_param_.local_hidden_units_; diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/utils/allocator.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/utils/allocator.h index 0e6c7062bb28..8a399cf37fd6 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/utils/allocator.h +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/utils/allocator.h @@ -87,10 +87,16 @@ class Allocator : public IAllocator { int64_t buf_size = static_cast(size); std::vector buf_dims({buf_size}); +#ifdef PADDLE_NEW_ALLOCATOR + // For PaddlePaddle>=2.3.0 + auto buf = paddle::empty(buf_dims, paddle::DataType::UINT8, paddle::GPUPlace()); + allocated_tensor_vector->push_back(buf); + auto *flat = buf.data(); +#else auto buf = paddle::Tensor(paddle::PlaceType::kGPU, buf_dims); allocated_tensor_vector->push_back(buf); - auto *flat = buf.mutable_data(paddle::PlaceType::kGPU); +#endif void *ptr = reinterpret_cast(flat); return ptr; } diff --git a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/utils/arguments.h b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/utils/arguments.h index a4fed1d32d28..c1f60807be08 100644 --- a/paddlenlp/ops/patches/FasterTransformer/fastertransformer/utils/arguments.h +++ b/paddlenlp/ops/patches/FasterTransformer/fastertransformer/utils/arguments.h @@ -155,6 +155,11 @@ struct GptArguments : public DecodingSamplingArguments { int min_gpu_num_{1}; }; + +struct GptJArguments : public GptArguments { + int rotary_embedding_dim_{0}; +}; + struct TransformerSamplingArguments : public DecodingSamplingArguments { int **start_ids_; int start_len_; diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 460f9422f0ea..58d814796ede 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -121,6 +121,8 @@ from .artist.tokenizer import * from .dallebart.modeling import * from .dallebart.tokenizer import * +from .gptj.modeling import * +from .gptj.tokenizer import * # For faster tokenizer from ..utils.import_utils import is_faster_tokenizer_available diff --git a/paddlenlp/transformers/codegen/modeling.py b/paddlenlp/transformers/codegen/modeling.py index 44a597ec0a8a..2c1c0bc07b53 100644 --- a/paddlenlp/transformers/codegen/modeling.py +++ b/paddlenlp/transformers/codegen/modeling.py @@ -289,7 +289,9 @@ class CodeGenPreTrainedModel(PretrainedModel): def init_weights(self, layer): """Initialize the weights.""" if isinstance(layer, (nn.Linear, nn.Embedding)): - if isinstance(layer.weight, paddle.Tensor): + if isinstance( + layer.weight, + paddle.Tensor) and paddle.get_default_dtype() == "float32": layer.weight.set_value( paddle.tensor.normal( mean=0.0, @@ -532,6 +534,33 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def prepare_faster_entry(self, kwargs): + from paddlenlp.ops import FasterCodeGen + use_fp16_decoding = kwargs.get('use_fp16_decoding', False) + decoding_lib = kwargs.get('decoding_lib', None) + decode_strategy = kwargs.get('decode_strategy') + if decode_strategy == "beam_search": + raise AttributeError( + "'beam_search' is not supported yet in the faster version of GPTJ" + ) + # Currently, FasterTransformer only support restricted size_per_head. + size_per_head = self.transformer.config[ + "n_embd"] // self.transformer.config["n_head"] + if size_per_head not in [32, 64, 80, 96, 128, 160, 192, 224, 256]: + raise AttributeError( + "'size_per_head = %d' is not supported yet in the faster version of GPTJ" + % size_per_head) + if kwargs['forced_bos_token_id'] is not None: + # not support for min_length yet in the faster version + raise AttributeError( + "'forced_bos_token_id != None' is not supported yet in the faster version" + ) + self._faster_entry = FasterCodeGen( + self, + decoding_lib=decoding_lib, + use_fp16_decoding=use_fp16_decoding).forward + return self._faster_entry + def prepare_inputs_for_generation(self, input_ids, cache=None, **kwargs): # only last token for inputs_ids if past is defined in kwargs if cache: diff --git a/paddlenlp/transformers/codegen/tokenizer.py b/paddlenlp/transformers/codegen/tokenizer.py index cfbf8ac0a8a8..5daedc4004ba 100644 --- a/paddlenlp/transformers/codegen/tokenizer.py +++ b/paddlenlp/transformers/codegen/tokenizer.py @@ -28,6 +28,26 @@ class CodeGenTokenizer(GPTTokenizer): pretrained_resource_files_map = {"vocab_file": {}, "merges_file": {}} pretrained_init_configuration = {} + def __init__(self, + vocab_file, + merges_file, + errors='replace', + max_len=None, + pad_token='<|endoftext|>', + eos_token='<|endoftext|>', + unk_token='<|endoftext|>', + eol_token='\u010a', + **kwargs): + super().__init__(vocab_file=vocab_file, + merges_file=merges_file, + errors=errors, + max_len=max_len, + pad_token=pad_token, + eos_token=eos_token, + unk_token=unk_token, + eol_token=eol_token, + **kwargs) + def decode(self, token_ids, skip_special_tokens=False, diff --git a/paddlenlp/transformers/gptj/__init__.py b/paddlenlp/transformers/gptj/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/paddlenlp/transformers/gptj/modeling.py b/paddlenlp/transformers/gptj/modeling.py new file mode 100644 index 000000000000..46d9c8857e2a --- /dev/null +++ b/paddlenlp/transformers/gptj/modeling.py @@ -0,0 +1,802 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The EleutherAI Authors and The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Layer, Embedding + +from ..nezha.modeling import ACT2FN +from .. import PretrainedModel, register_base_model + +__all__ = [ + "GPTJModel", "GPTJPretrainedModel", "GPTJForCausalLM", + "GPTJForSequenceClassification", "GPTJForQuestionAnswering" +] + + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000**(paddle.arange(0, dim, 2) / dim)) + sinusoid_inp = (paddle.einsum("i , j -> i j", + paddle.arange(seq_len, dtype="float32"), + inv_freq)) + return paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = paddle.stack((-x2, x1), axis=-1) + # In einsum notation: rearrange(x, '... d j -> ... (d j)') + return x.flatten(-2) + + +def duplicate_interleave(m): + return paddle.repeat_interleave(m, 2, axis=1) + + +def apply_rotary_pos_emb(x, sincos, offset=0): + sin, cos = map( + lambda t: duplicate_interleave(t)[None, offset:x.shape[1] + offset, + None, :], sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class GPTJAttention(Layer): + + def __init__(self, embed_dim, rotary_dim, num_attention_heads, + max_positions, attn_pdrop, resid_pdrop): + super().__init__() + + self.register_buffer( + "causal_mask", + paddle.tril( + paddle.ones((max_positions, max_positions), + dtype=paddle.get_default_dtype())).reshape( + (1, 1, max_positions, max_positions)), + ) + + self.attn_dropout = nn.Dropout(attn_pdrop) + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.embed_dim = embed_dim + self.num_attention_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" + f" `num_attention_heads`: {self.num_attention_heads}).") + self.scale_attn = paddle.sqrt( + paddle.to_tensor(self.head_dim, dtype="float32")) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias_attr=False) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias_attr=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias_attr=False) + + self.out_proj = nn.Linear(self.embed_dim, + self.embed_dim, + bias_attr=False) + self.rotary_dim = rotary_dim + + def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): + new_shape = tensor.shape[:-1] + [num_attention_heads, attn_head_size] + tensor = tensor.reshape(new_shape) + if rotary: + return tensor + if len(tensor.shape) == 5: + # (batch, blocks, head, block_length, head_features) + return tensor.transpose([0, 1, 3, 2, 4]) + elif len(tensor.shape) == 4: + # (batch, head, seq_length, head_features) + return tensor.transpose([0, 2, 1, 3]) + else: + raise ValueError( + f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}" + ) + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + if len(tensor.shape) == 5: + tensor = tensor.transpose([0, 1, 3, 2, 4]) + elif len(tensor.shape) == 4: + tensor = tensor.transpose([0, 2, 1, 3]) + else: + raise ValueError( + f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}" + ) + new_shape = tensor.shape[:-2] + [num_attention_heads * attn_head_size] + return tensor.reshape(new_shape) + + def _attn(self, query, key, value, attention_mask=None): + + # compute causal mask from causal mask buffer + query_length, key_length = query.shape[-2], key.shape[-2] + causal_mask = self.causal_mask[:, :, key_length - + query_length:key_length, :key_length] + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = paddle.cast(query, "float32") + key = paddle.cast(key, "float32") + attn_weights = paddle.matmul(query, key, transpose_y=True) + + mask_value = paddle.to_tensor(-1e9, dtype=attn_weights.dtype) + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + attn_weights = paddle.where(causal_mask, attn_weights, mask_value) + attn_weights = attn_weights / self.scale_attn + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights, axis=-1, dtype=value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + attn_output = paddle.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states, + attention_mask=None, + use_cache=False, + cache=None, + ): + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_attention_heads, + self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, + True) + value = self._split_heads(value, self.num_attention_heads, + self.head_dim, False) + + seq_len = key.shape[1] + offset = 0 + + if cache is not None: + offset = cache[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, :self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim:] + + q_rot = query[:, :, :, :self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim:] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = paddle.concat([k_rot, k_pass], axis=-1) + query = paddle.concat([q_rot, q_pass], axis=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.transpose([0, 2, 1, 3]) + query = query.transpose([0, 2, 1, 3]) + + if cache is not None: + past_key = cache[0] + past_value = cache[1] + key = paddle.concat((past_key, key), axis=-2) + value = paddle.concat((past_value, value), axis=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, + attention_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, + self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + return attn_output, present + + +class GPTJMLP(Layer): + + def __init__(self, embed_dim, inner_dim, activation_function, resid_pdrop): + super().__init__() + + self.fc_in = nn.Linear(embed_dim, inner_dim) + self.fc_out = nn.Linear(inner_dim, embed_dim) + + self.act = ACT2FN[activation_function] + self.dropout = nn.Dropout(resid_pdrop) + + def forward(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GPTJBlock(Layer): + + def __init__(self, embed_dim, rotary_dim, n_head, n_positions, attn_pdrop, + resid_pdrop, activation_function, layer_norm_epsilon): + super().__init__() + inner_dim = 4 * embed_dim + self.ln_1 = nn.LayerNorm(embed_dim, epsilon=layer_norm_epsilon) + self.attn = GPTJAttention(embed_dim, rotary_dim, n_head, n_positions, + attn_pdrop, resid_pdrop) + self.mlp = GPTJMLP(embed_dim, inner_dim, activation_function, + resid_pdrop) + + def forward( + self, + hidden_states, + attention_mask=None, + use_cache=False, + cache=None, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn(hidden_states, + attention_mask=attention_mask, + cache=cache, + use_cache=use_cache) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + + if use_cache: + outputs = (hidden_states, ) + outputs + else: + outputs = (hidden_states, ) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +class GPTJPretrainedModel(PretrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + model_config_file = "model_config.json" + pretrained_init_configuration = {} + resource_files_names = {"model_state": "model_state.pdparams"} + pretrained_resource_files_map = {"model_state": {}} + + base_model_prefix = "transformer" + + def init_weights(self, layer): + """Initialize the weights.""" + if isinstance(layer, (nn.Linear, nn.Embedding)): + if isinstance( + layer.weight, + paddle.Tensor) and paddle.get_default_dtype() == "float32": + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.initializer_range if hasattr( + self, "initializer_range") else + self.transformer.config["initializer_range"], + shape=layer.weight.shape)) + elif isinstance(layer, nn.LayerNorm): + layer.bias.set_value(paddle.zeros_like(layer.bias)) + layer.weight.set_value(paddle.full_like(layer.weight, 1.0)) + layer._epsilon = getattr(self, "layer_norm_epsilon", 1e-05) + if isinstance(layer, nn.Linear) and layer.bias is not None: + layer.bias.set_value(paddle.zeros_like(layer.bias)) + + +@register_base_model +class GPTJModel(GPTJPretrainedModel): + r""" + The bare GPTJ Model outputting raw hidden-states. + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + Args: + vocab_size (int): + Vocabulary size of `inputs_ids` in `GPTJModel`. Also is the vocab size of token embedding matrix. + Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `GPTJModel`. + bos_token_id (int, optional): + The beginning of sequence token that was used during pretraining. Can be + used a sequence classifier token. + Defaults to `0`. + pad_token_id(int, optional): + The index of padding token in the token vocabulary. + Defaults to `50256`. + eos_toke_idn (int, optional): + A special token representing the end of a sequence that was used during pretraining. + Defaults to `2`. + n_embed (int, optional): + Dimensionality of the embedding layer, decoder layer. Defaults to `1024`. + n_layer (int, optional): + Number of hidden layers. Defaults to `20`. + n_head (int, optional): + Number of attention heads for each attention layer in the Transformer decoder. + Defaults to `16`. + n_positions (int, optional): + The maximum sequence length that this model might ever be used with. + Defaults to `2048`. + attn_pdrop (float, optional): + The dropout probability used in MultiHeadAttention in all decoder layers to drop some attention target. + Defaults to `0.0`. + resid_pdrop (float, optional): + The dropout probability for all residual layers in the decoder. + Defaults to `0.0`. + embd_pdrop (float, optional): + The dropout probability used in embedding layers. Defaults to `0.0`. + rotary_dim (int, optional): + Dimensionality of rotay position embeddings. + Defaults to `32`. + activation_function (str, optional): + The non-linear activation function in the feed-forward layer. + ``"gelu"``, ``"relu"`` and any other paddle supported activation functions are supported. + Defaults to `"gelu_new"`. + layer_norm_epsilon (float, optional): + The epsilon to use in the layer normalization layers. + Defaults to `1e-05`. + initializer_range (float, optional): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Default to `0.02`. + """ + + def __init__(self, + vocab_size, + bos_token_id=50256, + pad_token_id=50256, + eos_token_id=50256, + n_embd=4096, + n_layer=28, + n_head=16, + n_positions=2048, + attn_pdrop=0.0, + resid_pdrop=0.0, + embd_pdrop=0.0, + rotary_dim=64, + activation_function="gelu_new", + layer_norm_epsilon=1e-05, + initializer_range=0.02): + super().__init__() + + self.vocab_size = vocab_size + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.eos_token_id = eos_token_id + self.embed_dim = n_embd + self.initializer_range = initializer_range + self.wte = nn.Embedding(vocab_size, self.embed_dim) + self.drop = nn.Dropout(embd_pdrop) + self.h = nn.LayerList([ + GPTJBlock(n_embd, rotary_dim, n_head, n_positions, attn_pdrop, + resid_pdrop, activation_function, layer_norm_epsilon) + for _ in range(n_layer) + ]) + self.ln_f = nn.LayerNorm(self.embed_dim, epsilon=layer_norm_epsilon) + + # Initialize weights and apply final processing + self.apply(self.init_weights) + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + use_cache=False, + cache=None, + ): + r''' + The GPTJModel forward method, overrides the `__call__()` special method. + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. + attention_mask (Tensor, optional): + Mask used in multi-head attention to avoid performing attention to some unwanted positions, + usually the paddings or the subsequent positions. + Its data type can be int, float and bool. + When the data type is bool, the `masked` tokens have `False` values and the others have `True` values. + When the data type is int, the `masked` tokens have `0` values and the others have `1` values. + When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values. + It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. + For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], + [batch_size, num_attention_heads, sequence_length, sequence_length]. + Defaults to `None`, which means nothing needed to be prevented attention to. + use_cache (bool, optional): + Whether or not to use cache. Defaults to `False`. If set to `True`, key value states will be returned and + can be used to speed up decoding. + cache (list, optional): + It is a list, and each element in the list is a tuple `(incremental_cache, static_cache)`. + See `TransformerDecoder.gen_cache `__ for more details. + It is only used for inference and should be None for training. + Default to `None`. + Returns: + Tensor: Returns tensor `decoder_output`, which is the output at the last layer of the model. + Its data type should be float32 and has a shape of [batch_size, sequence_length, hidden_size]. + Example: + .. code-block:: + import paddle + from paddlenlp.transformers import GPTJModel, GPTJTokenizer + tokenizer = GPTJTokenizer.from_pretrained('EleutherAI/gpt-j-6B') + model = GPTJModel.from_pretrained('EleutherAI/gpt-j-6B') + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", return_token_type_ids=False) + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + output = model(**inputs) + ''' + + if input_ids is not None: + input_shape = input_ids.shape + input_ids = input_ids.reshape(shape=(-1, input_shape[-1])) + batch_size = input_ids.shape[0] + else: + raise ValueError("You have to specify input_ids") + + if cache is None: + past_length = 0 + cache = tuple([None] * len(self.h)) + else: + past_length = cache[0][0].shape[-2] + + # Attention mask. + if attention_mask is None: + assert input_ids is not None, "input_ids should be " \ + "specified when generating attention_mask" + attention_mask = paddle.cast( + input_ids == self.pad_token_id, + dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4 + # For 2D attention_mask from tokenizer + elif attention_mask.ndim == 2: + attention_mask = paddle.unsqueeze( + attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype()) + attention_mask = (1.0 - attention_mask) * -1e4 + attention_mask.stop_gradient = True + + inputs_embeds = self.wte(input_ids) + + hidden_states = self.drop(inputs_embeds) + output_shape = input_shape[:] + [hidden_states.shape[-1]] + + presents = () if use_cache else None + for i, (block, old_cache) in enumerate(zip(self.h, cache)): + outputs = block(hidden_states, + attention_mask=attention_mask, + use_cache=use_cache, + cache=old_cache) + + hidden_states = outputs[0] + if use_cache: + presents = presents + (outputs[1], ) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.reshape(shape=output_shape) + + last_hidden_state = hidden_states + new_cache = presents + + return last_hidden_state, new_cache + + +class GPTJForCausalLM(GPTJPretrainedModel): + r""" + GPTJ Model with a `language modeling` head on top. + Args: + GPTJ (:class:`GPTJModel`): + An instance of GPTJModel. + """ + + def __init__(self, transformer): + super().__init__() + self.transformer = transformer + self.lm_head = nn.Linear(self.transformer.config["n_embd"], + self.transformer.config["vocab_size"]) + + # Initialize weights and apply final processing + self.apply(self.init_weights) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_faster_entry(self, kwargs): + from paddlenlp.ops import FasterGPTJ + use_fp16_decoding = kwargs.get('use_fp16_decoding', False) + decoding_lib = kwargs.get('decoding_lib', None) + decode_strategy = kwargs.get('decode_strategy') + if decode_strategy == "beam_search": + raise AttributeError( + "'beam_search' is not supported yet in the faster version of GPTJ" + ) + # Currently, FasterTransformer only support restricted size_per_head. + size_per_head = self.transformer.config[ + "n_embd"] // self.transformer.config["n_head"] + if size_per_head not in [32, 64, 80, 96, 128, 160, 192, 224, 256]: + raise AttributeError( + "'size_per_head = %d' is not supported yet in the faster version of GPTJ" + % size_per_head) + if kwargs['forced_bos_token_id'] is not None: + # not support for min_length yet in the faster version + raise AttributeError( + "'forced_bos_token_id != None' is not supported yet in the faster version" + ) + self._faster_entry = FasterGPTJ( + self, + decoding_lib=decoding_lib, + use_fp16_decoding=use_fp16_decoding).forward + return self._faster_entry + + def prepare_inputs_for_generation(self, input_ids, cache=None, **kwargs): + # only last token for inputs_ids if past is defined in kwargs + if cache: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if len(attention_mask.shape) == 4: + attention_mask = attention_mask[:, :, -1:, :] + + return { + "input_ids": input_ids, + "cache": cache, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + + def forward(self, + input_ids=None, + attention_mask=None, + use_cache=False, + cache=None): + r""" + The GPTJForCausalLM forward method, overrides the __call__() special method. + Args: + input_ids (Tensor): + See :class:`GPTJModel`. + attention_mask (Tensor, optional): + See :class:`GPTJModel`. + use_cache (bool, optional): + See :class:`GPTJModel`. + cache (Tensor, optional): + See :class:`GPTJModel`. + Returns: + Tensor or tuple: Returns Tensor `lm_logits` if `use_cache` is `False`, otherwise, returns tuple (`lm_logits`, `cache`). + With the fields: + - `lm_logits` (Tensor): + The generated sentence of the model. + Its data type should be float32 and has a shape of [batch_size, sequence_length, vocab_size]. + - `cache` (Tensor): + See :class:`GPTJModel`. + Example: + .. code-block:: + import paddle + from paddlenlp.transformers import GPTJForCausalLM, GPTJTokenizer + tokenizer = GPTJTokenizer.from_pretrained('EleutherAI/gpt-j-6B') + model = GPTJForCausalLM.from_pretrained('EleutherAI/gpt-j-6B') + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", return_token_type_ids=False) + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + """ + + transformer_outputs = self.transformer(input_ids, + attention_mask=attention_mask, + use_cache=use_cache, + cache=cache) + + hidden_states = transformer_outputs[0] + + # make sure sampling in fp16 works correctly and + # compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + lm_logits = paddle.cast(self.lm_head(hidden_states), "float32") + past_key_values = transformer_outputs[1] + + return lm_logits, past_key_values + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError as e: + try: + return getattr(getattr(self, self.base_model_prefix), name) + except AttributeError: + try: + return getattr(self, self.base_model_prefix).config[name] + except KeyError: + raise e + + +class GPTJForSequenceClassification(GPTJPretrainedModel): + r""" + GPTJ Model with a linear layer on top of the pooled output, + designed for sequence classification/regression tasks like GLUE tasks. + Since it does classification on the last token, it requires to know the + position of the last token. If a `pad_token_id` is defined in the configuration, + it finds the last token that is not a padding token in each row. If no `pad_token_id` + is defined, it simply takes the last value in each row of the batch. + + Args: + GPTJ (:class:`GPTJModel`): + An instance of GPTJModel. + num_labels (int, optional): + The number of different labels. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of GPTJ. + If None, use the same value as `hidden_dropout_prob` of `GPTJModel` + instance `GPTJ`. Defaults to None. + """ + + def __init__(self, transformer, num_labels=2): + super().__init__() + self.transformer = transformer + self.classifier = nn.Linear(self.transformer.config["n_embd"], + num_labels, + bias_attr=False) + self.apply(self.init_weights) + + def forward( + self, + input_ids=None, + attention_mask=None, + use_cache=False, + cache=None, + ): + r""" + The GPTJForSequenceClassification forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`GPTJModel`. + attention_mask (Tensor, optional): + See :class:`GPTJModel`. + use_cache (bool, optional): + See :class:`GPTJModel`. + cache (Tensor, optional): + See :class:`GPTJModel`. + + Returns: + Tensor: Returns tensor `logits`, a tensor of the input text classification logits. + Shape as `[batch_size, num_labels]` and dtype as float32. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import GPTJForSequenceClassification, GPTJTokenizer + tokenizer = GPTJTokenizer.from_pretrained('EleutherAI/gpt-j-6B') + model = GPTJForSequenceClassification.from_pretrained('EleutherAI/gpt-j-6B') + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", return_token_type_ids=False) + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + logits = model(**inputs) + """ + transformer_outputs = self.transformer(input_ids, + attention_mask=attention_mask, + use_cache=use_cache, + cache=cache) + + hidden_states = transformer_outputs[0] + logits = self.classifier(hidden_states) + batch_size = input_ids.shape[0] + + if self.transformer.config.get('pad_token_id', + None) is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined.") + + if self.transformer.config.get('pad_token_id', None) is None: + sequence_lengths = -1 + else: + sequence_lengths = ( + input_ids != + self.transformer.config['pad_token_id']).sum(-1) - 1 + + pooled_logits = logits[paddle.arange(batch_size), sequence_lengths] + + return pooled_logits + + +class GPTJForQuestionAnswering(GPTJPretrainedModel): + r""" + GPTJ Model with a linear layer on top of the hidden-states output to + compute `span_start_logits` and `span_end_logits`, designed for question-answering tasks like SQuAD. + + Args: + GPTJ (:class:`GPTJModel`): + An instance of GPTJModel. + """ + + def __init__(self, transformer): + super().__init__() + self.transformer = transformer + self.classifier = nn.Linear(self.transformer.config['n_embd'], 2) + self.apply(self.init_weights) + + def forward( + self, + input_ids=None, + attention_mask=None, + use_cache=False, + cache=None, + ): + r""" + The GPTJForQuestionAnswering forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`GPTJModel`. + attention_mask (Tensor, optional): + See :class:`GPTJModel`. + use_cache (bool, optional): + See :class:`GPTJModel`. + cache (Tensor, optional): + See :class:`GPTJModel`. + + Returns: + tuple: Returns tuple (`start_logits`, `end_logits`). + + With the fields: + + - `start_logits` (Tensor): + A tensor of the input token classification logits, indicates the start position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + - `end_logits` (Tensor): + A tensor of the input token classification logits, indicates the end position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import GPTJForQuestionAnswering, GPTJTokenizer + + tokenizer = GPTJTokenizer.from_pretrained('EleutherAI/gpt-j-6B') + model = GPTJForQuestionAnswering.from_pretrained('EleutherAI/gpt-j-6B') + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", return_token_type_ids=False) + outputs = model(**inputs) + start_logits = outputs[0] + end_logits =outputs[1] + """ + transformer_outputs = self.transformer(input_ids, + attention_mask=attention_mask, + use_cache=use_cache, + cache=cache) + + hidden_states = transformer_outputs[0] + logits = self.classifier(hidden_states) + logits = paddle.transpose(logits, perm=[2, 0, 1]) + start_logits, end_logits = paddle.unstack(x=logits, axis=0) + return start_logits, end_logits diff --git a/paddlenlp/transformers/gptj/tokenizer.py b/paddlenlp/transformers/gptj/tokenizer.py new file mode 100644 index 000000000000..f3f3993898b5 --- /dev/null +++ b/paddlenlp/transformers/gptj/tokenizer.py @@ -0,0 +1,48 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The Open AI Team Authors and The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .. import GPTTokenizer + +__all__ = ['GPTJTokenizer'] + + +class GPTJTokenizer(GPTTokenizer): + + resource_files_names = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt" + } + pretrained_resource_files_map = {"vocab_file": {}, "merges_file": {}} + pretrained_init_configuration = {} + + def __init__(self, + vocab_file, + merges_file, + errors='replace', + max_len=None, + pad_token='<|endoftext|>', + eos_token='<|endoftext|>', + unk_token='<|endoftext|>', + eol_token='\u010a', + **kwargs): + super().__init__(vocab_file=vocab_file, + merges_file=merges_file, + errors=errors, + max_len=max_len, + pad_token=pad_token, + eos_token=eos_token, + unk_token=unk_token, + eol_token=eol_token, + **kwargs)