Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix edge cases for MistralTokenizer #9625

Merged
merged 10 commits into from
Nov 1, 2024
80 changes: 63 additions & 17 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generator, List, Optional

import pytest
from transformers import AutoTokenizer
Expand All @@ -7,11 +7,17 @@
from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally)
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer

TRUTH = [
"Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
"我很感谢你的热情"
"我很感谢你的热情",
# Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
# for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
# incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်",
]
TOKENIZERS = [
"facebook/opt-125m",
Expand All @@ -24,6 +30,7 @@
"tiiuae/falcon-7b",
"meta-llama/Llama-2-7b-hf",
"codellama/CodeLlama-7b-hf",
"mistralai/Pixtral-12B-2409",
]


Expand All @@ -49,15 +56,55 @@ def _run_incremental_decode(tokenizer, all_input_ids,
return decoded_text


@pytest.fixture
def tokenizer(tokenizer_name):
return (MistralTokenizer.from_pretrained(tokenizer_name)
if "mistral" in tokenizer_name else
AutoTokenizer.from_pretrained(tokenizer_name))


@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@pytest.mark.parametrize(
"truth",
[
# Burmese text triggers an edge-case where tokens may map to bytes with
# incomplete UTF-8 characters
"ပုံပြင်လေးပြောပြပါ",
# Using "URGENCY" since "CY" has token id 130282
"URGENCY🌶️",
])
def test_mistral_edge_case(tokenizer, truth):
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.

See https://github.com/vllm-project/vllm/pull/9625
"""
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids

decoded_text = _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
assert decoded_text == truth


@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
if "mistral" in tokenizer_name:
yield (
bool(True) if request.param else
pytest.skip("mistral doesn't support skip_special_tokens=False"))
else:
yield bool(True) if request.param else bool(False)


@pytest.mark.parametrize("truth", TRUTH)
@pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False))
def test_decode_streaming(tokenizer_id, truth, with_prompt,
skip_special_tokens):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
if with_prompt:
truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
prompt_input_ids = truth_tokens[:len(truth) // 2]
generated_input_ids = truth_tokens[len(truth) // 2:]
all_input_ids = prompt_input_ids + generated_input_ids
Expand All @@ -68,7 +115,7 @@ def test_decode_streaming(tokenizer_id, truth, with_prompt,
else:
generated = truth
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
if skip_special_tokens:
if tokenizer.bos_token_id is not None:
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
Expand Down Expand Up @@ -98,7 +145,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
enable_lora=False,
max_num_seqs=100,
max_input_length=None,
tokenizer_mode="auto",
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
trust_remote_code=False,
revision=None,
)
Expand All @@ -113,9 +160,8 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:

@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer_name: str) -> List[int]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
tokenizer) -> List[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
return complete_sequence_token_ids


Expand Down Expand Up @@ -150,7 +196,7 @@ def create_dummy_prompt_logprobs(

@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False])
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
def test_decode_sequence_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
Expand Down Expand Up @@ -208,9 +254,9 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],

# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
text = text_full[len(text_first):]

# Text for logprobs for the chosen token should be the same as the
Expand Down
64 changes: 42 additions & 22 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer)

from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam

logger = init_logger(__name__)


@dataclass
class Encoding:
Expand Down Expand Up @@ -72,20 +76,21 @@ def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
# Make sure special tokens will not raise
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE

self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
elif isinstance(tokenizer_, SentencePieceTokenizer):
self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
pass
else:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")

self._vocab = tokenizer_.vocab()
tjohnson31415 marked this conversation as resolved.
Show resolved Hide resolved
# Convert to a Dict[str, int] to match protocol, but this is a lossy
# conversion. There may be multiple token ids that decode to the same
# string due to partial UTF-8 byte sequences being converted to �
self._vocab_dict = {
token: idx
for idx, token in enumerate(self._vocab)
}
self.tokenizer = tokenizer_
self._max_token_id = max(self._vocab.values())
self._max_token_id = self.vocab_size - 1

@classmethod
def from_pretrained(cls,
Expand Down Expand Up @@ -182,7 +187,9 @@ def __call__(
return Encoding(input_ids=input_ids)

def get_vocab(self) -> Dict[str, int]:
return self._vocab
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
return self._vocab_dict

def get_added_vocab(self) -> Dict[str, int]:
# Mistral tokenizers have no added vocabulary
Expand Down Expand Up @@ -220,14 +227,20 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:
if any(isinstance(t, bytes) for t in tokens):
# we need to encode and decode all tokens again
shift = self.tokenizer.num_special_tokens
byte_tokens = [
t.encode("utf-8") if not isinstance(t, bytes) else t
for t in tokens
]
ids = [
self.tokenizer._tekken_token2id_nospecial[t] + shift
for t in byte_tokens
]

def _token_to_id(t: str):
t_bytes = t.encode("utf-8") \
if not isinstance(t, bytes) else t
try:
return shift + \
self.tokenizer._tekken_token2id_nospecial[t_bytes]
except KeyError:
logger.warning(
"Failed to convert token %s to id,"
" replacing with <unk>", t_bytes)
return self.tokenizer.unk_id

ids = [_token_to_id(t) for t in tokens]
decoded = self.tokenizer.decode(ids)
else:
decoded = "".join(tokens)
Expand All @@ -236,7 +249,13 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:

return decoded

def decode(self, ids: Union[List[int], int]) -> str:
def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
assert (
skip_special_tokens
), "Skipping special tokens is not supported for Mistral tokenizers."

if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids)
Expand All @@ -257,10 +276,11 @@ def convert_ids_to_tokens(

tokens = [self.tokenizer.id_to_piece(id) for id in ids]

if any(t.strip() == "�" for t in tokens):
# if any stripped decoded token is undefined
# because it's invalid unicode then pass bytes
if any("�" in t for t in tokens):
# if a decoded token contains the replacement character, then the
# token has an incomplete UTF-8 character so we must use bytes
# See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]

return tokens