Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lmi][vllm][trtllm] add support for generation parameters fron genera… #2685

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Properties(BaseModel):
tgi_compat: Optional[bool] = False
bedrock_compat: Optional[bool] = False
enable_lora: Optional[bool] = False
generation_config: Optional[str] = None

# Spec_dec
draft_model_id: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
import logging
import os
from typing import List, Optional
from collections import OrderedDict, defaultdict
from collections import OrderedDict

from lmi_dist.api import Request, RequestParams
from lmi_dist.arg_utils import VllmEngineArgs
from lmi_dist.init_engine import engine_from_args
from lmi_dist.seq2seq_engine import Seq2SeqPreprocessor
from vllm import SamplingParams
from vllm.utils import AtomicCounter
from vllm.engine.llm_engine import _load_generation_config_dict

from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
from djl_python.rolling_batch.rolling_batch_vllm_utils import (
Expand Down Expand Up @@ -96,22 +96,16 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
kwargs = {}
logging.info(f"engine_args: {engine_args}, kwargs: {kwargs}")

if self.lmi_dist_config.max_rolling_batch_prefill_tokens is None:
logging.warning(
"djl-serving/lmi has changed the default behavior for max_rolling_batch_prefill_tokens in 0.30.0 (lmi v12). "
"Previously, when max_rolling_batch_prefill_tokens was unset, djl-serving would use a warmup prefill limit of 4096 tokens. "
"This behavior differs from vLLM's default behavior, which (essentially) defaults to max_model_len. As a result of this change, "
"model deployments that worked previously may fail due to higher memory requirements at model loading time for the warmup phase. "
"For more information on this change, and guidance on what configurations to set, please see "
"https://github.com/deepjavalibrary/djl-serving/tree/master/serving/docs/lmi/announcements/breaking_changes.md"
)
self.engine = engine_from_args(engine_args, **kwargs)
self.request_cache = OrderedDict()
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}
self.is_mistral_tokenizer = self.lmi_dist_config.tokenizer_mode == 'mistral'
self.is_t5_model = isinstance(self.engine.preprocessor,
Seq2SeqPreprocessor)
self.default_generation_params = _load_generation_config_dict(
self.engine.model_config
) if self.lmi_dist_config.generation_config == 'auto' else {}

def reset(self) -> None:
"""
Expand Down Expand Up @@ -139,11 +133,6 @@ def get_huggingface_model_config(self):
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None

def get_huggingface_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None

def translate_lmi_dist_params(self, parameters: dict):
"""
Helper function to convert DJL Serving parameter names to parameter names
Expand All @@ -154,6 +143,11 @@ def translate_lmi_dist_params(self, parameters: dict):
:return: The same parameters dict, but with lmi-dist style parameter names.
"""
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
# when a default generation_config.json is provided, we still respect any overrides
# sent directly from the request
for k, v in self.default_generation_params.items():
if k not in parameters and k in LMI_DIST_GENERATION_PARAMS:
parameters[k] = v
do_sample = parameters.pop("do_sample", None)
if do_sample is not None and do_sample is False:
parameters["temperature"] = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,24 @@
from djl_python.request import Request
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception
from djl_python.request_io import Token
from typing import List
from typing import List, Optional
from transformers import GenerationConfig


# https://github.com/vllm-project/vllm/blame/3132a933b65d8ed3383e082264c682940d92d803/vllm/config.py#L873
def try_get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
) -> dict:
try:
return GenerationConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
).to_diff_dict()
except OSError: # Not found
return {}


class TRTLLMRollingBatch(RollingBatch):
Expand All @@ -41,6 +58,9 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.model = tensorrt_llm_toolkit.init_inference(
model_id_or_path, **properties)
self.request_cache = {}
self.default_generation_params = try_get_generation_config(
self.model.model.model_path, configs.trust_remote_code,
configs.revision) if configs.generation_config == 'auto' else {}

def get_tokenizer(self):
return self.model.tokenizer
Expand All @@ -63,6 +83,11 @@ def translate_triton_params(self, parameters: dict) -> dict:

:return: The same parameters dict, but with TensorRT-LLM style parameter names.
"""
# when a default generation_config.json is provided, we still respect any overrides
# sent directly from the request
for k, v in self.default_generation_params.items():
if k not in parameters:
parameters[k] = v
if "request_output_len" not in parameters:
parameters["request_output_len"] = parameters.pop(
"max_new_tokens", 30)
Expand All @@ -72,7 +97,8 @@ def translate_triton_params(self, parameters: dict) -> dict:
parameters["runtime_top_p"] = parameters.pop("top_p")
if "seed" in parameters:
parameters["random_seed"] = int(parameters.pop("seed"))
if parameters.pop("do_sample", False):
do_sample = parameters.pop("do_sample", None)
if do_sample is not None and do_sample:
parameters["runtime_top_k"] = parameters.get("runtime_top_k", 5)
parameters["runtime_top_p"] = parameters.get("runtime_top_p", 0.85)
parameters["temperature"] = parameters.get("temperature", 0.8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from collections import OrderedDict, defaultdict
from collections import OrderedDict

from vllm import LLMEngine, SamplingParams
from vllm.utils import random_uuid, AtomicCounter
Expand All @@ -24,10 +23,7 @@
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from typing import List, Optional

# FIXME: Once all vllm versions are past 0.6.0 we can move to just struct_fields
VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) if hasattr(
SamplingParams(), "__struct_fields__") else set(
SamplingParams().__dict__.keys())
VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__)


class VLLMRollingBatch(RollingBatch):
Expand All @@ -54,6 +50,7 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}
self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral'
self.default_generation_params = self.engine.generation_config_fields if self.vllm_configs.generation_config == 'auto' else {}

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
Expand Down Expand Up @@ -85,6 +82,11 @@ def translate_vllm_params(self, parameters: dict) -> dict:

:return: The same parameters dict, but with VLLM style parameter names.
"""
# when a default generation_config.json is provided, we still respect any overrides
# sent directly from the request
for k, v in self.default_generation_params.items():
if k not in parameters and k in VLLM_GENERATION_PARAMS:
parameters[k] = v
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
do_sample = parameters.pop("do_sample", None)
if do_sample is not None and do_sample is False:
Expand Down
Loading
Loading