Skip to content

Commit

Permalink
Refactor text tokenizers (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Nov 27, 2023
1 parent e95f287 commit cd435b8
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 235 deletions.
6 changes: 6 additions & 0 deletions src/fairseq2/data/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@
from fairseq2.data.text.converters import StrSplitter as StrSplitter
from fairseq2.data.text.converters import StrToIntConverter as StrToIntConverter
from fairseq2.data.text.converters import StrToTensorConverter as StrToTensorConverter
from fairseq2.data.text.sentencepiece import (
BasicSentencePieceTokenizer as BasicSentencePieceTokenizer,
)
from fairseq2.data.text.sentencepiece import (
SentencePieceDecoder as SentencePieceDecoder,
)
from fairseq2.data.text.sentencepiece import (
SentencePieceEncoder as SentencePieceEncoder,
)
from fairseq2.data.text.sentencepiece import SentencePieceModel as SentencePieceModel
from fairseq2.data.text.sentencepiece import (
SentencePieceTokenizerBase as SentencePieceTokenizerBase,
)
from fairseq2.data.text.sentencepiece import (
vocab_info_from_sentencepiece as vocab_info_from_sentencepiece,
)
Expand Down
97 changes: 96 additions & 1 deletion src/fairseq2/data/text/sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from torch import Tensor

from fairseq2 import _DOC_MODE
from fairseq2.data.text.text_tokenizer import TextTokenDecoder, TextTokenEncoder
from fairseq2.data.text.text_tokenizer import (
TextTokenDecoder,
TextTokenEncoder,
TextTokenizer,
)
from fairseq2.data.typing import PathLike, StringLike
from fairseq2.data.vocabulary_info import VocabularyInfo
from fairseq2.typing import Device, finaloverride
Expand Down Expand Up @@ -121,6 +125,97 @@ def _set_module_name() -> None:
_set_module_name()


class SentencePieceTokenizerBase(TextTokenizer):
"""Represents an abstract base class for SentencePiece tokenizers."""

model: SentencePieceModel

def __init__(
self, pathname: PathLike, control_symbols: Optional[Sequence[StringLike]] = None
) -> None:
"""
:param pathname:
The pathname of the SentencePiece model file.
:param control_symbols:
The list of control symbols to add to the SentencePiece model.
"""
self.model = SentencePieceModel(pathname, control_symbols)

vocab_info = vocab_info_from_sentencepiece(self.model)

super().__init__(vocab_info)

@finaloverride
def create_raw_encoder(
self, *, device: Optional[Device] = None, pin_memory: bool = False
) -> SentencePieceEncoder:
return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory)

@finaloverride
def create_decoder(self) -> SentencePieceDecoder:
return SentencePieceDecoder(self.model)


class BasicSentencePieceTokenizer(SentencePieceTokenizerBase):
"""Represents a SentencePiece tokenizer that encodes text with BOS and EOS."""

def __init__(self, pathname: PathLike) -> None:
"""
:param pathname:
The pathname of the SentencePiece model file.
"""
super().__init__(pathname)

@finaloverride
def create_encoder(
self,
*,
task: Optional[str] = None,
lang: Optional[str] = None,
mode: Optional[str] = None,
device: Optional[Device] = None,
pin_memory: bool = False,
) -> SentencePieceEncoder:
"""Create a token encoder.
:param task:
Not used.
:param lang:
Not used.
:param mode:
Must be 'default' or 'prompt'. If ``None``, defaults to 'default'.
:param device:
The device on which to construct tensors.
:param pin_memory:
If ``True``, uses pinned memory while constructing tensors.
"""
if task is not None:
raise ValueError(f"`task` must be `None`, but is '{task}' instead.")

if lang is not None:
raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.")

if mode is None or mode == "default":
prefix_tokens = ["<s>"]
suffix_tokens = ["</s>"]
elif mode == "prompt":
prefix_tokens = ["<s>"]
# In prompt mode, we expect the generator to finish the sequence.
suffix_tokens = None
else:
raise ValueError(
f"`mode` must be 'default' or 'prompt', but is '{mode}' instead."
)

return SentencePieceEncoder(
self.model,
prefix_tokens=prefix_tokens,
suffix_tokens=suffix_tokens,
device=device,
pin_memory=pin_memory,
)


def vocab_info_from_sentencepiece(model: SentencePieceModel) -> VocabularyInfo:
"""Return the vocabulary information of ``model``."""
return VocabularyInfo(
Expand Down
6 changes: 3 additions & 3 deletions src/fairseq2/data/text/text_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class TextTokenizer(ABC):
"""Represents a tokenizer to encode and decode texts."""
"""Represents a tokenizer to encode and decode text."""

vocab_info: VocabularyInfo

Expand Down Expand Up @@ -80,7 +80,7 @@ def create_decoder(self) -> TextTokenDecoder:


class TextTokenEncoder(ABC):
"""Encodes texts into tokens or token indices."""
"""Encodes text into tokens or token indices."""

@abstractmethod
def __call__(self, text: StringLike) -> Tensor:
Expand Down Expand Up @@ -110,7 +110,7 @@ def suffix_indices(self) -> Optional[Tensor]:


class TextTokenDecoder(ABC):
"""Decodes texts from tokens or token indices."""
"""Decodes text from tokens or token indices."""

@abstractmethod
def __call__(self, token_indices: Tensor) -> StringLike:
Expand Down
8 changes: 4 additions & 4 deletions src/fairseq2/data/vocabulary_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ class VocabularyInfo:
"""The size of the vocabulary."""

unk_idx: Optional[int]
"""The index of the symbol that represents an unknown element."""
"""The index of the symbol that represents an unknown element (UNK)."""

bos_idx: Optional[int]
"""The index of the symbol that represents the beginning of a sequence."""
"""The index of the symbol that represents the beginning of a sequence (BOS)."""

eos_idx: Optional[int]
"""The index of the symbol that represents the end of a sequence."""
"""The index of the symbol that represents the end of a sequence (EOS)."""

pad_idx: Optional[int]
"""The index of the symbol that is used to pad a sequence."""
"""The index of the symbol that is used to pad a sequence (PAD)."""
88 changes: 3 additions & 85 deletions src/fairseq2/models/llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,93 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, final
from typing import final

from fairseq2.data.text import (
SentencePieceDecoder,
SentencePieceEncoder,
SentencePieceModel,
TextTokenDecoder,
TextTokenEncoder,
TextTokenizer,
vocab_info_from_sentencepiece,
)
from fairseq2.data.typing import PathLike
from fairseq2.typing import Device, finaloverride
from fairseq2.data.text import BasicSentencePieceTokenizer


@final
class LLaMATokenizer(TextTokenizer):
class LLaMATokenizer(BasicSentencePieceTokenizer):
"""Represents the tokenizer used by LLaMA models."""

model: SentencePieceModel

def __init__(self, pathname: PathLike) -> None:
"""
:param pathname:
The pathname of the SentencePiece model file.
"""
self.model = SentencePieceModel(pathname)

vocab_info = vocab_info_from_sentencepiece(self.model)

super().__init__(vocab_info)

@finaloverride
def create_encoder(
self,
*,
task: Optional[str] = None,
lang: Optional[str] = None,
mode: Optional[str] = None,
device: Optional[Device] = None,
pin_memory: bool = False,
) -> TextTokenEncoder:
"""Create a token encoder.
:param task:
Not used.
:param lang:
Not used.
:param mode:
Must be 'default' or 'prompt'. If ``None``, defaults to 'default'.
:param device:
The device on which to construct tensors.
:param pin_memory:
If ``True``, uses pinned memory while constructing tensors.
"""
if task is not None:
raise ValueError(f"`task` must be `None`, but is '{task}' instead.")

if lang is not None:
raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.")

if mode is None or mode == "default":
prefix_tokens = ["<s>"]
suffix_tokens = ["</s>"]
elif mode == "prompt":
prefix_tokens = ["<s>"]
# In prompt mode, we expect the generator to finish the sequence.
suffix_tokens = None
else:
raise ValueError(
f"`mode` must be 'default' or 'prompt', but is '{mode}' instead."
)

return SentencePieceEncoder(
self.model,
prefix_tokens=prefix_tokens,
suffix_tokens=suffix_tokens,
device=device,
pin_memory=pin_memory,
)

@finaloverride
def create_raw_encoder(
self, *, device: Optional[Device] = None, pin_memory: bool = False
) -> TextTokenEncoder:
return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory)

@finaloverride
def create_decoder(self) -> TextTokenDecoder:
return SentencePieceDecoder(self.model)
88 changes: 3 additions & 85 deletions src/fairseq2/models/mistral/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,93 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, final
from typing import final

from fairseq2.data.text import (
SentencePieceDecoder,
SentencePieceEncoder,
SentencePieceModel,
TextTokenDecoder,
TextTokenEncoder,
TextTokenizer,
vocab_info_from_sentencepiece,
)
from fairseq2.data.typing import PathLike
from fairseq2.typing import Device, finaloverride
from fairseq2.data.text import BasicSentencePieceTokenizer


@final
class MistralTokenizer(TextTokenizer):
class MistralTokenizer(BasicSentencePieceTokenizer):
"""Represents the tokenizer used by Mistral models."""

model: SentencePieceModel

def __init__(self, pathname: PathLike) -> None:
"""
:param pathname:
The pathname of the SentencePiece model file.
"""
self.model = SentencePieceModel(pathname)

vocab_info = vocab_info_from_sentencepiece(self.model)

super().__init__(vocab_info)

@finaloverride
def create_encoder(
self,
*,
task: Optional[str] = None,
lang: Optional[str] = None,
mode: Optional[str] = None,
device: Optional[Device] = None,
pin_memory: bool = False,
) -> TextTokenEncoder:
"""Create a token encoder.
:param task:
Not used.
:param lang:
Not used.
:param mode:
Must be 'default' or 'prompt'. If ``None``, defaults to 'default'.
:param device:
The device on which to construct tensors.
:param pin_memory:
If ``True``, uses pinned memory while constructing tensors.
"""
if task is not None:
raise ValueError(f"`task` must be `None`, but is '{task}' instead.")

if lang is not None:
raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.")

if mode is None or mode == "default":
prefix_tokens = ["<s>"]
suffix_tokens = ["</s>"]
elif mode == "prompt":
prefix_tokens = ["<s>"]
# In prompt mode, we expect the generator to finish the sequence.
suffix_tokens = None
else:
raise ValueError(
f"`mode` must be 'default' or 'prompt', but is '{mode}' instead."
)

return SentencePieceEncoder(
self.model,
prefix_tokens=prefix_tokens,
suffix_tokens=suffix_tokens,
device=device,
pin_memory=pin_memory,
)

@finaloverride
def create_raw_encoder(
self, *, device: Optional[Device] = None, pin_memory: bool = False
) -> TextTokenEncoder:
return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory)

@finaloverride
def create_decoder(self) -> TextTokenDecoder:
return SentencePieceDecoder(self.model)
Loading

0 comments on commit cd435b8

Please sign in to comment.