Skip to content

Commit

Permalink
[Bugfix] Fix tensor parallel for qwen2 classification model (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#10297)

Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
Isotr0py authored and sumitd2 committed Nov 14, 2024
1 parent 36977a1 commit 633ebd9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
6 changes: 3 additions & 3 deletions tests/models/embedding/language/test_cls_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def test_classification_models(
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)

with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)

print(hf_outputs, vllm_outputs)

# check logits difference
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/qwen2_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

# hidden_states from Qwen2Model has been reduced,
# the input of score layer is not parallelized.
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config)
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(prefix, "score"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
Expand Down

0 comments on commit 633ebd9

Please sign in to comment.