Skip to content

Commit

Permalink
[lmi][vllm][trtllm] add support for generation parameters fron genera…
Browse files Browse the repository at this point in the history
…tion_config.json
  • Loading branch information
siddvenk committed Jan 24, 2025
1 parent 4a87746 commit 40fe0eb
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
import logging
import os
from typing import List, Optional
from collections import OrderedDict, defaultdict
from collections import OrderedDict

from lmi_dist.api import Request, RequestParams
from lmi_dist.arg_utils import VllmEngineArgs
from lmi_dist.init_engine import engine_from_args
from lmi_dist.seq2seq_engine import Seq2SeqPreprocessor
from vllm import SamplingParams
from vllm.utils import AtomicCounter
from vllm.engine.llm_engine import _load_generation_config_dict

from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
from djl_python.rolling_batch.rolling_batch_vllm_utils import (
Expand Down Expand Up @@ -96,22 +96,15 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
kwargs = {}
logging.info(f"engine_args: {engine_args}, kwargs: {kwargs}")

if self.lmi_dist_config.max_rolling_batch_prefill_tokens is None:
logging.warning(
"djl-serving/lmi has changed the default behavior for max_rolling_batch_prefill_tokens in 0.30.0 (lmi v12). "
"Previously, when max_rolling_batch_prefill_tokens was unset, djl-serving would use a warmup prefill limit of 4096 tokens. "
"This behavior differs from vLLM's default behavior, which (essentially) defaults to max_model_len. As a result of this change, "
"model deployments that worked previously may fail due to higher memory requirements at model loading time for the warmup phase. "
"For more information on this change, and guidance on what configurations to set, please see "
"https://github.com/deepjavalibrary/djl-serving/tree/master/serving/docs/lmi/announcements/breaking_changes.md"
)
self.engine = engine_from_args(engine_args, **kwargs)
self.request_cache = OrderedDict()
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}
self.is_mistral_tokenizer = self.lmi_dist_config.tokenizer_mode == 'mistral'
self.is_t5_model = isinstance(self.engine.preprocessor,
Seq2SeqPreprocessor)
self.default_generation_params = _load_generation_config_dict(
self.engine.model_config)

def reset(self) -> None:
"""
Expand Down Expand Up @@ -139,11 +132,6 @@ def get_huggingface_model_config(self):
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None

def get_huggingface_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None

def translate_lmi_dist_params(self, parameters: dict):
"""
Helper function to convert DJL Serving parameter names to parameter names
Expand All @@ -154,6 +142,11 @@ def translate_lmi_dist_params(self, parameters: dict):
:return: The same parameters dict, but with lmi-dist style parameter names.
"""
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
# when a default generation_config.json is provided, we still respect any overrides
# sent directly from the request
for k, v in self.default_generation_params.items():
if k not in parameters and k in LMI_DIST_GENERATION_PARAMS:
parameters[k] = v
do_sample = parameters.pop("do_sample", None)
if do_sample is not None and do_sample is False:
parameters["temperature"] = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,24 @@
from djl_python.request import Request
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception
from djl_python.request_io import Token
from typing import List
from typing import List, Optional
from transformers import GenerationConfig


# https://github.com/vllm-project/vllm/blame/3132a933b65d8ed3383e082264c682940d92d803/vllm/config.py#L873
def try_get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
) -> dict:
try:
return GenerationConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
).to_diff_dict()
except OSError: # Not found
return {}


class TRTLLMRollingBatch(RollingBatch):
Expand All @@ -41,6 +58,9 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.model = tensorrt_llm_toolkit.init_inference(
model_id_or_path, **properties)
self.request_cache = {}
self.default_generation_params = try_get_generation_config(
self.model.model.model_path, configs.trust_remote_code,
configs.revision)

def get_tokenizer(self):
return self.model.tokenizer
Expand All @@ -63,6 +83,11 @@ def translate_triton_params(self, parameters: dict) -> dict:
:return: The same parameters dict, but with TensorRT-LLM style parameter names.
"""
# when a default generation_config.json is provided, we still respect any overrides
# sent directly from the request
for k, v in self.default_generation_params.items():
if k not in parameters:
parameters[k] = v
if "request_output_len" not in parameters:
parameters["request_output_len"] = parameters.pop(
"max_new_tokens", 30)
Expand All @@ -72,7 +97,8 @@ def translate_triton_params(self, parameters: dict) -> dict:
parameters["runtime_top_p"] = parameters.pop("top_p")
if "seed" in parameters:
parameters["random_seed"] = int(parameters.pop("seed"))
if parameters.pop("do_sample", False):
do_sample = parameters.pop("do_sample", None)
if do_sample is not None and do_sample:
parameters["runtime_top_k"] = parameters.get("runtime_top_k", 5)
parameters["runtime_top_p"] = parameters.get("runtime_top_p", 0.85)
parameters["temperature"] = parameters.get("temperature", 0.8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from collections import OrderedDict, defaultdict
from collections import OrderedDict

from vllm import LLMEngine, SamplingParams
from vllm.utils import random_uuid, AtomicCounter
Expand All @@ -24,10 +23,7 @@
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from typing import List, Optional

# FIXME: Once all vllm versions are past 0.6.0 we can move to just struct_fields
VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) if hasattr(
SamplingParams(), "__struct_fields__") else set(
SamplingParams().__dict__.keys())
VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__)


class VLLMRollingBatch(RollingBatch):
Expand All @@ -54,6 +50,7 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}
self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral'
self.default_generation_params = self.engine.generation_config_fields

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
Expand Down Expand Up @@ -85,6 +82,11 @@ def translate_vllm_params(self, parameters: dict) -> dict:
:return: The same parameters dict, but with VLLM style parameter names.
"""
# when a default generation_config.json is provided, we still respect any overrides
# sent directly from the request
for k, v in self.default_generation_params.items():
if k not in parameters and k in VLLM_GENERATION_PARAMS:
parameters[k] = v
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
do_sample = parameters.pop("do_sample", None)
if do_sample is not None and do_sample is False:
Expand Down

0 comments on commit 40fe0eb

Please sign in to comment.