Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI/Build] Expand Model Testing #4510

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ pytest-asyncio
pytest-rerunfailures
pytest-shard
httpx
einops # required for MPT
einops # required for MPT testing
bitsandbytes # required for baichuan testing
transformers_stream_generator # required for qwen1 testing
sentencepiece # required for sentencepiece testing
requests
ray
peft
Expand Down
65 changes: 64 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def generate_greedy_logprobs(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if self.model.get_output_embeddings().bias is not None:
if getattr("bias", self.model.get_output_embeddings(),
None) is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
Expand All @@ -272,6 +273,68 @@ def generate_greedy_logprobs(
all_logprobs.append(seq_logprobs)
return all_logprobs

def generate_greedy_logprobs_limit(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
all_logprobs = []
all_output_ids = []
all_output_strs = []

for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate(
input_ids.cuda(),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
)

seq_logprobs = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if getattr(self.model.get_output_embeddings(), "bias",
None) is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
seq_logprobs.append(logprobs)

# convert to dict
seq_logprobs_lst = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)

tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()

seq_logprobs_lst.append(tok_logprobs_dct)

all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))

outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]

def __del__(self):
del self.model
cleanup()
Expand Down
File renamed without changes.
92 changes: 92 additions & 0 deletions tests/models/test_models_medium_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Compares the outputs of hf vs vllm for medium sized models.

There is not bitwise correctness for fp16 inference.
As a result, in this test, we just confirm that the top selected tokens
of the models are in the top 3 selections of each other.

Run `pytest tests/models/test_models_medium_logprobs.py` --forked.
"""
import pytest

from tests.models.utils import check_logprobs_close

SKIPPED_MODEL_REASON = {
"THUDM/chatglm3-6b": "Hf side test broken",
"allenai/OLMo-1B": "Hf side requirement conflict (req torch 2.2)",
"xverse/XVERSE-7B": "Hf side test broken"
}

MAX_MODEL_LEN = 1024

MODELS = [
"baichuan-inc/Baichuan2-7B-Chat",
"bigscience/bloom-560m",
"THUDM/chatglm3-6b",
# command-r -> not tested
# dbrx -> not tested
"Deci/DeciLM-7B-instruct",
"deepseek-ai/deepseek-coder-1.3b-instruct",
"tiiuae/falcon-7b-instruct",
"google/gemma-1.1-2b-it",
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-1.4b",
"internlm/internlm2-chat-7b",
# jais -> not tested
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"openbmb/MiniCPM-2B-128k",
# mixtral -> not tested
# mixtral-quant -> not tested
"mosaicml/mpt-7b-instruct",
"allenai/OLMo-1B",
"facebook/opt-125m",
# orion -> not tested
"microsoft/phi-2",
"Qwen/Qwen-1_8B",
"Qwen/Qwen1.5-1.8B",
# qwen2 moe -> not tested
"stabilityai/stablelm-2-1_6b-chat",
"bigcode/starcoder2-3b",
"xverse/XVERSE-7B",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
hf_runner,
example_prompts,
model,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
# Skip if explicitly skipped.
if model in SKIPPED_MODEL_REASON:
pytest.skip(reason=SKIPPED_MODEL_REASON[model])
# Run HF.
hf_model = hf_runner(model_name=model, dtype=dtype)
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
del hf_model

# Run vLLM.
vllm_model = vllm_runner(model_name=model,
enforce_eager=True,
dtype=dtype,
max_model_len=MAX_MODEL_LEN)
vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
del vllm_model

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,43 @@
"""
import pytest

SKIPPED_MODEL_REASON = {
"allenai/OLMo-1B": "Hf side requirements",
"google/gemma-1.1-2b-it": "No bitwise correctness for fp32",
"Qwen/Qwen-1_8B": "No bitwise correctness for fp32"
}

MODELS = [
"facebook/opt-125m",
# baichuan -> tested in medium
"bigscience/bloom-560m",
# chatglm -> tested in medium
# command-r -> not tested
# dbrx -> not tested
# decilm -> tested in medium
"deepseek-ai/deepseek-coder-1.3b-instruct",
# falcon -> tested in medium
"google/gemma-1.1-2b-it",
"gpt2",
"bigcode/tiny_starcoder_py",
# gpt-j -> tested in medium
"EleutherAI/pythia-70m",
"bigscience/bloom-560m", # Testing alibi slopes.
# internlm2 -> tested in medium
# jais -> not tested
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"openbmb/MiniCPM-2B-128k",
# mixtral -> not tested
# mixtral-quant -> not tested
# mpt -> tested in medium
"allenai/OLMo-1B",
"facebook/opt-125m",
# orion -> tested in medium
"microsoft/phi-2",
"stabilityai/stablelm-3b-4e1t",
# "allenai/OLMo-1B", # Broken
"Qwen/Qwen-1_8B",
"Qwen/Qwen1.5-1.8B",
# qwen2 moe -> not tested
"stabilityai/stablelm-2-1_6b-chat",
"bigcode/starcoder2-3b",
# xverse -> tested in medium
]
Comment on lines 17 to 48
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several of these left in are still tested in medium. Obviously this test is more strict as it is measuring float exact token match, but do we need to duplicate models like bloom, deepseek, starcoder?



Expand All @@ -31,6 +58,10 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
# Skip if explicitly skipped.
if model in SKIPPED_MODEL_REASON:
pytest.skip(reason=SKIPPED_MODEL_REASON[model])

# To pass the small model tests, we need full precision.
assert dtype == "float"

Expand Down
Loading