Skip to content

Commit

Permalink
[Bugfix] Fix edge cases for MistralTokenizer (vllm-project#9625)
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Prashant Gupta <[email protected]>
Co-authored-by: Prashant Gupta <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
3 people authored and sumitd2 committed Nov 14, 2024
1 parent f62b1f5 commit b119bc4
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 39 deletions.
80 changes: 63 additions & 17 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generator, List, Optional

import pytest
from transformers import AutoTokenizer
Expand All @@ -7,11 +7,17 @@
from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally)
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer

TRUTH = [
"Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
"我很感谢你的热情"
"我很感谢你的热情",
# Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
# for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
# incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်",
]
TOKENIZERS = [
"facebook/opt-125m",
Expand All @@ -24,6 +30,7 @@
"tiiuae/falcon-7b",
"meta-llama/Llama-2-7b-hf",
"codellama/CodeLlama-7b-hf",
"mistralai/Pixtral-12B-2409",
]


Expand All @@ -49,15 +56,55 @@ def _run_incremental_decode(tokenizer, all_input_ids,
return decoded_text


@pytest.fixture
def tokenizer(tokenizer_name):
return (MistralTokenizer.from_pretrained(tokenizer_name)
if "mistral" in tokenizer_name else
AutoTokenizer.from_pretrained(tokenizer_name))


@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@pytest.mark.parametrize(
"truth",
[
# Burmese text triggers an edge-case where tokens may map to bytes with
# incomplete UTF-8 characters
"ပုံပြင်လေးပြောပြပါ",
# Using "URGENCY" since "CY" has token id 130282
"URGENCY🌶️",
])
def test_mistral_edge_case(tokenizer, truth):
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.
See https://github.com/vllm-project/vllm/pull/9625
"""
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

decoded_text = _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
assert decoded_text == truth


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
if "mistral" in tokenizer_name:
yield (
bool(True) if request.param else
pytest.skip("mistral doesn't support skip_special_tokens=False"))
else:
yield bool(True) if request.param else bool(False)


@pytest.mark.parametrize("truth", TRUTH)
@pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False))
def test_decode_streaming(tokenizer_id, truth, with_prompt,
skip_special_tokens):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
if with_prompt:
truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
prompt_input_ids = truth_tokens[:len(truth) // 2]
generated_input_ids = truth_tokens[len(truth) // 2:]
all_input_ids = prompt_input_ids + generated_input_ids
Expand All @@ -68,7 +115,7 @@ def test_decode_streaming(tokenizer_id, truth, with_prompt,
else:
generated = truth
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
if skip_special_tokens:
if tokenizer.bos_token_id is not None:
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
Expand Down Expand Up @@ -98,7 +145,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
enable_lora=False,
max_num_seqs=100,
max_input_length=None,
tokenizer_mode="auto",
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
trust_remote_code=False,
revision=None,
)
Expand All @@ -113,9 +160,8 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:

@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer_name: str) -> List[int]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
tokenizer) -> List[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
return complete_sequence_token_ids


Expand Down Expand Up @@ -150,7 +196,7 @@ def create_dummy_prompt_logprobs(

@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False])
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
def test_decode_sequence_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
Expand Down Expand Up @@ -208,9 +254,9 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],

# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
text = text_full[len(text_first):]

# Text for logprobs for the chosen token should be the same as the
Expand Down
64 changes: 42 additions & 22 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer)

from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam

logger = init_logger(__name__)


@dataclass
class Encoding:
Expand Down Expand Up @@ -72,20 +76,21 @@ def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
# Make sure special tokens will not raise
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE

self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
elif isinstance(tokenizer_, SentencePieceTokenizer):
self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
pass
else:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")

self._vocab = tokenizer_.vocab()
# Convert to a Dict[str, int] to match protocol, but this is a lossy
# conversion. There may be multiple token ids that decode to the same
# string due to partial UTF-8 byte sequences being converted to �
self._vocab_dict = {
token: idx
for idx, token in enumerate(self._vocab)
}
self.tokenizer = tokenizer_
self._max_token_id = max(self._vocab.values())
self._max_token_id = self.vocab_size - 1

@classmethod
def from_pretrained(cls,
Expand Down Expand Up @@ -182,7 +187,9 @@ def __call__(
return Encoding(input_ids=input_ids)

def get_vocab(self) -> Dict[str, int]:
return self._vocab
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
return self._vocab_dict

def get_added_vocab(self) -> Dict[str, int]:
# Mistral tokenizers have no added vocabulary
Expand Down Expand Up @@ -220,14 +227,20 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:
if any(isinstance(t, bytes) for t in tokens):
# we need to encode and decode all tokens again
shift = self.tokenizer.num_special_tokens
byte_tokens = [
t.encode("utf-8") if not isinstance(t, bytes) else t
for t in tokens
]
ids = [
self.tokenizer._tekken_token2id_nospecial[t] + shift
for t in byte_tokens
]

def _token_to_id(t: str):
t_bytes = t.encode("utf-8") \
if not isinstance(t, bytes) else t
try:
return shift + \
self.tokenizer._tekken_token2id_nospecial[t_bytes]
except KeyError:
logger.warning(
"Failed to convert token %s to id,"
" replacing with <unk>", t_bytes)
return self.tokenizer.unk_id

ids = [_token_to_id(t) for t in tokens]
decoded = self.tokenizer.decode(ids)
else:
decoded = "".join(tokens)
Expand All @@ -236,7 +249,13 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:

return decoded

def decode(self, ids: Union[List[int], int]) -> str:
def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
assert (
skip_special_tokens
), "Skipping special tokens is not supported for Mistral tokenizers."

if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids)
Expand All @@ -257,10 +276,11 @@ def convert_ids_to_tokens(

tokens = [self.tokenizer.id_to_piece(id) for id in ids]

if any(t.strip() == "�" for t in tokens):
# if any stripped decoded token is undefined
# because it's invalid unicode then pass bytes
if any("�" in t for t in tokens):
# if a decoded token contains the replacement character, then the
# token has an incomplete UTF-8 character so we must use bytes
# See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]

return tokens

0 comments on commit b119bc4

Please sign in to comment.