Skip to content

Commit

Permalink
[python] Update vllm rolling batch to return only deltas
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Dec 14, 2024
1 parent 96a0efd commit dca8119
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,47 +91,26 @@ def update_request_cache_with_output(request_cache: OrderedDict,

def update_multiple_sequences(cache, request_output, vllm_request_output):
for completion_output in vllm_request_output.outputs:

sequence_index = completion_output.index
if f"sequence_index_{sequence_index}" not in cache:
cache[f"sequence_index_{sequence_index}"] = {
"curr_length": 0,
"num_generated_tokens": 0
}

if sequence_index not in request_output.sequences:
request_output.sequences[sequence_index] = Sequence()

# set token of the sequence
# previous length of token ids generated
prev_len = cache[f"sequence_index_{sequence_index}"][
'num_generated_tokens']
# curr length of the token ids generated so far
cur_len = len(completion_output.token_ids)
cache[f"sequence_index_{sequence_index}"][
"num_generated_tokens"] = cur_len

# get the newly generated token_ids
new_token_ids = completion_output.token_ids[
prev_len:
cur_len] if prev_len < cur_len else completion_output.token_ids
new_token_ids = completion_output.token_ids

# get the newly generated token texts for speculative decoding
output_token_texts = []
if hasattr(completion_output, "output_token_texts"):
output_token_texts = completion_output.output_token_texts[
prev_len:
cur_len] if prev_len < cur_len else completion_output.output_token_texts
output_token_texts = completion_output.output_token_texts

top_tokens = []
token_texts = []
# calculate log probs and token_texts
if completion_output.logprobs:
new_logprobs_list = completion_output.logprobs[
prev_len:
cur_len] if prev_len < cur_len else completion_output.logprobs
new_logprobs = []
for token_id, logprobs in zip(new_token_ids, new_logprobs_list):
for token_id, logprobs in zip(new_token_ids,
completion_output.logprobs):
new_logprobs.append(logprobs[token_id].logprob)
decoded_token = logprobs[token_id].decoded_token if logprobs[
token_id].decoded_token else ""
Expand All @@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
Token(id=token_id_key,
text=logprob.decoded_token,
log_prob=logprob.logprob))

elif new_token_ids:
# TODO: Test and remove this. logprobs is always set 1. This case should never happen.
new_logprobs = [None] * len(new_token_ids)
curr_length = cache[f"sequence_index_{sequence_index}"][
"curr_length"]
token_texts.append(completion_output.text[curr_length:])
token_texts.append(completion_output.text)

if not output_token_texts:
if len(token_texts) != len(new_token_ids):
Expand Down Expand Up @@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
request_output.sequences[sequence_index].set_next_top_tokens(
top_tokens)

cache[f"sequence_index_{sequence_index}"]["curr_length"] = len(
completion_output.text)


def get_speculative_decoding_metrics_record(
completion_output: CompletionOutput,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections import OrderedDict, defaultdict

from vllm import LLMEngine, SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.utils import random_uuid, AtomicCounter

from djl_python.request import Request
Expand Down Expand Up @@ -128,7 +129,8 @@ def inference(self, new_requests: List[Request]) -> List:
request_id = random_uuid()
prompt_inputs = get_prompt_inputs(request)
params = self.translate_vllm_params(request.parameters)
sampling_params = SamplingParams(**params)
sampling_params = SamplingParams(
output_kind=RequestOutputKind.DELTA, **params)
request_params = dict()
if request.adapter is not None:
adapter_name = request.adapter.get_property("name")
Expand Down

0 comments on commit dca8119

Please sign in to comment.