-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
107 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from logging import exception | ||
import os | ||
import pytest | ||
|
||
from wenet.text.hugging_face_tokenizer import HuggingFaceTokenizer | ||
|
||
try: | ||
import transformers # noqa | ||
except ImportError: | ||
os.system('pip install --no-input transformers') | ||
import transformers # noqa | ||
|
||
|
||
@pytest.fixture(params=["bert-base-cased"]) | ||
def hugging_face_tokenizer(request): | ||
return HuggingFaceTokenizer(request.param) | ||
|
||
|
||
def test_text2tokens(hugging_face_tokenizer: HuggingFaceTokenizer): | ||
tokenizer = hugging_face_tokenizer | ||
text = "hello wenet very cool!" | ||
expected = ['hello', 'we', '##net', 'very', 'cool', '!'] | ||
assert all(h == r for h, r in zip(tokenizer.text2tokens(text), expected)) | ||
|
||
|
||
def test_tokens2text(hugging_face_tokenizer: HuggingFaceTokenizer): | ||
tokenizer = hugging_face_tokenizer | ||
inputs = ['hello', 'we', '##net', 'very', 'cool', '!'] | ||
expected = "hello wenet very cool!" | ||
|
||
result = tokenizer.tokens2text(inputs) | ||
assert result == expected | ||
|
||
|
||
def test_tokens2ids(hugging_face_tokenizer: HuggingFaceTokenizer): | ||
tokenizer = hugging_face_tokenizer | ||
inputs = ['hello', 'we', '##net', 'very', 'cool', '!'] | ||
expected = [19082, 1195, 6097, 1304, 4348, 106] | ||
tokens = tokenizer.tokens2ids(inputs) | ||
assert len(tokens) == len(expected) | ||
assert all(h == r for (h, r) in zip(tokens, expected)) | ||
|
||
|
||
def test_ids2tokens(hugging_face_tokenizer: HuggingFaceTokenizer): | ||
tokenizer = hugging_face_tokenizer | ||
ids = [19082, 1195, 6097, 1304, 4348, 106] | ||
expected = ['hello', 'we', '##net', 'very', 'cool', '!'] | ||
results = tokenizer.ids2tokens(ids) | ||
assert len(results) == len(expected) | ||
assert all(h == r for (h, r) in zip(results, expected)) | ||
|
||
|
||
def test_tokenize(hugging_face_tokenizer: HuggingFaceTokenizer): | ||
tokenizer = hugging_face_tokenizer | ||
|
||
text = "hello wenet very cool!" | ||
ids = [19082, 1195, 6097, 1304, 4348, 106] | ||
tokens = ['hello', 'we', '##net', 'very', 'cool', '!'] | ||
|
||
r_tokens, r_ids = tokenizer.tokenize(text) | ||
assert len(r_tokens) == len(tokens) | ||
assert all(h == r for (h, r) in zip(r_tokens, tokens)) | ||
assert len(r_ids) == len(ids) | ||
assert all(h == r for (h, r) in zip(r_ids, ids)) | ||
|
||
|
||
def test_detokenize(hugging_face_tokenizer: HuggingFaceTokenizer): | ||
tokenizer = hugging_face_tokenizer | ||
text = "hello wenet very cool!" | ||
ids = [19082, 1195, 6097, 1304, 4348, 106] | ||
tokens = ['hello', 'we', '##net', 'very', 'cool', '!'] | ||
|
||
r_text, r_tokens = tokenizer.detokenize(ids) | ||
assert r_text == text | ||
assert len(r_tokens) == len(tokens) | ||
assert all(h == r for (h, r) in zip(r_tokens, tokens)) | ||
|
||
|
||
def test_vocab_size(hugging_face_tokenizer: HuggingFaceTokenizer): | ||
assert hugging_face_tokenizer.vocab_size() == 28996 | ||
assert hugging_face_tokenizer.vocab_size() == len( | ||
hugging_face_tokenizer.symbol_table) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters