diff --git a/vllm/config.py b/vllm/config.py index 27c61d4d50439..0b8a2a27f6d43 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -44,6 +44,9 @@ class ModelConfig: revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + code_revision: The specific revision to use for the model code on + Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -70,6 +73,7 @@ def __init__( dtype: Union[str, torch.dtype], seed: int, revision: Optional[str] = None, + code_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, @@ -84,6 +88,7 @@ def __init__( self.load_format = load_format self.seed = seed self.revision = revision + self.code_revision = code_revision self.tokenizer_revision = tokenizer_revision self.quantization = quantization self.enforce_eager = enforce_eager @@ -103,7 +108,8 @@ def __init__( self.download_dir = model_path self.tokenizer = model_path - self.hf_config = get_config(self.model, trust_remote_code, revision) + self.hf_config = get_config(self.model, trust_remote_code, revision, + code_revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d5e63e25d6e85..8ac0157151d8e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -32,6 +32,7 @@ class EngineArgs: max_paddings: int = 256 disable_log_stats: bool = False revision: Optional[str] = None + code_revision: Optional[str] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False @@ -75,6 +76,13 @@ def add_cli_args( help='the specific model version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') + parser.add_argument( + '--code-revision', + type=str, + default=None, + help='the specific revision to use for the model code on ' + 'Hugging Face Hub. It can be a branch name, a tag name, or a ' + 'commit id. If unspecified, will use the default version.') parser.add_argument( '--tokenizer-revision', type=str, @@ -279,13 +287,12 @@ def create_engine_configs( ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, DeviceConfig, Optional[LoRAConfig]]: device_config = DeviceConfig(self.device) - model_config = ModelConfig(self.model, self.tokenizer, - self.tokenizer_mode, self.trust_remote_code, - self.download_dir, self.load_format, - self.dtype, self.seed, self.revision, - self.tokenizer_revision, self.max_model_len, - self.quantization, self.enforce_eager, - self.max_context_len_to_capture) + model_config = ModelConfig( + self.model, self.tokenizer, self.tokenizer_mode, + self.trust_remote_code, self.download_dir, self.load_format, + self.dtype, self.seed, self.revision, self.code_revision, + self.tokenizer_revision, self.max_model_len, self.quantization, + self.enforce_eager, self.max_context_len_to_capture) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index b12918e41b32e..491cb4d9a427c 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -16,10 +16,14 @@ def get_config(model: str, trust_remote_code: bool, - revision: Optional[str] = None) -> PretrainedConfig: + revision: Optional[str] = None, + code_revision: Optional[str] = None) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code, revision=revision) + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision) except ValueError as e: if (not trust_remote_code and "requires you to execute the configuration file" in str(e)): @@ -33,5 +37,7 @@ def get_config(model: str, raise e if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] - config = config_class.from_pretrained(model, revision=revision) + config = config_class.from_pretrained(model, + revision=revision, + code_revision=code_revision) return config