Skip to content

Commit

Permalink
perf: use faster decoding method
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Oct 3, 2024
1 parent b00c1ef commit d9f3939
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 89 deletions.
1 change: 0 additions & 1 deletion server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Config(BaseSettings):
server_root_path (str) : the root path for the server
worker_count (int) : the number of workers to use
translator_threads (int) : the number of threads for the translator
translator_beam_size (int) : the beam size for the translator
use_cuda (bool) : whether to use CUDA for inference
translator_model_name (str) : the name of the translator model
language_detector_model_name (str) : the name of the language detector model
Expand Down
107 changes: 19 additions & 88 deletions server/features/translator.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,14 @@
from itertools import cycle
from typing import Iterable, Iterator, Self
from typing import Iterator

from ctranslate2 import Translator as CTranslator
from tokenizers import Encoding
from transformers.models.nllb.tokenization_nllb_fast import NllbTokenizerFast

from server.config import Config
from server.types import Languages
from server.utils import huggingface_download


class Tokeniser:
"""
Summary
-------
context manager for the NLLB tokeniser
Methods
-------
encode(text: str) -> list[str]
encode the input text
decode(tokens: str | list[str]) -> str
decode the input tokens
"""

__slots__ = ('tokeniser', 'lock')

def __init__(self, model_path: str):
self.tokeniser: NllbTokenizerFast = NllbTokenizerFast.from_pretrained(model_path, local_files_only=True)
self.lock = False

def __call__(self, source_language: Languages) -> Self:
self.tokeniser.src_lang = source_language
return self

def __enter__(self):
self.lock = True

def __exit__(self, *_):
self.lock = False

def encode(self, text: str) -> list[str]:
"""
Summary
-------
encode the input text
Parameters
----------
text (str) : the input text
Returns
-------
tokens (list[str]) : the tokenised input text
"""
return self.tokeniser(text).tokens()

def decode(self, tokens: str | Iterable[str]) -> str:
"""
Summary
-------
decode the input tokens
Parameters
----------
tokens (str | list[str]) : the input tokens
Returns
-------
text (str) : the decoded text
"""
return self.tokeniser.decode(
self.tokeniser.convert_tokens_to_ids(tokens), # type: ignore
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)


class Translator:
"""
Summary
Expand All @@ -95,10 +27,10 @@ class Translator:
streams the translation input from the source language to the target language using a pool of tokenisers
"""

__slots__ = ('translator', 'tokeniser_pool')
__slots__ = ('translator', 'tokeniser')

def __init__(self, translator: CTranslator, tokeniser_pool: Iterator[Tokeniser]):
self.tokeniser_pool = tokeniser_pool
def __init__(self, translator: CTranslator, tokeniser: NllbTokenizerFast):
self.tokeniser = tokeniser
self.translator = translator

def translate_generator(self, text: str, source_language: Languages, target_language: Languages) -> Iterator[str]:
Expand All @@ -118,17 +50,11 @@ def translate_generator(self, text: str, source_language: Languages, target_lang
tokens (Iterator[str]) : the translated tokens
"""

while True:
if (tokeniser := next(self.tokeniser_pool)).lock:
continue

with tokeniser(source_language):
source_tokens = tokeniser.encode(text)
encoding: Encoding = self.tokeniser(text).encodings[0] # type: ignore
results = self.translator.generate_tokens([source_language] + encoding.tokens, (target_language,))
next(results) # skip the target language token

results = self.translator.generate_tokens(source_tokens, (target_language,))
next(results) # skip the target language token

return (result.token for result in results if not result.is_last)
return (result.token for result in results if not result.is_last)

def translate(self, text: str, source_language: Languages, target_language: Languages) -> str:
"""
Expand All @@ -146,8 +72,9 @@ def translate(self, text: str, source_language: Languages, target_language: Lang
-------
translated_text (str) : the translated text
"""

return next(self.tokeniser_pool).decode(self.translate_generator(text, source_language, target_language))
return self.tokeniser.convert_tokens_to_string(
list(self.translate_generator(text, source_language, target_language))
)

def translate_stream(self, text: str, source_language: Languages, target_language: Languages) -> Iterator[str]:
"""
Expand All @@ -165,7 +92,11 @@ def translate_stream(self, text: str, source_language: Languages, target_languag
-------
translated_text (Iterator[str]) : the translated text
"""
return map(next(self.tokeniser_pool).decode, self.translate_generator(text, source_language, target_language))

return (
self.tokeniser.convert_tokens_to_string((token,)) # type: ignore
for token in self.translate_generator(text, source_language, target_language)
)


def get_translator() -> Translator:
Expand All @@ -179,12 +110,12 @@ def get_translator() -> Translator:
translator (TranslatorPool) : the translator pool
"""
model_path = huggingface_download(Config.translator_model_name)
tokeniser_pool = cycle([Tokeniser(model_path) for _ in range(Config.translator_threads)])
tokeniser: NllbTokenizerFast = NllbTokenizerFast.from_pretrained(model_path, local_files_only=True)
translator = CTranslator(
model_path,
'cuda' if Config.use_cuda else 'cpu',
compute_type='auto',
inter_threads=Config.translator_threads,
)

return Translator(translator, tokeniser_pool)
return Translator(translator, tokeniser)

0 comments on commit d9f3939

Please sign in to comment.