Skip to content

Commit

Permalink
[cherry-pick] add qwen2vl video infer benchmark (#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
nemonameless authored Jan 7, 2025
1 parent 5b208af commit 09b3b44
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 67 deletions.
20 changes: 10 additions & 10 deletions deploy/qwen2_vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,31 @@ python setup_cuda.py install

### 3.1. 文本&单张图像输入高性能推理
```bash
python deploy/qwen2_vl/single_image_infer.py \
CUDA_VISIBLE_DEVICES=0 python deploy/qwen2_vl/single_image_infer.py \
--model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
--dtype bfloat16 \
--benchmark True \
```

- 在 NVIDIA A100-SXM4-80GB 上测试的单图端到端速度性能如下:
- 在 NVIDIA A800-80GB 上测试的单图端到端速度性能如下:

| model | Paddle Inference| PyTorch | Paddle 动态图 |
| ---------------------- | --------------- | ------------ | ------------ |
| Qwen2-VL-2B-Instruct | 1.44 s | 2.35 s | 5.215 s |
| Qwen2-VL-7B-Instruct | 1.73 s | 4.4s | 6.339 s |
| Qwen2-VL-2B-Instruct | 1.053 s | 2.086 s | 5.766 s |
| Qwen2-VL-7B-Instruct | 2.293 s | 3.132 s | 6.221 s |


### 3.2. 文本&视频输入高性能推理
```bash
python deploy/qwen2_vl/video_infer.py \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
CUDA_VISIBLE_DEVICES=0 python deploy/qwen2_vl/video_infer.py \
--model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
--dtype bfloat16 \
--benchmark 1
--benchmark True
```

- 在 NVIDIA A100-SXM4-80GB 上测试的单图端到端速度性能如下
- 在 NVIDIA A800-80GB 上测试的单视频端到端速度性能如下

| model | Paddle Inference| PyTorch | Paddle 动态图 |
| ---------------------- | --------------- | ------------ | ------------ |
| Qwen2-VL-2B-Instruct | 1.503 s | - | 47.922 s |
| Qwen2-VL-7B-Instruct | 2.715 s | - | 33.597 s |
| Qwen2-VL-2B-Instruct | 2.890 s | 3.143 s | 6.183 s |
| Qwen2-VL-7B-Instruct | 2.534 s | 2.715 s | 5.721 s |
8 changes: 6 additions & 2 deletions deploy/qwen2_vl/single_image_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,15 @@ def run_model():
duringtime = endtime - starttime
duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
sumtime += duringtime
print(f"Single {predictor_args.model_name_or_path} end to end time : ", duringtime, "ms")
print(f"Single Image Inference: {predictor_args.model_name_or_path} end-to-end time : ", duringtime, "ms")
inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3)
print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB")

print(f"Single {predictor_args.model_name_or_path} ave end to end time : ", sumtime / repeat_times, "ms")
print(
f"Single Image Inference: {predictor_args.model_name_or_path} average end-to-end time : ",
sumtime / repeat_times,
"ms",
)

else:
generated_text = run_model()
Expand Down
8 changes: 6 additions & 2 deletions deploy/qwen2_vl/video_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,15 @@ def run_model():
duringtime = endtime - starttime
duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
sumtime += duringtime
print(f"Single {predictor_args.model_name_or_path} end to end time : ", duringtime, "ms")
print(f"Single Video Inference: {predictor_args.model_name_or_path} end to end time : ", duringtime, "ms")
inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3)
print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB")

print(f"Single {predictor_args.model_name_or_path} ave end to end time : ", sumtime / repeat_times, "ms")
print(
f"Single Video Inference: {predictor_args.model_name_or_path} average end-to-end time : ",
sumtime / repeat_times,
"ms",
)

else:
generated_text = run_model()
Expand Down
14 changes: 7 additions & 7 deletions paddlemix/examples/qwen2_vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ python -m pip install paddlenlp==3.0.0b3

### a. 单图预测
```bash
python paddlemix/examples/qwen2_vl/single_image_infer.py
CUDA_VISIBLE_DEVICES=0 python paddlemix/examples/qwen2_vl/single_image_infer.py
```

### b. 多图预测
```bash
python paddlemix/examples/qwen2_vl/multi_image_infer.py
CUDA_VISIBLE_DEVICES=0 python paddlemix/examples/qwen2_vl/multi_image_infer.py
```

### c. 视频预测
```bash
python paddlemix/examples/qwen2_vl/video_infer.py
CUDA_VISIBLE_DEVICES=0 python paddlemix/examples/qwen2_vl/video_infer.py
```

## 4 模型微调
Expand Down Expand Up @@ -106,19 +106,19 @@ sh paddlemix/examples/qwen2_vl/shell/basline_7b_lora_bs32_1e8.sh
同按步骤3中的模型推理预测,只需将`paddlemix/examples/qwen2_vl/single_image_infer.py`中的`--model_path`参数修改为微调后的模型路径即可。

```bash
python paddlemix/examples/qwen2_vl/single_image_infer.py
CUDA_VISIBLE_DEVICES=0 python paddlemix/examples/qwen2_vl/single_image_infer.py
```

### 5 高性能推理优化

[Paddle高性能推理优化后](../../../deploy/qwen2_vl/),测试结果如下:

- 在 NVIDIA A100-SXM4-80GB 上测试的单图端到端速度性能如下:
- 在 NVIDIA A800-80GB 上测试的单图端到端速度性能如下:

| model | Paddle Inference| PyTorch | Paddle 动态图 |
| ---------------------- | --------------- | ------------ | ------------ |
| Qwen2-VL-2B-Instruct | 1.44 s | 2.35 s | 5.215 s |
| Qwen2-VL-7B-Instruct | 1.73 s | 4.4s | 6.339 s |
| Qwen2-VL-2B-Instruct | 1.053 s | 2.086 s | 5.766 s |
| Qwen2-VL-7B-Instruct | 2.293 s | 3.132 s | 6.221 s |



Expand Down
8 changes: 4 additions & 4 deletions paddlemix/examples/qwen2_vl/single_image_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def main(args):
if i > 10:
total += time.time() - start
print("s/it: ", total / 10)
print(f"\nGPU memory_allocated: {paddle.device.cuda.memory_allocated() / 1024 ** 3:.2f} GB")
print(f"\nGPU max_memory_allocated: {paddle.device.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB")
print(f"\nGPU memory_reserved: {paddle.device.cuda.memory_reserved() / 1024 ** 3:.2f} GB")
print(f"\nGPU max_memory_reserved: {paddle.device.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")
print(f"GPU memory_allocated: {paddle.device.cuda.memory_allocated() / 1024 ** 3:.2f} GB")
print(f"GPU max_memory_allocated: {paddle.device.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB")
print(f"GPU memory_reserved: {paddle.device.cuda.memory_reserved() / 1024 ** 3:.2f} GB")
print(f"GPU max_memory_reserved: {paddle.device.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")
print("output_text:\n", output_text)

else:
Expand Down
145 changes: 103 additions & 42 deletions paddlemix/examples/qwen2_vl/video_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,114 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import paddle

from paddlemix.models.qwen2_vl import MIXQwen2Tokenizer
from paddlemix.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
from paddlemix.processors.qwen2_vl_processing import (
Qwen2VLImageProcessor,
Qwen2VLProcessor,
process_vision_info,
)
from paddlemix.utils.log import logger

MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct"
model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype="bfloat16")

image_processor = Qwen2VLImageProcessor()
tokenizer = MIXQwen2Tokenizer.from_pretrained(MODEL_NAME)
min_pixels = 256 * 28 * 28 # 200704
max_pixels = 1280 * 28 * 28 # 1003520
processor = Qwen2VLProcessor(image_processor, tokenizer, min_pixels=min_pixels, max_pixels=max_pixels)

# Messages containing a video and a text query
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"video": "paddlemix/demo_images/red-panda.mp4",
"max_pixels": 360 * 420,
"fps": 1.0,
},
{"type": "text", "text": "Describe this video."},
],
}
]

image_inputs, video_inputs = process_vision_info(messages)
question = "Describe this video."
video_pad_token = "<|vision_start|><|video_pad|><|vision_end|>"
text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{video_pad_token}{question}<|im_end|>\n<|im_start|>assistant\n"

inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pd",
)
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle
# print("generated_ids:\n", generated_ids)
output_text = processor.batch_decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
print("output_text:\n", output_text[0])

def main(args):
paddle.seed(seed=0)
compute_dtype = "float16" if args.fp16 else "bfloat16"
if "npu" in paddle.get_device():
is_bfloat16_supported = True
else:
is_bfloat16_supported = paddle.amp.is_bfloat16_supported()
if compute_dtype == "bfloat16" and not is_bfloat16_supported:
logger.warning("bfloat16 is not supported on your device,change to float32")
compute_dtype = "float32"

model = Qwen2VLForConditionalGeneration.from_pretrained(args.model_path, dtype="bfloat16")

image_processor = Qwen2VLImageProcessor()
tokenizer = MIXQwen2Tokenizer.from_pretrained(args.model_path)
min_pixels = 256 * 28 * 28 # 200704
max_pixels = 1280 * 28 * 28 # 1003520
processor = Qwen2VLProcessor(image_processor, tokenizer, min_pixels=min_pixels, max_pixels=max_pixels)

# Messages containing a video and a text query
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"video": f"{args.video_file}",
"max_pixels": 360 * 420,
"fps": 1.0,
},
{"type": "text", "text": f"{args.question}"},
],
}
]

# Preparation for inference
image_inputs, video_inputs = process_vision_info(messages)

question = messages[0]["content"][1]["text"]
video_pad_token = "<|vision_start|><|video_pad|><|vision_end|>"
text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{video_pad_token}{question}<|im_end|>\n<|im_start|>assistant\n"
text = [text]

inputs = processor(
text=text,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pd",
)

if args.benchmark:
import time

start = 0.0
total = 0.0
for i in range(20):
if i > 10:
start = time.time()
with paddle.no_grad():
generated_ids = model.generate(
**inputs, max_new_tokens=args.max_new_tokens, temperature=args.temperature
) # already trimmed in paddle
output_text = processor.batch_decode(
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
if i > 10:
total += time.time() - start
print("s/it: ", total / 10)
print(f"GPU memory_allocated: {paddle.device.cuda.memory_allocated() / 1024 ** 3:.2f} GB")
print(f"GPU max_memory_allocated: {paddle.device.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB")
print(f"GPU memory_reserved: {paddle.device.cuda.memory_reserved() / 1024 ** 3:.2f} GB")
print(f"GPU max_memory_reserved: {paddle.device.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")
print("output_text:\n", output_text)

else:
# Inference: Generation of the output
generated_ids = model.generate(
**inputs, max_new_tokens=args.max_new_tokens, temperature=args.temperature
) # already trimmed in paddle
output_text = processor.batch_decode(
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("output_text:\n", output_text)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="Qwen/Qwen2-VL-2B-Instruct")
parser.add_argument("--question", type=str, default="Describe this video.")
parser.add_argument("--video_file", type=str, default="paddlemix/demo_images/red-panda.mp4")
parser.add_argument("--temperature", type=float, default=0.01)
parser.add_argument("--max_new_tokens", type=int, default=128)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--benchmark", action="store_true")
args = parser.parse_args()
main(args)

0 comments on commit 09b3b44

Please sign in to comment.