Skip to content

Commit

Permalink
rename load_tokenizer and ensure dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed May 22, 2024
1 parent 8a7eda5 commit f35e8e7
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/datatrove/pipeline/filters/gopher_quality_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter
from datatrove.utils.text import PUNCTUATION_SET
from datatrove.utils.word_tokenizers import load_tokenizer
from datatrove.utils.word_tokenizers import load_word_tokenizer


STOP_WORDS = ["the", "be", "to", "of", "and", "that", "have", "with"]
Expand Down Expand Up @@ -70,7 +70,7 @@ def filter(self, doc: Document) -> bool | tuple[bool, str]:
"""
text = doc.text
tokenizer = load_tokenizer(self.language)
tokenizer = load_word_tokenizer(self.language)
words = tokenizer.word_tokenize(text)
n_words = len(words)

Expand Down
67 changes: 50 additions & 17 deletions src/datatrove/utils/word_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def simple_span_tokenize(text: str, sents: list[str]) -> Iterator[tuple[int, int
start_char = text.index(sent, start_index)
end_char = start_char + len(sent)
start_index = end_char
yield (start_char, end_char)
yield start_char, end_char


class WordTokenizer(ABC):
Expand Down Expand Up @@ -67,6 +67,10 @@ class SpaCyTokenizer(WordTokenizer):
def __init__(self, spacy_language: str, config=None):
super().__init__()
check_required_dependencies(f"{spacy_language} word tokenizer", ["spacy"])
if spacy_language == "vi":
check_required_dependencies(f"{spacy_language} word tokenizer", ["pyvi"])
elif spacy_language == "zh":
check_required_dependencies(f"{spacy_language} word tokenizer", ["jieba"])
self.spacy_language = spacy_language
self.config = config
self._tokenizer = None
Expand All @@ -80,7 +84,7 @@ def tokenizer(self):
self._tokenizer = spacy.blank(self.spacy_language)
else:
self._tokenizer = spacy.blank(self.spacy_language, config=self.config)
self.tokenizer.add_pipe("sentencizer")
self._tokenizer.add_pipe("sentencizer")
return self._tokenizer

def word_tokenize(self, text: str) -> list[str]:
Expand All @@ -102,29 +106,47 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:

class StanzaTokenizer(WordTokenizer):
def __init__(self, stanza_language: str, **stanza_kwargs):
import stanza
from stanza.pipeline.core import DownloadMethod
super().__init__()
check_required_dependencies(f"{stanza_language} word tokenizer", ["stanza"])
self.stanza_language = stanza_language
self.stanza_kwargs = stanza_kwargs
self._tokenizer = None

self._tokenizer = stanza.Pipeline(
stanza_language, processors="tokenize", download_method=DownloadMethod.REUSE_RESOURCES, **stanza_kwargs
)
@property
def tokenizer(self):
if not self._tokenizer:
import stanza
from stanza.pipeline.core import DownloadMethod

self._tokenizer = stanza.Pipeline(
self.stanza_language,
processors="tokenize",
download_method=DownloadMethod.REUSE_RESOURCES,
**self.stanza_kwargs,
)

return self._tokenizer

def word_tokenize(self, text: str) -> list[str]:
doc = self._tokenizer(text)
doc = self.tokenizer(text)
tokens = [token.text for sentence in doc.sentences for token in sentence.tokens]
return strip_strings(tokens)

def sent_tokenize(self, text: str) -> list[str]:
doc = self._tokenizer(text)
doc = self.tokenizer(text)
sents = [sentence.text for sentence in doc.sentences]
return strip_strings(sents)

def span_tokenize(self, text: str) -> list[tuple[int, int]]:
doc = self._tokenizer(text)
doc = self.tokenizer(text)
return [(sent.tokens[0].start_char, sent.tokens[-1].end_char) for sent in doc.sentences]


class ThaiTokenizer(WordTokenizer):
def __init__(self):
super().__init__()
check_required_dependencies("th word tokenizer", ["pythainlp"])

def word_tokenize(self, text: str) -> list[str]:
from pythainlp.tokenize import word_tokenize as th_word_tokenize

Expand All @@ -144,7 +166,9 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:

class IndicNLPTokenizer(WordTokenizer):
def __init__(self, language: str):
super().__init__()
self.language = language
check_required_dependencies(f"{language} word tokenizer", [("indicnlp", "indic-nlp-library")])

def word_tokenize(self, text) -> list[str]:
from indicnlp.tokenize.indic_tokenize import trivial_tokenize as indicnlp_trivial_tokenize
Expand All @@ -165,23 +189,32 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:

class KiwiTokenizer(WordTokenizer):
def __init__(self, model_type="sbg"):
from kiwipiepy import Kiwi
super().__init__()
check_required_dependencies("ko word tokenizer", ["kiwipiepy"])
self.model_type = model_type
self._tokenizer = None

self.kiwi = Kiwi(model_type=model_type)
@property
def tokenizer(self):
if not self._tokenizer:
from kiwipiepy import Kiwi

self._tokenizer = Kiwi(model_type=self.model_type)
return self._tokenizer

def word_tokenize(self, text: str) -> list[str]:
tokens = [token.form for token in self.kiwi.tokenize(text)]
tokens = [token.form for token in self.tokenizer.tokenize(text)]
return strip_strings(tokens)

def sent_tokenize(self, text: str) -> list[str]:
sents = [sent.text for sent in self.kiwi.split_into_sents(text)]
sents = [sent.text for sent in self.tokenizer.split_into_sents(text)]
return strip_strings(sents)

def span_tokenize(self, text: str) -> list[tuple[int, int]]:
return [(sent.start, sent.end) for sent in self.kiwi.split_into_sents(text)]
return [(sent.start, sent.end) for sent in self.tokenizer.split_into_sents(text)]


# If you know a better tokenizer or better proxy language, please submit a change
# If you know a better tokenizer or better proxy language, please submit a PR
WORD_TOKENIZER_FACTORY: dict[str, Callable[[], WordTokenizer]] = {
Languages.english: lambda: NLTKTokenizer("english"),
Languages.korean: lambda: KiwiTokenizer(),
Expand Down Expand Up @@ -287,7 +320,7 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
WORD_TOKENIZER_CACHE: dict[str, WordTokenizer] = {}


def load_tokenizer(language: str) -> WordTokenizer:
def load_word_tokenizer(language: str) -> WordTokenizer:
if language not in WORD_TOKENIZER_CACHE:
if language not in WORD_TOKENIZER_FACTORY:
raise ValueError(f"Language '{language}' doesn't have a tokenizer.")
Expand Down
10 changes: 5 additions & 5 deletions tests/pipeline/test_word_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from nltk.tokenize import word_tokenize

from datatrove.utils.word_tokenizers import WORD_TOKENIZER_FACTORY, load_tokenizer
from datatrove.utils.word_tokenizers import WORD_TOKENIZER_FACTORY, load_word_tokenizer


SAMPLE_TEXT = (
Expand All @@ -15,23 +15,23 @@
class TestWordTokenizers(unittest.TestCase):
def test_word_tokenizers(self):
for language in WORD_TOKENIZER_FACTORY.keys():
tokenizer = load_tokenizer(language)
tokenizer = load_word_tokenizer(language)
tokens = tokenizer.word_tokenize(SAMPLE_TEXT)
assert len(tokens) >= 1, f"'{language}' tokenizer doesn't output tokens"
is_stripped = [token == token.strip() for token in tokens]
assert all(is_stripped), f"'{language}' tokenizer tokens contain whitespaces"

def test_sent_tokenizers(self):
for language in WORD_TOKENIZER_FACTORY.keys():
tokenizer = load_tokenizer(language)
tokenizer = load_word_tokenizer(language)
sents = tokenizer.sent_tokenize(SAMPLE_TEXT)
assert len(sents) >= 1, f"'{language}' tokenizer doesn't output sentences"
is_stripped = [sent == sent.strip() for sent in sents]
assert all(is_stripped), f"'{language}' tokenizer sentences contain whitespaces"

def test_span_tokenizers(self):
for language in WORD_TOKENIZER_FACTORY.keys():
tokenizer = load_tokenizer(language)
tokenizer = load_word_tokenizer(language)
sents = tokenizer.sent_tokenize(SAMPLE_TEXT)
spans = tokenizer.span_tokenize(SAMPLE_TEXT)
assert len(spans) >= 1, f"'{language}' tokenizer doesn't output spans"
Expand All @@ -41,7 +41,7 @@ def test_span_tokenizers(self):
def test_english_tokenizer(self):
nltk_words = word_tokenize(SAMPLE_TEXT, language="english")

en_tokenizer = load_tokenizer("en")
en_tokenizer = load_word_tokenizer("en")
tokenizer_words = en_tokenizer.word_tokenize(SAMPLE_TEXT)

self.assertEqual(nltk_words, tokenizer_words, "NLTK tokenizer and multilingual tokenizer differ")

0 comments on commit f35e8e7

Please sign in to comment.