Skip to content

Commit

Permalink
Merge pull request #2724 from tomaarsen/improve_typing
Browse files Browse the repository at this point in the history
[`typing`] Improve typing for many functions & add `py.typed` to satisfy `mypy`
  • Loading branch information
tomaarsen authored Jun 6, 2024
2 parents fc1b7d0 + 936f283 commit b5e98e1
Show file tree
Hide file tree
Showing 40 changed files with 216 additions and 188 deletions.
6 changes: 3 additions & 3 deletions sentence_transformers/LoggingHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@


class LoggingHandler(logging.Handler):
def __init__(self, level=logging.NOTSET):
def __init__(self, level=logging.NOTSET) -> None:
super().__init__(level)

def emit(self, record):
def emit(self, record) -> None:
try:
msg = self.format(record)
tqdm.tqdm.write(msg)
Expand All @@ -18,7 +18,7 @@ def emit(self, record):
self.handleError(record)


def install_logger(given_logger, level=logging.WARNING, fmt="%(levelname)s:%(name)s:%(message)s"):
def install_logger(given_logger, level=logging.WARNING, fmt="%(levelname)s:%(name)s:%(message)s") -> None:
"""Configures the given logger; format, logging level, style, etc"""
import coloredlogs

Expand Down
46 changes: 26 additions & 20 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from multiprocessing import Queue
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, overload
from typing import Any, Callable, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union, overload

import numpy as np
import torch
Expand Down Expand Up @@ -159,7 +160,7 @@ def __init__(
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
model_card_data: Optional[SentenceTransformerModelCardData] = None,
):
) -> None:
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
self.default_prompt_name = default_prompt_name
Expand Down Expand Up @@ -689,7 +690,9 @@ def similarity_pairwise(self) -> Callable[[Union[Tensor, ndarray], Union[Tensor,
self.similarity_fn_name = SimilarityFunction.COSINE
return self._similarity_pairwise

def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[str, Any]:
def start_multi_process_pool(
self, target_devices: List[str] = None
) -> Dict[Literal["input", "output", "processes"], Any]:
"""
Starts a multi-process pool to process the encoding with several independent processes
via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`.
Expand Down Expand Up @@ -737,7 +740,7 @@ def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[str
return {"input": input_queue, "output": output_queue, "processes": processes}

@staticmethod
def stop_multi_process_pool(pool):
def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"], Any]) -> None:
"""
Stops all processes started with start_multi_process_pool.
Expand All @@ -760,7 +763,7 @@ def stop_multi_process_pool(pool):
def encode_multi_process(
self,
sentences: List[str],
pool: Dict[str, object],
pool: Dict[Literal["input", "output", "processes"], Any],
prompt_name: Optional[str] = None,
prompt: Optional[str] = None,
batch_size: int = 32,
Expand All @@ -776,7 +779,8 @@ def encode_multi_process(
Args:
sentences (List[str]): List of sentences to encode.
pool (Dict[str, object]): A pool of workers started with SentenceTransformer.start_multi_process_pool.
pool (Dict[Literal["input", "output", "processes"], Any]): A pool of workers started with
:meth:`SentenceTransformer.start_multi_process_pool <sentence_transformers.SentenceTransformer.start_multi_process_pool>`.
prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
which is either set in the constructor or loaded from the model configuration. For example if
``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
Expand Down Expand Up @@ -847,7 +851,9 @@ def main():
return embeddings

@staticmethod
def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue):
def _encode_multi_process_worker(
target_device: str, model: "SentenceTransformer", input_queue: Queue, results_queue: Queue
) -> None:
"""
Internal working process to encode sentences in multi-process setup
"""
Expand Down Expand Up @@ -915,7 +921,7 @@ def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]) -
"""
return self._first_module().tokenize(texts)

def get_sentence_features(self, *features):
def get_sentence_features(self, *features) -> Dict[Literal["sentence_embedding"], torch.Tensor]:
return self._first_module().get_sentence_features(*features)

def get_sentence_embedding_dimension(self) -> Optional[int]:
Expand All @@ -938,7 +944,7 @@ def get_sentence_embedding_dimension(self) -> Optional[int]:
return output_dim

@contextmanager
def truncate_sentence_embeddings(self, truncate_dim: Optional[int]):
def truncate_sentence_embeddings(self, truncate_dim: Optional[int]) -> Iterator[None]:
"""
In this context, :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>` outputs
sentence embeddings truncated at dimension ``truncate_dim``.
Expand Down Expand Up @@ -967,11 +973,11 @@ def truncate_sentence_embeddings(self, truncate_dim: Optional[int]):
finally:
self.truncate_dim = original_output_dim

def _first_module(self):
def _first_module(self) -> torch.nn.Module:
"""Returns the first module of this sequential embedder"""
return self._modules[next(iter(self._modules))]

def _last_module(self):
def _last_module(self) -> torch.nn.Module:
"""Returns the last module of this sequential embedder"""
return self._modules[next(reversed(self._modules))]

Expand All @@ -982,7 +988,7 @@ def save(
create_model_card: bool = True,
train_datasets: Optional[List[str]] = None,
safe_serialization: bool = True,
):
) -> None:
"""
Saves a model and its configuration files to a directory, so that it can be loaded
with ``SentenceTransformer(path)`` again.
Expand Down Expand Up @@ -1049,7 +1055,7 @@ def save_pretrained(
create_model_card: bool = True,
train_datasets: Optional[List[str]] = None,
safe_serialization: bool = True,
):
) -> None:
"""
Saves a model and its configuration files to a directory, so that it can be loaded
with ``SentenceTransformer(path)`` again.
Expand All @@ -1072,7 +1078,7 @@ def save_pretrained(

def _create_model_card(
self, path: str, model_name: Optional[str] = None, train_datasets: Optional[List[str]] = "deprecated"
):
) -> None:
"""
Create an automatic model and stores it in the specified path. If no training was done and the loaded model
was a Sentence Transformer model already, then its model card is reused.
Expand Down Expand Up @@ -1240,7 +1246,7 @@ def push_to_hub(
# This isn't expected to ever be reached.
return folder_url

def _text_length(self, text: Union[List[int], List[List[int]]]):
def _text_length(self, text: Union[List[int], List[List[int]]]) -> int:
"""
Help function to get the length for the input text. Text can be either
a list of ints (which means a single text as input), or a tuple of list of ints
Expand All @@ -1256,7 +1262,7 @@ def _text_length(self, text: Union[List[int], List[List[int]]]):
else:
return sum([len(t) for t in text]) # Sum of length of individual strings

def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None):
def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None) -> Union[Dict[str, float], float]:
"""
Evaluate the model based on an evaluator
Expand Down Expand Up @@ -1504,7 +1510,7 @@ def _load_sbert_model(
return modules

@staticmethod
def load(input_path):
def load(input_path) -> "SentenceTransformer":
return SentenceTransformer(input_path)

@property
Expand All @@ -1530,14 +1536,14 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return torch.device("cpu")

@property
def tokenizer(self):
def tokenizer(self) -> Any:
"""
Property to get the tokenizer that is used by this model
"""
return self._first_module().tokenizer

@tokenizer.setter
def tokenizer(self, value):
def tokenizer(self, value) -> None:
"""
Property to set the tokenizer that should be used by this model
"""
Expand All @@ -1563,7 +1569,7 @@ def max_seq_length(self) -> int:
return self._first_module().max_seq_length

@max_seq_length.setter
def max_seq_length(self, value):
def max_seq_length(self, value) -> None:
"""
Property to set the maximal input sequence length for the model. Longer inputs will be truncated.
"""
Expand Down
20 changes: 11 additions & 9 deletions sentence_transformers/cross_encoder/CrossEncoder.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import logging
import os
from functools import wraps
from typing import Callable, Dict, List, Optional, Type, Union
from typing import Callable, Dict, List, Literal, Optional, Tuple, Type, Union

import numpy as np
import torch
from torch import nn
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm, trange
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, is_torch_npu_available
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import PushToHubMixin

from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.readers import InputExample
from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.util import fullname, get_device_name, import_from_string

Expand Down Expand Up @@ -64,7 +66,7 @@ def __init__(
local_files_only: bool = False,
default_activation_function=None,
classifier_dropout: float = None,
):
) -> None:
if tokenizer_args is None:
tokenizer_args = {}
if automodel_args is None:
Expand Down Expand Up @@ -125,7 +127,7 @@ def __init__(
else:
self.default_activation_function = nn.Sigmoid() if self.config.num_labels == 1 else nn.Identity()

def smart_batching_collate(self, batch):
def smart_batching_collate(self, batch: List[InputExample]) -> Tuple[BatchEncoding, Tensor]:
texts = [[] for _ in range(len(batch[0].texts))]
labels = []

Expand All @@ -147,7 +149,7 @@ def smart_batching_collate(self, batch):

return tokenized, labels

def smart_batching_collate_text_only(self, batch):
def smart_batching_collate_text_only(self, batch: List[InputExample]) -> BatchEncoding:
texts = [[] for _ in range(len(batch[0]))]

for example in batch:
Expand Down Expand Up @@ -182,7 +184,7 @@ def fit(
use_amp: bool = False,
callback: Callable[[float, int, int], None] = None,
show_progress_bar: bool = True,
):
) -> None:
"""
Train the model with the given training objective
Each training objective is sampled in turn for one batch.
Expand Down Expand Up @@ -406,7 +408,7 @@ def rank(
apply_softmax=False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
) -> List[Dict]:
) -> List[Dict[Literal["corpus_id", "score", "text"], Union[int, float, str]]]:
"""
Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores.
Expand All @@ -424,7 +426,7 @@ def rank(
convert_to_tensor (bool, optional): Convert the output to a tensor. Defaults to False.
Returns:
List[Dict]: A sorted list with the document indices and scores, and optionally also documents.
List[Dict[Literal["corpus_id", "score", "text"], Union[int, float, str]]]: A sorted list with the "corpus_id", "score", and optionally "text" of the documents.
Example:
::
Expand Down Expand Up @@ -484,7 +486,7 @@ def rank(
results = sorted(results, key=lambda x: x["score"], reverse=True)
return results[:top_k]

def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback):
def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback) -> None:
"""Runs evaluation during the training"""
if evaluator is not None:
score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps)
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/evaluation/SentenceEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(
"""
pass

def prefix_name_to_metrics(self, metrics: Dict[str, float], name: str):
def prefix_name_to_metrics(self, metrics: Dict[str, float], name: str) -> Dict[str, float]:
if not name:
return metrics
metrics = {name + "_" + key: value for key, value in metrics.items()}
Expand Down
19 changes: 10 additions & 9 deletions sentence_transformers/fit_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from packaging import version
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm.autonotebook import trange
from transformers import TrainerCallback, TrainerControl, TrainerState
Expand Down Expand Up @@ -68,7 +69,7 @@ def on_evaluate(
metrics: Dict[str, Any],
model: "SentenceTransformer",
**kwargs,
):
) -> None:
if self.evaluator is not None and self.save_best_model:
metric_key = getattr(self.evaluator, "primary_metric", "evaluator")
for key, value in metrics.items():
Expand All @@ -84,7 +85,7 @@ def on_train_end(
control: TrainerControl,
model: "SentenceTransformer",
**kwargs,
):
) -> None:
if self.evaluator is None:
model.save(self.output_dir)

Expand All @@ -109,7 +110,7 @@ def on_epoch_end(
control: TrainerControl,
model: "SentenceTransformer",
**kwargs,
):
) -> None:
evaluator_metrics = self.evaluator(model, epoch=state.epoch)
if not isinstance(evaluator_metrics, dict):
evaluator_metrics = {"evaluator": evaluator_metrics}
Expand Down Expand Up @@ -141,7 +142,7 @@ def on_evaluate(
control: TrainerControl,
metrics: Dict[str, Any],
**kwargs,
):
) -> None:
metric_key = getattr(self.evaluator, "primary_metric", "evaluator")
for key, value in metrics.items():
if key.endswith(metric_key):
Expand Down Expand Up @@ -172,7 +173,7 @@ def fit(
checkpoint_path: str = None,
checkpoint_save_steps: int = 500,
checkpoint_save_total_limit: int = 0,
):
) -> None:
"""
Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
:class:`~sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method uses
Expand Down Expand Up @@ -371,7 +372,7 @@ def _default_checkpoint_dir() -> str:
trainer.train()

@staticmethod
def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int) -> LambdaLR:
"""
Returns the correct learning rate scheduler. Available scheduler:
Expand Down Expand Up @@ -450,7 +451,7 @@ def old_fit(
checkpoint_path: str = None,
checkpoint_save_steps: int = 500,
checkpoint_save_total_limit: int = 0,
):
) -> None:
"""
Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
:class:`sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method should
Expand Down Expand Up @@ -658,7 +659,7 @@ def old_fit(
if checkpoint_path is not None:
self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)

def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback):
def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback) -> None:
"""Runs evaluation during the training"""
eval_path = output_path
if output_path is not None:
Expand All @@ -675,7 +676,7 @@ def _eval_during_training(self, evaluator, output_path, save_best_model, epoch,
if save_best_model:
self.save(output_path)

def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step):
def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step) -> None:
# Store new checkpoint
self.save(os.path.join(checkpoint_path, str(step)))

Expand Down
Loading

0 comments on commit b5e98e1

Please sign in to comment.