From bf711a297e986a67a63bc2dc859088b64ba462f8 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 14 Nov 2024 10:54:59 +0800 Subject: [PATCH] [Bugfix] Fix tensor parallel for qwen2 classification model (#10297) Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/embedding/language/test_cls_models.py | 6 +++--- vllm/model_executor/models/qwen2_cls.py | 7 ++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/models/embedding/language/test_cls_models.py b/tests/models/embedding/language/test_cls_models.py index d8ca6d361f0e3..40ee49cf60742 100644 --- a/tests/models/embedding/language/test_cls_models.py +++ b/tests/models/embedding/language/test_cls_models.py @@ -21,14 +21,14 @@ def test_classification_models( model: str, dtype: str, ) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSequenceClassification) as hf_model: hf_outputs = hf_model.classify(example_prompts) - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.classify(example_prompts) - print(hf_outputs, vllm_outputs) # check logits difference diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 020af88aadd98..27eb7e8a93975 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -69,9 +69,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # hidden_states from Qwen2Model has been reduced, + # the input of score layer is not parallelized. self.score = RowParallelLinear(config.hidden_size, config.num_labels, - quant_config=quant_config) + quant_config=quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix(prefix, "score")) self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.LAST,