Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleMIX i…
Browse files Browse the repository at this point in the history
…nto ops_filter
  • Loading branch information
megemini committed Dec 23, 2024
2 parents cb2c07a + 1f7ab0b commit f6807df
Show file tree
Hide file tree
Showing 50 changed files with 6,526 additions and 64 deletions.
54 changes: 54 additions & 0 deletions deploy/qwen2_vl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Qwen2-VL

## 1. 模型介绍

[Qwen2-VL
](https: //qwenlm.github.io/blog/qwen2-vl/) 是大规模视觉语言模型。可以以图像、文本、检测框、视频作为输入,并以文本和检测框作为输出。本仓库提供paddle版本的`Qwen2-VL-2B-Instruct``Qwen2-VL-7B-Instruct`模型。

## 2 环境准备
- **python >= 3.10**
- **paddlepaddle-gpu 要求版本develop**
```
# 安装示例
python -m pip install paddlepaddle-gpu==0.0.0.post118 -f https: //www.paddlepaddle.org.cn/whl/linux/gpu/develop.html
```
- **paddlenlp
```
# 安装示例
git submodule update --init --recursive
cd PaddleNLP
git reset --hard e91c2d3d634b12769c30aa419ddf931c20b7ca9f
pip install -e .
cd csrc
python setup_cuda.py install
```

> 注:
* 请确保安装了以上依赖,否则无法运行。同时,需要安装 paddlemix/external_ops 下的自定义OP, `python setup.py install`。如果安装后仍然找不到算子,需要额外设置PYTHONPATH
* (默认开启flash_attn)使用flash_attn 要求A100/A800显卡或者H20显卡

## 3 高性能推理
# 在Qwen2-vl的推理优化中,我们在视觉模型部分继续使用paddlemix中的模型组网;
但是在语言模型部分,我们调用Paddlenlp中高性能的qwen2语言模型,以得到高性能的Qwen2-vl推理版本。

### a. 文本&单张图像输入高性能推理
```bash
python deploy/qwen2_vl/single_image_infer.py \
--model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
--dtype bfloat16 \
--benchmark 1
```

- 在 NVIDIA A100-SXM4-80GB 上测试的性能如下:


- Qwen2-VL-2B-Instruct
| Paddle Inference| PyTorch | Paddle 动态图 |
| --------------- | ------------ | ------------ |
| 1.44 s | 2.35 s | 5.215 s |


- Qwen2-VL-7B-Instruct
| Paddle Inference| PyTorch | Paddle 动态图 |
| --------------- | ------------ | ------------ |
| 1.73 s | 4.4s | 6.339 s |
280 changes: 280 additions & 0 deletions deploy/qwen2_vl/single_image_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
from dataclasses import dataclass, field

import numpy as np
import paddle
from paddlenlp.generation import GenerationConfig
from paddlenlp.trainer import PdArgumentParser
from paddlenlp.transformers import (
AutoConfig,
AutoInferenceModelForCausalLM,
Qwen2Tokenizer,
)
from paddlenlp.trl import llm_utils

from paddlemix.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2RotaryEmbedding,
Qwen2VLForConditionalGeneration,
)
from paddlemix.processors.qwen2_vl_processing import (
Qwen2VLImageProcessor,
Qwen2VLProcessor,
process_vision_info,
)

MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct"
# MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct"
vl_model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype="bfloat16")

# NOTE: (zhoukangkang、changwenbin) Because we only use the visual model here,
# in order to reduce video memory,we delete the language model.
del vl_model.model
paddle.device.cuda.empty_cache()

image_processor = Qwen2VLImageProcessor()
tokenizer = Qwen2Tokenizer.from_pretrained(MODEL_NAME)
processor = Qwen2VLProcessor(image_processor, tokenizer)

# min_pixels = 256*28*28 # 200704
# max_pixels = 1280*28*28 # 1003520
# processor = Qwen2VLProcessor(image_processor, tokenizer, min_pixels=min_pixels, max_pixels=max_pixels)

messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "paddlemix/demo_images/examples_image1.jpg",
},
{"type": "text", "text": "Describe this image."},
],
}
]

# Preparation for inference
image_inputs, video_inputs = process_vision_info(messages)

question = "Describe this image."
image_pad_token = "<|vision_start|><|image_pad|><|vision_end|>"
text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{image_pad_token}{question}<|im_end|>\n<|im_start|>assistant\n"


@dataclass
class PredictorArgument:
# NOTE: (zhoukangkang、changwenbin)
# These parameters are all copied from https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/predict/predictor.py
# For simplicity and ease of use, only the necessary parameters are retained here.
# If you want to know the exact meaning of these parameters, please refer to the link above.

model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."})
src_length = 1024
min_length = 2
max_length = 200
top_k = 0
top_p = 0.0
temperature = 0.95
repetition_penalty = 1.0
dtype: str = field(default=None, metadata={"help": "Model dtype"})
decode_strategy = "sampling"
mode = "dynamic"
inference_model = True
quant_type = ""
benchmark: bool = field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
)
use_fake_parameter = False
block_attn = True
block_size = 64
cachekv_int8_type = None
append_attn = True
total_max_length = 4096
speculate_method = None


@dataclass
class ModelArgument:
model_type: str = field(
default=None,
metadata={"help": "the type of the model, which can be one of ['gpt-3', 'ernie-3.5-se', 'llama-img2txt']"},
)


def init_llm_model_inputs(vision_model_inputs, inputs_embeds, arg_config: PredictorArgument):
assert len(inputs_embeds.shape) == 3
batch_size = inputs_embeds.shape[0]

model_inputs = {}
model_inputs["input_ids"] = paddle.zeros(shape=[batch_size, arg_config.total_max_length], dtype="int64")
model_inputs["inputs_embeds"] = inputs_embeds

# I dislike write (arg_config.total_max_length + arg_config.block_size -1 ) // arg_config.block_size
assert arg_config.total_max_length % arg_config.block_size == 0

model_inputs["top_p"] = paddle.full(shape=[batch_size, 1], fill_value=arg_config.top_p, dtype="float32")
model_inputs["temperature"] = paddle.full(
shape=[batch_size, 1], fill_value=arg_config.temperature, dtype="float32"
)
model_inputs["eos_token_id"] = paddle.to_tensor(
np.array(llm_utils.get_eos_token_id(tokenizer, generation_config)).reshape(-1, 1).astype("int64")
)
model_inputs["penalty_score"] = paddle.full(
shape=[batch_size, 1], fill_value=arg_config.repetition_penalty, dtype="float32"
)
model_inputs["frequency_score"] = paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32")
model_inputs["presence_score"] = paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32")
model_inputs["min_length"] = paddle.full(shape=[batch_size, 1], fill_value=arg_config.min_length, dtype="int64")
model_inputs["max_length"] = paddle.full(shape=[batch_size, 1], fill_value=arg_config.max_length, dtype="int64")

position_ids, _ = vl_model.get_rope_index(
config.vision_config["spatial_merge_size"],
config.image_token_id,
config.video_token_id,
config.vision_start_token_id,
vision_model_inputs.get("input_ids"),
vision_model_inputs.get("image_grid_thw"),
vision_model_inputs.get("video_grid_thw", None),
vision_model_inputs.get("attention_mask"),
)
position_start = position_ids[0][0][-1].item()
position_end = 4096 - position_ids.shape[-1] + position_start
position_value = (
paddle.arange(position_start, position_end).reshape([1, 1, -1]).expand([position_ids.shape[0], 1, -1])
)
position_ids = paddle.concat([position_ids, position_value], axis=-1)

head_dim = config.hidden_size // config.num_attention_heads
qwen2_Embedding = Qwen2RotaryEmbedding(head_dim, 4096, config.rope_theta)
cos = qwen2_Embedding.cos_cached
sin = qwen2_Embedding.sin_cached

# NOTE: (zhoukangkang、changwenbin) Copied from PaddleMIX/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py,
# for calculating M-ROPE.
cos = cos[position_ids]
sin = sin[position_ids]
mrope_section = config.rope_scaling["mrope_section"] * 2
cos = paddle.concat(x=[m[i % 3] for i, m in enumerate(cos.split(mrope_section, axis=-1))], axis=-1)
sin = paddle.concat(x=[m[i % 3] for i, m in enumerate(sin.split(mrope_section, axis=-1))], axis=-1)

rope_emb = paddle.stack([cos, sin], axis=0)
rope_emb = rope_emb.reshape([rope_emb.shape[0], 1, rope_emb.shape[2], 1, rope_emb.shape[-1]])
model_inputs["rope_emb"] = rope_emb

model_inputs["bad_tokens"] = paddle.to_tensor([-1], dtype="int64")
model_inputs["is_block_step"] = paddle.full(shape=[batch_size], fill_value=False, dtype="bool")

cache_kvs_shape = fast_llm_model.get_cache_kvs_shape(fast_llm_model.config, batch_size)
cachekv_dtype = config.dtype if arg_config.cachekv_int8_type is None else "uint8"
model_inputs["cache_kvs"] = [paddle.zeros(shape, dtype=cachekv_dtype) for shape in cache_kvs_shape]

block_nums = arg_config.total_max_length // arg_config.block_size
model_inputs["block_tables"] = paddle.arange(block_nums, dtype="int32").tile([batch_size, 1])

seq_lens = inputs_embeds.shape[1]
model_inputs["seq_lens_this_time"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1))
model_inputs["seq_lens_encoder"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1))
model_inputs["seq_lens_decoder"] = paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int32")
model_inputs["step_idx"] = paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64")
model_inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=True, dtype="bool")
model_inputs["stop_flags"] = paddle.full(shape=[batch_size, 1], fill_value=False, dtype="bool")
model_inputs["stop_nums"] = paddle.full(shape=[1], fill_value=batch_size, dtype="int64")
model_inputs["pre_ids"] = paddle.full(shape=[batch_size, arg_config.max_length], fill_value=-1, dtype="int64")
model_inputs["next_tokens"] = paddle.full(shape=[batch_size, 1], fill_value=-1, dtype="int64")

return model_inputs


parser = PdArgumentParser((PredictorArgument, ModelArgument))
predictor_args, model_args = parser.parse_args_into_dataclasses()

paddle.set_default_dtype(predictor_args.dtype)
config = AutoConfig.from_pretrained(MODEL_NAME)

# NOTE: (changwenbin) This is for using the inference optimization of paddlenlp qwen2.
config.model_type = "qwen2"
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
fast_llm_model = AutoInferenceModelForCausalLM.from_pretrained(
MODEL_NAME,
config=config,
predictor_args=predictor_args,
model_args=model_args,
dtype=predictor_args.dtype,
tensor_parallel_degree=1,
tensor_parallel_rank=0,
)
fast_llm_model.eval()

vl_model.model = fast_llm_model


def run_model():

vision_model_inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pd",
)
inputs_embeds = vl_model.vision_forward(**vision_model_inputs)
llm_model_inputs = init_llm_model_inputs(vision_model_inputs, inputs_embeds, arg_config=predictor_args)
generated_text = ""
while llm_model_inputs["not_need_stop"]:
generated_ids = fast_llm_model.generate(**llm_model_inputs) # already trimmed in paddle
llm_model_inputs["input_ids"] = generated_ids
llm_model_inputs["inputs_embeds"] = None
new_text_piece = processor.batch_decode(
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
if new_text_piece == "<|im_end|>":
break
generated_text += new_text_piece
return generated_text


if predictor_args.benchmark:
print(f"Benchmarking {MODEL_NAME} ...")
warm_up = 3
repeat_times = 10
sumtime = 0.0
times = repeat_times + warm_up
for i in range(times):
if i > 2:
paddle.device.synchronize()
starttime = datetime.datetime.now()
generated_text = run_model()
if i > 2:
paddle.device.synchronize()
endtime = datetime.datetime.now()
print("Final output_text:\n", generated_text)

if i > 2:
duringtime = endtime - starttime
duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
sumtime += duringtime
print(f"Single {MODEL_NAME} end to end time : ", duringtime, "ms")
inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3)
print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB")

print(f"Single {MODEL_NAME} ave end to end time : ", sumtime / repeat_times, "ms")

else:
generated_text = run_model()
print("Final output_text:\n", generated_text)
Loading

0 comments on commit f6807df

Please sign in to comment.