diff --git a/server/config.py b/server/config.py index 2384a26..3cafac6 100644 --- a/server/config.py +++ b/server/config.py @@ -33,7 +33,6 @@ class Config(BaseSettings): server_root_path (str) : the root path for the server worker_count (int) : the number of workers to use translator_threads (int) : the number of threads for the translator - translator_beam_size (int) : the beam size for the translator use_cuda (bool) : whether to use CUDA for inference translator_model_name (str) : the name of the translator model language_detector_model_name (str) : the name of the language detector model diff --git a/server/features/translator.py b/server/features/translator.py index 0b49931..1f84287 100644 --- a/server/features/translator.py +++ b/server/features/translator.py @@ -1,7 +1,7 @@ -from itertools import cycle -from typing import Iterable, Iterator, Self +from typing import Iterator from ctranslate2 import Translator as CTranslator +from tokenizers import Encoding from transformers.models.nllb.tokenization_nllb_fast import NllbTokenizerFast from server.config import Config @@ -9,74 +9,6 @@ from server.utils import huggingface_download -class Tokeniser: - """ - Summary - ------- - context manager for the NLLB tokeniser - - Methods - ------- - encode(text: str) -> list[str] - encode the input text - - decode(tokens: str | list[str]) -> str - decode the input tokens - """ - - __slots__ = ('tokeniser', 'lock') - - def __init__(self, model_path: str): - self.tokeniser: NllbTokenizerFast = NllbTokenizerFast.from_pretrained(model_path, local_files_only=True) - self.lock = False - - def __call__(self, source_language: Languages) -> Self: - self.tokeniser.src_lang = source_language - return self - - def __enter__(self): - self.lock = True - - def __exit__(self, *_): - self.lock = False - - def encode(self, text: str) -> list[str]: - """ - Summary - ------- - encode the input text - - Parameters - ---------- - text (str) : the input text - - Returns - ------- - tokens (list[str]) : the tokenised input text - """ - return self.tokeniser(text).tokens() - - def decode(self, tokens: str | Iterable[str]) -> str: - """ - Summary - ------- - decode the input tokens - - Parameters - ---------- - tokens (str | list[str]) : the input tokens - - Returns - ------- - text (str) : the decoded text - """ - return self.tokeniser.decode( - self.tokeniser.convert_tokens_to_ids(tokens), # type: ignore - skip_special_tokens=True, - clean_up_tokenization_spaces=False, - ) - - class Translator: """ Summary @@ -95,10 +27,10 @@ class Translator: streams the translation input from the source language to the target language using a pool of tokenisers """ - __slots__ = ('translator', 'tokeniser_pool') + __slots__ = ('translator', 'tokeniser') - def __init__(self, translator: CTranslator, tokeniser_pool: Iterator[Tokeniser]): - self.tokeniser_pool = tokeniser_pool + def __init__(self, translator: CTranslator, tokeniser: NllbTokenizerFast): + self.tokeniser = tokeniser self.translator = translator def translate_generator(self, text: str, source_language: Languages, target_language: Languages) -> Iterator[str]: @@ -118,17 +50,11 @@ def translate_generator(self, text: str, source_language: Languages, target_lang tokens (Iterator[str]) : the translated tokens """ - while True: - if (tokeniser := next(self.tokeniser_pool)).lock: - continue - - with tokeniser(source_language): - source_tokens = tokeniser.encode(text) + encoding: Encoding = self.tokeniser(text).encodings[0] # type: ignore + results = self.translator.generate_tokens([source_language] + encoding.tokens, (target_language,)) + next(results) # skip the target language token - results = self.translator.generate_tokens(source_tokens, (target_language,)) - next(results) # skip the target language token - - return (result.token for result in results if not result.is_last) + return (result.token for result in results if not result.is_last) def translate(self, text: str, source_language: Languages, target_language: Languages) -> str: """ @@ -146,8 +72,9 @@ def translate(self, text: str, source_language: Languages, target_language: Lang ------- translated_text (str) : the translated text """ - - return next(self.tokeniser_pool).decode(self.translate_generator(text, source_language, target_language)) + return self.tokeniser.convert_tokens_to_string( + list(self.translate_generator(text, source_language, target_language)) + ) def translate_stream(self, text: str, source_language: Languages, target_language: Languages) -> Iterator[str]: """ @@ -165,7 +92,11 @@ def translate_stream(self, text: str, source_language: Languages, target_languag ------- translated_text (Iterator[str]) : the translated text """ - return map(next(self.tokeniser_pool).decode, self.translate_generator(text, source_language, target_language)) + + return ( + self.tokeniser.convert_tokens_to_string((token,)) # type: ignore + for token in self.translate_generator(text, source_language, target_language) + ) def get_translator() -> Translator: @@ -179,7 +110,7 @@ def get_translator() -> Translator: translator (TranslatorPool) : the translator pool """ model_path = huggingface_download(Config.translator_model_name) - tokeniser_pool = cycle([Tokeniser(model_path) for _ in range(Config.translator_threads)]) + tokeniser: NllbTokenizerFast = NllbTokenizerFast.from_pretrained(model_path, local_files_only=True) translator = CTranslator( model_path, 'cuda' if Config.use_cuda else 'cpu', @@ -187,4 +118,4 @@ def get_translator() -> Translator: inter_threads=Config.translator_threads, ) - return Translator(translator, tokeniser_pool) + return Translator(translator, tokeniser)