Skip to content

Commit

Permalink
[text] add tongyi unit test && change token type to T
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Dec 1, 2023
1 parent 891f8fd commit 77a265f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 22 deletions.
15 changes: 15 additions & 0 deletions test/wenet/text/test_hugging_face_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,18 @@ def test_vocab_size(hugging_face_tokenizer: HuggingFaceTokenizer):
assert hugging_face_tokenizer.vocab_size() == 28996
assert hugging_face_tokenizer.vocab_size() == len(
hugging_face_tokenizer.symbol_table)


def test_tongyi_tokenizer():
# NOTE(Mddct): tongyi need extra matplotlib package
os.system('pip install --no-input matplotlib')
model_dir = 'Qwen/Qwen-Audio-Chat'
tongyi_tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True)
tokenizer = HuggingFaceTokenizer(model_dir, trust_remote_code=True)
text = "from transformers import AutoModelForCausalLM, AutoTokenizer"
tongyi_result = tongyi_tokenizer.tokenize(text)
result, _ = tokenizer.tokenize(text)

assert len(result) == len(tongyi_result)
assert all(h == r for (h, r) in zip(result, tongyi_result))
18 changes: 10 additions & 8 deletions wenet/text/base_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

T = Union[str, bytes]


class BaseTokenizer(ABC):

def tokenize(self, line: str) -> Tuple[List[str], List[int]]:
def tokenize(self, line: str) -> Tuple[List[T], List[int]]:
tokens = self.text2tokens(line)
ids = self.tokens2ids(tokens)
return tokens, ids

def detokenize(self, ids: List[int]) -> Tuple[str, List[str]]:
def detokenize(self, ids: List[int]) -> Tuple[str, List[T]]:
tokens = self.ids2tokens(ids)
text = self.tokens2text(tokens)
return text, tokens

@abstractmethod
def text2tokens(self, line: str) -> List[str]:
def text2tokens(self, line: str) -> List[T]:
raise NotImplementedError("abstract method")

@abstractmethod
def tokens2text(self, tokens: List[str]) -> str:
def tokens2text(self, tokens: List[T]) -> str:
raise NotImplementedError("abstract method")

@abstractmethod
def tokens2ids(self, tokens: List[str]) -> List[int]:
def tokens2ids(self, tokens: List[T]) -> List[int]:
raise NotImplementedError("abstract method")

@abstractmethod
def ids2tokens(self, ids: List[int]) -> List[str]:
def ids2tokens(self, ids: List[int]) -> List[T]:
raise NotImplementedError("abstract method")

@abstractmethod
def vocab_size(self) -> int:
raise NotImplementedError("abstract method")

@abstractproperty
def symbol_table(self) -> Dict[str, int]:
def symbol_table(self) -> Dict[T, int]:
raise NotImplementedError("abstract method")
28 changes: 14 additions & 14 deletions wenet/text/hugging_face_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from os import PathLike
from typing import Dict, List, Union
from wenet.text.base_tokenizer import BaseTokenizer
from wenet.text.base_tokenizer import BaseTokenizer, T as Type


class HuggingFaceTokenizer(BaseTokenizer):

def __init__(self, model: Union[str, PathLike]) -> None:
def __init__(self, model: Union[str, PathLike], *args, **kwargs) -> None:
# NOTE(Mddct): don't build here, pickle issues
self.model = model
self.tokenizer = None

self.args = args
self.kwargs = kwargs

def __getstate__(self):
state = self.__dict__.copy()
del state['tokenizer']
Expand All @@ -23,27 +26,24 @@ def __setstate__(self, state):
def _build_hugging_face(self):
from transformers import AutoTokenizer
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
self.t2i = self.tokenizer.vocab
self.i2t = {}
for (i, token) in self.t2i.items():
self.i2t[i] = token
assert len(self.t2i) == len(self.i2t)

def text2tokens(self, line: str) -> List[str]:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model, **self.kwargs)
self.t2i = self.tokenizer.get_vocab()

def text2tokens(self, line: str) -> List[Type]:
self._build_hugging_face()
return self.tokenizer.tokenize(line)

def tokens2text(self, tokens: List[str]) -> str:
def tokens2text(self, tokens: List[Type]) -> str:
self._build_hugging_face()
ids = self.tokens2ids(tokens)
return self.tokenizer.decode(ids)

def tokens2ids(self, tokens: List[str]) -> List[int]:
def tokens2ids(self, tokens: List[Type]) -> List[int]:
self._build_hugging_face()
return self.tokenizer.convert_tokens_to_ids(tokens)

def ids2tokens(self, ids: List[int]) -> List[str]:
def ids2tokens(self, ids: List[int]) -> List[Type]:
self._build_hugging_face()
return self.tokenizer.convert_ids_to_tokens(ids)

Expand All @@ -53,6 +53,6 @@ def vocab_size(self) -> int:
return len(self.tokenizer)

@property
def symbol_table(self) -> Dict[str, int]:
def symbol_table(self) -> Dict[Type, int]:
self._build_hugging_face()
return self.t2i

0 comments on commit 77a265f

Please sign in to comment.