Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[integration] Add support for Transformers v4.46.0 #3026

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from typing import TYPE_CHECKING, Any, Callable

import torch
from packaging.version import parse as parse_version
from torch import nn
from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, SubsetRandomSampler
from transformers import EvalPrediction, PreTrainedTokenizerBase, Trainer, TrainerCallback
from transformers import __version__ as transformers_version
from transformers.data.data_collator import DataCollator
from transformers.integrations import WandbCallback
from transformers.trainer import TRAINING_ARGS_NAME
Expand Down Expand Up @@ -202,19 +204,24 @@ def __init__(
train_dataset = DatasetDict(train_dataset)
if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, DatasetDict):
eval_dataset = DatasetDict(eval_dataset)
super().__init__(
model=None if self.model_init else model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
super_kwargs = {
"model": None if self.model_init else model,
"args": args,
"data_collator": data_collator,
"train_dataset": train_dataset,
"eval_dataset": eval_dataset,
"model_init": model_init,
"compute_metrics": compute_metrics,
"callbacks": callbacks,
"optimizers": optimizers,
"preprocess_logits_for_metrics": preprocess_logits_for_metrics,
}
# Transformers v4.46.0 changed the `tokenizer` argument to a more general `processing_class` argument
if parse_version(transformers_version) >= parse_version("4.46.0"):
super_kwargs["processing_class"] = tokenizer
else:
super_kwargs["tokenizer"] = tokenizer
super().__init__(**super_kwargs)
# Every Sentence Transformer model can always return a loss, so we set this to True
# to avoid having to specify it in the data collator or model's forward
self.can_return_loss = True
Expand Down Expand Up @@ -311,6 +318,7 @@ def compute_loss(
model: SentenceTransformer,
inputs: dict[str, torch.Tensor | Any],
return_outputs: bool = False,
num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]:
"""
Computes the loss for the SentenceTransformer model.
Expand All @@ -325,6 +333,7 @@ def compute_loss(
model (SentenceTransformer): The SentenceTransformer model.
inputs (Dict[str, Union[torch.Tensor, Any]]): The input data for the model.
return_outputs (bool, optional): Whether to return the outputs along with the loss. Defaults to False.
num_items_in_batch (int, optional): The number of items in the batch. Defaults to None. Unused, but required by the transformers Trainer.

Returns:
Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: The computed loss. If `return_outputs` is True, returns a tuple of loss and outputs. Otherwise, returns only the loss.
Expand Down