Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core]: Support encode only models (xlm-roberta、bge-m3...) by Workflow Defined Engine #8462

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added demo_temporary/__init__.py
Empty file.
Empty file.
116 changes: 116 additions & 0 deletions demo_temporary/benchmarks/benchmark_bge-m3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import time
import random


def benchmark_hf(args):
random.seed(args.seed)

import torch
from FlagEmbedding import BGEM3FlagModel

model = BGEM3FlagModel(args.model, use_fp16=True)

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

with torch.no_grad():
for batchsize in args.batchsize:
start = time.perf_counter()
n_step = 0
for i in range(0, len(requests), batchsize):
batch = requests[i:i + batchsize]
output = model.encode(batch, batch_size=batchsize)
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: {len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")


def benchmark_vllm(args):
random.seed(args.seed)

import gc
import torch
from vllm.wde.entrypoints.llm import LLMEngine
from vllm.wde.encode_only.arg_utils import EncodeOnlyEngineArgs as EngineArgs

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

engine_args = EngineArgs(
model=args.model,
tokenizer=args.tokenizer,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
device=args.device,
max_num_seqs=32,
scheduling=args.scheduling
)

engine = LLMEngine.from_engine_args(engine_args)

for batchsize in args.batchsize:
engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize)

start = time.perf_counter()
for request_id, prompt in enumerate(requests):
engine.add_request(str(request_id), prompt)

n_step = 0
while engine.has_unfinished_requests():
engine.step()
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: {len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")

engine.executor.shutdown_execute_loop()
gc.collect()
torch.cuda.empty_cache()


if __name__ == '__main__':
from easydict import EasyDict as edict
args = edict()

args.input_len = 256
args.num_prompts = 10000

args.model = 'BAAI/bge-m3'

args.trust_remote_code = False
args.tokenizer = args.model
args.seed = 0
args.max_model_len = None
args.dtype = "half"
args.device = "cuda"
args.batchsize = [1, 2, 4, 8, 16, 32, 64]

from concurrent.futures import ProcessPoolExecutor

def run_hf(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_hf, args)
f.result()

run_hf(args)

def run_vllm(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_vllm, args)
f.result()

for scheduling in ["sync", "async", "double_buffer"]:
print(scheduling)
args.scheduling = scheduling
run_vllm(args)
125 changes: 125 additions & 0 deletions demo_temporary/benchmarks/benchmark_xlm-roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import time
import random


def benchmark_hf(args):
random.seed(args.seed)

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
model = AutoModelForMaskedLM.from_pretrained(args.model, torch_dtype=torch_dtype).to(args.device)

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

with torch.no_grad():
for batchsize in args.batchsize:
start = time.perf_counter()
n_step = 0
for i in range(0, len(requests), batchsize):
batch = requests[i:i + batchsize]
encoded_input = tokenizer(batch, return_tensors='pt').to(args.device)
output = model(**encoded_input)
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: {len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")


def benchmark_vllm(args):
random.seed(args.seed)

import gc
import torch
from vllm.wde.entrypoints.llm import LLMEngine
from vllm.wde.encode_only.arg_utils import EncodeOnlyEngineArgs as EngineArgs

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

engine_args = EngineArgs(
model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
quantization_param_path=args.quantization_param_path,
device=args.device,
max_num_seqs=32,
scheduling=args.scheduling
)

engine = LLMEngine.from_engine_args(engine_args)

for batchsize in args.batchsize:
engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize)

start = time.perf_counter()
for request_id, prompt in enumerate(requests):
engine.add_request(str(request_id), prompt)

n_step = 0
while engine.has_unfinished_requests():
engine.step()
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: {len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")

engine.executor.shutdown_execute_loop()
gc.collect()
torch.cuda.empty_cache()


if __name__ == '__main__':
from easydict import EasyDict as edict
args = edict()

args.input_len = 256
args.num_prompts = 10000

args.model = 'FacebookAI/xlm-roberta-base'
#args.model = 'FacebookAI/xlm-roberta-large'
args.trust_remote_code = False
args.tokenizer = args.model
args.seed = 0
args.quantization = None
args.quantization_param_path = None
args.max_model_len = None

args.dtype = "half"
args.device = "cuda"
args.batchsize = [1, 2, 4, 8, 16, 32, 64]
from concurrent.futures import ProcessPoolExecutor

def run_hf(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_hf, args)
f.result()

run_hf(args)

def run_vllm(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_vllm, args)
f.result()


for scheduling in ["sync", "async", "double_buffer"]:
print(scheduling)
args.scheduling = scheduling
run_vllm(args)
Empty file.
16 changes: 16 additions & 0 deletions demo_temporary/examples/offline_inference_bge-m3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from vllm.wde.entrypoints.llm import LLM

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]


llm = LLM(model='BAAI/bge-m3')

outputs = llm.encode(prompts)

for output in outputs:
print(output.outputs.shape)
16 changes: 16 additions & 0 deletions demo_temporary/examples/offline_inference_xlm-roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@


from vllm.wde.entrypoints.llm import LLM

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

llm = LLM(model="FacebookAI/xlm-roberta-base")

outputs = llm.encode(prompts)
for output in outputs:
print(output.outputs.shape)
Empty file.
117 changes: 117 additions & 0 deletions demo_temporary/profiler/encode_only_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import time
import random


def patch():
from vllm.wde.encode_only.executor.gpu_executor import GPUAsyncExecutor

simple_execute_loop = GPUAsyncExecutor.simple_execute_loop

def p_execute_loop(self, *args, **kwargs):
import torch
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]) as prof:
simple_execute_loop(self, *args, **kwargs)

prof.export_chrome_trace(f"simple_execute_loop.json")

GPUAsyncExecutor.simple_execute_loop = p_execute_loop

double_buffer_execute_loop = GPUAsyncExecutor.double_buffer_execute_loop

def p_execute_loop(self, *args, **kwargs):
import torch
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]) as prof:
double_buffer_execute_loop(self, *args, **kwargs)
prof.export_chrome_trace(f"double_buffer_execute_loop.json")

GPUAsyncExecutor.double_buffer_execute_loop = p_execute_loop



def benchmark_vllm(args):
random.seed(args.seed)
patch()

from vllm.wde.entrypoints.llm import LLMEngine
from vllm.wde.encode_only.arg_utils import EncodeOnlyEngineArgs as EngineArgs

prompt = "if" * args.input_len
requests = [prompt for _ in range(args.num_prompts)]

engine_args = EngineArgs(
model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
quantization_param_path=args.quantization_param_path,
device=args.device,
max_num_seqs=32,
scheduling=args.scheduling
)

engine = LLMEngine.from_engine_args(engine_args)

for batchsize in args.batchsize:
engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize)

start = time.perf_counter()
for request_id, prompt in enumerate(requests):
engine.add_request(str(request_id), prompt)

n_step = 0
while engine.has_unfinished_requests():
engine.step()
n_step += 1
end = time.perf_counter()

elapsed_time = end - start
delay = elapsed_time / n_step

print(f"Batchsize {batchsize}, Throughput: {len(requests) / elapsed_time:.4f} requests/s, "
f"Delay {delay * 1000:0.2f} ms, n_step {n_step}")

engine.executor.shutdown_execute_loop()


if __name__ == '__main__':
from easydict import EasyDict as edict
args = edict()

args.input_len = 256
args.num_prompts = 100

args.model = 'BAAI/bge-m3'

args.trust_remote_code = False
args.tokenizer = args.model
args.seed = 0
args.quantization = None
args.quantization_param_path = None
args.max_model_len = None

args.dtype = "half"
args.device = "cuda"
args.batchsize = [4]

from concurrent.futures import ProcessPoolExecutor

def run_vllm(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark_vllm, args)
f.result()

for scheduling in ["async", "double_buffer"]:
print(scheduling)
args.scheduling = scheduling
run_vllm(args)
Empty file added tests/wde/__init__.py
Empty file.
Empty file.
Empty file.
Loading
Loading