diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index bad3cc8eb6..ee7be75b7c 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -91,47 +91,26 @@ def update_request_cache_with_output(request_cache: OrderedDict, def update_multiple_sequences(cache, request_output, vllm_request_output): for completion_output in vllm_request_output.outputs: - sequence_index = completion_output.index - if f"sequence_index_{sequence_index}" not in cache: - cache[f"sequence_index_{sequence_index}"] = { - "curr_length": 0, - "num_generated_tokens": 0 - } if sequence_index not in request_output.sequences: request_output.sequences[sequence_index] = Sequence() - # set token of the sequence - # previous length of token ids generated - prev_len = cache[f"sequence_index_{sequence_index}"][ - 'num_generated_tokens'] - # curr length of the token ids generated so far - cur_len = len(completion_output.token_ids) - cache[f"sequence_index_{sequence_index}"][ - "num_generated_tokens"] = cur_len - # get the newly generated token_ids - new_token_ids = completion_output.token_ids[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.token_ids + new_token_ids = completion_output.token_ids # get the newly generated token texts for speculative decoding output_token_texts = [] if hasattr(completion_output, "output_token_texts"): - output_token_texts = completion_output.output_token_texts[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.output_token_texts + output_token_texts = completion_output.output_token_texts top_tokens = [] token_texts = [] # calculate log probs and token_texts if completion_output.logprobs: - new_logprobs_list = completion_output.logprobs[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.logprobs new_logprobs = [] - for token_id, logprobs in zip(new_token_ids, new_logprobs_list): + for token_id, logprobs in zip(new_token_ids, + completion_output.logprobs): new_logprobs.append(logprobs[token_id].logprob) decoded_token = logprobs[token_id].decoded_token if logprobs[ token_id].decoded_token else "" @@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): Token(id=token_id_key, text=logprob.decoded_token, log_prob=logprob.logprob)) - elif new_token_ids: # TODO: Test and remove this. logprobs is always set 1. This case should never happen. new_logprobs = [None] * len(new_token_ids) - curr_length = cache[f"sequence_index_{sequence_index}"][ - "curr_length"] - token_texts.append(completion_output.text[curr_length:]) + token_texts.append(completion_output.text) if not output_token_texts: if len(token_texts) != len(new_token_ids): @@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): request_output.sequences[sequence_index].set_next_top_tokens( top_tokens) - cache[f"sequence_index_{sequence_index}"]["curr_length"] = len( - completion_output.text) - def get_speculative_decoding_metrics_record( completion_output: CompletionOutput, diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 66abbf811e..6ce21a9b9d 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -13,6 +13,7 @@ from collections import OrderedDict, defaultdict from vllm import LLMEngine, SamplingParams +from vllm.sampling_params import RequestOutputKind from vllm.utils import random_uuid, AtomicCounter from djl_python.request import Request @@ -128,7 +129,8 @@ def inference(self, new_requests: List[Request]) -> List: request_id = random_uuid() prompt_inputs = get_prompt_inputs(request) params = self.translate_vllm_params(request.parameters) - sampling_params = SamplingParams(**params) + sampling_params = SamplingParams( + output_kind=RequestOutputKind.DELTA, **params) request_params = dict() if request.adapter is not None: adapter_name = request.adapter.get_property("name")