Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
noooop committed Aug 18, 2024
1 parent a96d4fe commit fc646ba
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 86 deletions.
137 changes: 137 additions & 0 deletions benchmarks/baseline2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@

import os
import random
import numpy as np
import time


def benchmark(args):
random.seed(args.seed)

os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"
os.environ["VLLM_NO_USAGE_STATS"] = "True"

import vllm
from vllm import LLMEngine, EngineArgs, SamplingParams, TextPrompt

print(vllm.__version__)

engine_args = EngineArgs(
model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
seed=args.seed,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
enable_prefix_caching=args.enable_prefix_caching,
download_dir=args.download_dir,
enable_chunked_prefill=args.enable_chunked_prefill,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_seqs=args.max_num_seqs,
distributed_executor_backend=args.distributed_executor_backend,
disable_log_stats=True
)
engine = LLMEngine.from_engine_args(engine_args)

prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]

start = time.perf_counter()
for request_id, (prompt, _, output_len) in enumerate(requests):
inputs = TextPrompt(prompt=prompt)
sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
engine.add_request(str(request_id), inputs, sampling_params)

out = []
while engine.has_unfinished_requests():
request_outputs = engine.step()
out.append((time.perf_counter(), request_outputs))
end = time.perf_counter()

timestamp = {}
for t, rs in out:
for r in rs:
request_id = r.request_id
if request_id not in timestamp:
timestamp[request_id] = []
timestamp[request_id].append(t)

tpot = []
for v in timestamp.values():
dd = [v[i]-v[i-1] for i in range(1, len(v))]
tpot.extend(dd)

tpot = np.mean(tpot)
elapsed_time = end - start

total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)

print(f"Throughput: {len(requests) / elapsed_time:.4f} requests/s, "
f"{total_num_tokens / elapsed_time:.4f} tokens/s, "
f"Delay {tpot*1000:0.2f} ms")


if __name__ == '__main__':
from easydict import EasyDict as edict
args = edict()

args.dataset = None
args.input_len = 512
args.output_len = 512

args.model = "Qwen/Qwen2-7B-Instruct"
args.trust_remote_code = False
args.tokenizer = args.model
args.quantization = None
args.quantization_param_path = None
args.tensor_parallel_size = 1
args.seed = 0
args.n = 1
args.use_beam_search = False
args.num_prompts = 1000
args.dtype = 'auto'
args.max_model_len = 10000
args.enforce_eager = True
args.kv_cache_dtype = "auto"
args.device = "cuda"
args.enable_prefix_caching = False
args.gpu_memory_utilization = 0.9
args.output_json = None
args.distributed_executor_backend = None
args.download_dir = None

import sys
from concurrent.futures import ProcessPoolExecutor

def run(args):
with ProcessPoolExecutor(1) as executor:
f = executor.submit(benchmark, args)
f.result()


max_num_seqs_list = [256]

print()
print("enable_chunked_prefill = True")
for max_num_seqs in max_num_seqs_list:
print("max_num_seqs", max_num_seqs)
args.enable_chunked_prefill = True
args.max_num_seqs = max_num_seqs
args.max_num_batched_tokens = args.max_num_seqs
run(args)
2 changes: 1 addition & 1 deletion vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,5 +700,5 @@ def get_common_computed_block_ids(

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.get_seqs():
for seq in seq_group.seqs_dict.values():
self.compute_full_blocks_in_seq(seq)
45 changes: 45 additions & 0 deletions vllm/core/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections import deque
from typing import Deque

from vllm.sequence import SequenceGroup


class Policy:

def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
raise NotImplementedError

def sort_by_priority(
self,
now: float,
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))


class FCFS(Policy):

def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.metrics.arrival_time


class PolicyFactory:

_POLICY_REGISTRY = {'fcfs': FCFS}

@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
Loading

0 comments on commit fc646ba

Please sign in to comment.