Skip to content

Commit

Permalink
[text] unit pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 30, 2023
1 parent 60008ce commit e3d5c6e
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 14 deletions.
82 changes: 82 additions & 0 deletions test/wenet/text/test_hugging_face_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from logging import exception
import os
import pytest

from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer

try:
import transformers # noqa
except ImportError:
os.system('pip install --no-input transformers')
import transformers # noqa


@pytest.fixture(params=["bert-base-cased"])
def hugging_face_tokenizer(request):
return HuggingFaceTokenizer(request.param)


def test_text2tokens(hugging_face_tokenizer: HuggingFaceTokenizer):
tokenizer = hugging_face_tokenizer
text = "hello wenet very cool!"
expected = ['hello', 'we', '##net', 'very', 'cool', '!']
assert all(h == r for h, r in zip(tokenizer.text2tokens(text), expected))


def test_tokens2text(hugging_face_tokenizer: HuggingFaceTokenizer):
tokenizer = hugging_face_tokenizer
inputs = ['hello', 'we', '##net', 'very', 'cool', '!']
expected = "hello wenet very cool!"

result = tokenizer.tokens2text(inputs)
assert result == expected


def test_tokens2ids(hugging_face_tokenizer: HuggingFaceTokenizer):
tokenizer = hugging_face_tokenizer
inputs = ['hello', 'we', '##net', 'very', 'cool', '!']
expected = [19082, 1195, 6097, 1304, 4348, 106]
tokens = tokenizer.tokens2ids(inputs)
assert len(tokens) == len(expected)
assert all(h == r for (h, r) in zip(tokens, expected))


def test_ids2tokens(hugging_face_tokenizer: HuggingFaceTokenizer):
tokenizer = hugging_face_tokenizer
ids = [19082, 1195, 6097, 1304, 4348, 106]
expected = ['hello', 'we', '##net', 'very', 'cool', '!']
results = tokenizer.ids2tokens(ids)
assert len(results) == len(expected)
assert all(h == r for (h, r) in zip(results, expected))


def test_tokenize(hugging_face_tokenizer: HuggingFaceTokenizer):
tokenizer = hugging_face_tokenizer

text = "hello wenet very cool!"
ids = [19082, 1195, 6097, 1304, 4348, 106]
tokens = ['hello', 'we', '##net', 'very', 'cool', '!']

r_tokens, r_ids = tokenizer.tokenize(text)
assert len(r_tokens) == len(tokens)
assert all(h == r for (h, r) in zip(r_tokens, tokens))
assert len(r_ids) == len(ids)
assert all(h == r for (h, r) in zip(r_ids, ids))


def test_detokenize(hugging_face_tokenizer: HuggingFaceTokenizer):
tokenizer = hugging_face_tokenizer
text = "hello wenet very cool!"
ids = [19082, 1195, 6097, 1304, 4348, 106]
tokens = ['hello', 'we', '##net', 'very', 'cool', '!']

r_text, r_tokens = tokenizer.detokenize(ids)
assert r_text == text
assert len(r_tokens) == len(tokens)
assert all(h == r for (h, r) in zip(r_tokens, tokens))


def test_vocab_size(hugging_face_tokenizer: HuggingFaceTokenizer):
assert hugging_face_tokenizer.vocab_size() == 28996
assert hugging_face_tokenizer.vocab_size() == len(
hugging_face_tokenizer.symbol_table)
22 changes: 20 additions & 2 deletions test/wenet/text/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from wenet.text.base_tokenizer import BaseTokenizer

from wenet.text.bpe_tokenizer import BpeTokenizer
from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer
from wenet.text.whisper_tokenizer import WhisperTokenizer


Expand Down Expand Up @@ -47,7 +48,7 @@ def test_bpe_tokenzier_parallel():
symbol_table_path = "test/resources/librispeech.words.txt"
bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel"

inputs = ["WENR IS SIMPLE", "GOOD"]
inputs = ["WENT IS SIMPLE", "GOOD"]
tokenizer = BpeTokenizer(bpe_model, symbol_table_path)
partial_tokenize = partial(consistency, tokenizer)
with Pool(processes=len(inputs)) as pool:
Expand All @@ -63,7 +64,7 @@ def test_bpe_tokenizer_parallel_after_property():
symbol_table_path = "test/resources/librispeech.words.txt"
bpe_model = "test/resources/librispeech.train_960_unigram5000.bpemodel"

inputs = ["WENR IS SIMPLE", "GOOD"]
inputs = ["WENT IS SIMPLE", "GOOD"]
tokenizer = BpeTokenizer(bpe_model, symbol_table_path)
_ = tokenizer.vocab_size
_ = tokenizer.symbol_table
Expand All @@ -76,3 +77,20 @@ def test_bpe_tokenizer_parallel_after_property():
results.sort()

assert all(h == r for (h, r) in zip(results, inputs))


def test_hugging_face_tokenizer():
tokenizer = HuggingFaceTokenizer("bert-base-cased")

_ = tokenizer.vocab_size
_ = tokenizer.symbol_table

inputs = ["wenet is simple", "good"]
partial_tokenize = partial(consistency, tokenizer)
with Pool(processes=len(inputs)) as pool:
results = pool.map(partial_tokenize, inputs)

inputs.sort()
results.sort()

assert all(h == r for (h, r) in zip(results, inputs))
17 changes: 5 additions & 12 deletions wenet/text/hugging_face_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os import PathLike
from typing import Union
from typing import Dict, List, Union
from wenet.text.base_tokenizer import BaseTokenizer


Expand All @@ -24,17 +24,10 @@ def _build_hugging_face(self):
from transformers import AutoTokenizer
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
self.t2i = {}
self.t2i = self.tokenizer.vocab
self.i2t = {}
for i in range(self.tokenizer.encoding.n_vocab):
unit = str(
self.tokenizer.encoding.decode_single_token_bytes(i))
if len(unit) == 0:
unit = str(i)
unit = unit.replace(" ", "<space>")
# unit = bytes(unit, 'utf-8')
self.t2i[unit] = i
self.i2t[i] = unit
for (i, token) in self.t2i.items():
self.i2t[i] = token
assert len(self.t2i) == len(self.i2t)

def text2tokens(self, line: str) -> List[str]:
Expand All @@ -61,5 +54,5 @@ def vocab_size(self) -> int:

@property
def symbol_table(self) -> Dict[str, int]:
self._build_tiktoken()
self._build_hugging_face()
return self.t2i

0 comments on commit e3d5c6e

Please sign in to comment.