-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix
] Matryoshka training always patch original forward, and check matryoshka_dims
#2593
Conversation
Hello! This strikes me as a good idea to prevent some very unexpected issues, even if they only occur very rarely. As for your last paragraph, perhaps we can simplify this by warning the user if one of their provided Matryoshka embedding dimensions is larger than the model's original embedding dimension. After all, in that case the truncation won't do anything.
|
I verified that the bug happens in certain conditions—a simple one is where Here's a CPU-friendly script, where I modified scriptfrom typing import NoReturn
import torch
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses, InputExample
from sentence_transformers.evaluation import (
EmbeddingSimilarityEvaluator,
SimilarityFunction,
)
model = SentenceTransformer("paraphrase-albert-small-v2", device="cpu")
print(model.get_sentence_embedding_dimension())
# 768
# matryoshka_dims = [768, 10, 9, 8, 7, 6, 5, 4, 3, 2]
n_dims_per_step = -1
matryoshka_dims = [2, 3, 4, 5, 6, 7, 8, 9, 10, 768]
# Dummy data
train_examples = [
InputExample(texts=["Anchor 1", "Positive 1"]),
InputExample(texts=["somethin", "something else"]),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
dev_evaluator = EmbeddingSimilarityEvaluator(
["aljfad", "a;lkjdfasl;jf"],
["sentence3", "sentence4"],
[0.9, 0.9],
main_similarity=SimilarityFunction.COSINE,
write_csv=False,
show_progress_bar=True,
)
# Bad loss that will immediately raise an error
class MultipleNegativesRankingLossBad(torch.nn.Module):
def __init__(self, model: SentenceTransformer) -> None:
super().__init__()
self.model = model
def forward(*args, **kwargs) -> NoReturn:
raise ValueError("Faaaaill")
train_loss = MultipleNegativesRankingLossBad(model)
train_loss = losses.MatryoshkaLoss(
model, train_loss, matryoshka_dims=matryoshka_dims, n_dims_per_step=n_dims_per_step
)
# First attempt at training
model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=dev_evaluator,
epochs=2,
)
# raises ValueError: Faaaaill
# model.forward has been modified and will always truncate to 2
print(type(model.forward))
# <class 'sentence_transformers.losses.MatryoshkaLoss.ForwardDecorator'>
print(model.forward.dim)
# 2
# Correct the loss and run it again
train_loss = losses.MultipleNegativesRankingLoss(model)
train_loss = losses.MatryoshkaLoss(model, train_loss, matryoshka_dims=matryoshka_dims)
# Silently wrong training
model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=dev_evaluator,
epochs=2,
)
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
# torch.Size([2, 2])
Good idea, I'll add this to the PR. Do you think this should be a warning, a |
fix
] Matryoshka training always patch original forward, and check matryoshka_dims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the error text should be a bit more explicit with notifying the users what they're doing wrong. Other than that, this is looking good to go.
Co-authored-by: Tom Aarsen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much appreciated! Looks good :)
This is somewhat unrelated, but it may interest you nonetheless. I also don't remember if I have written about this here before, so apologies if you've seen this already: Instead, a potential advancement is to consider a "continuous Matryoshka loss" or "continuous MRL" simply by creating a similarity function that prioritizes information towards the start of the embedding. This similarity score function can be applied directly in other losses that accept such functions, e.g. MultipleNegativesRankingLoss. I experimented with a naive version, e.g.: def mrl_cos_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# Normalize the input embeddings such that matrix multiplication is cosine similarity
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
# Multiply the normalize embeddings with a decreasing multiplier to give more importance to the first dimensions
multiplier = torch.arange(b_norm.shape[-1], 0, -1, device=a_norm.device) / b_norm.shape[-1]
a_norm *= multiplier
b_norm *= multiplier
# Return the cosine similarity
return a_norm @ b_norm.T loss = losses.MultipleNegativesRankingLoss(model, similarity_fct=mrl_cos_sim) I evaluated this with semantic textual similarity using STSB and the triplet evaluator using the AllNLI validation dataset, while training on the AllNLI train dataset. I used the training refactor PR which integrates with Weights and Biases to easily compare the performance with a baseline: This figure is quite interesting. With the "matryoshka cosine similarity", the Spearman Correlation reduces very gradually when reducing the dimensionality: 768 > 512 > 256 > 128 > 64 > 32 > 16 > 8, while the baseline is very jumpy: 256 > 768 > 128 > 512 > 64 > 32 > 16 > 8. As a result, the matryoshka cosine similarity model is sometimes much better, and sometimes much worse. (w&b link) This figure is a lot more straightforward. Perhaps that is to be expected, as this validation dataset originates from the same distribution as the training set, so this "gradual increase over time during training" is pretty normal. As could be somewhat expected, the model performs worse than the baseline at 768 (because the MRL cosine similarity cant use the last dimensions to store information as well as the baseline, so in essence it can store "less information"). This difference shrinks and eventually the continuous MRL model handily outperforms the baseline. The full training scriptfrom collections import defaultdict
import datasets
from datasets import Dataset
import torch
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
losses,
evaluation,
SentenceTransformerTrainingArguments
)
from sentence_transformers.models import Transformer, Pooling
"""
def mrl_cos_sim(a, b):
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
# a x d, b x d -> a x b x d
similarity_per_dim = a_norm * b_norm
multiplier = torch.arange(similarity_per_dim.shape[-1], 0, -1, device=similarity_per_dim.device) / 768
similarity_per_dim *= multiplier
return similarity_per_dim.sum(-1, keepdim=True)
"""
def mrl_cos_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# Normalize the input embeddings such that matrix multiplication is cosine similarity
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
# Multiply the normalize embeddings with a decreasing multiplier to give more importance to the first dimensions
multiplier = torch.arange(b_norm.shape[-1], 0, -1, device=a_norm.device) / b_norm.shape[-1]
a_norm *= multiplier
b_norm *= multiplier
# Return the cosine similarity
return a_norm @ b_norm.T
def to_triplets(dataset):
premises = defaultdict(dict)
for sample in dataset:
premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
queries = []
positives = []
negatives = []
for premise, sentences in premises.items():
if 0 in sentences and 2 in sentences:
queries.append(premise)
positives.append(sentences[0]) # <- entailment
negatives.append(sentences[2]) # <- contradiction
return Dataset.from_dict({
"anchor": queries,
"positive": positives,
"negative": negatives,
})
if __name__ == "__main__":
snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
"train": to_triplets(snli_ds["train"]),
"validation": to_triplets(snli_ds["validation"]),
"test": to_triplets(snli_ds["test"]),
})
multi_nli_ds = datasets.load_dataset("multi_nli")
multi_nli_ds = datasets.DatasetDict({
"train": to_triplets(multi_nli_ds["train"]),
"validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
})
all_nli_ds = datasets.DatasetDict({
"train": datasets.concatenate_datasets([snli_ds["train"], snli_ds["train"]]),
"validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),
"test": snli_ds["test"]
})
stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")
training_args = SentenceTransformerTrainingArguments(
output_dir="checkpoints",
run_name="mpnet-base-allnli-baseline",
# report_to="none",
num_train_epochs=1,
seed=33,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
learning_rate=2e-5,
warmup_ratio=0.1,
bf16=True,
logging_steps=100,
evaluation_strategy="steps",
eval_steps=500,
save_steps=500,
save_total_limit=2,
metric_for_best_model="eval_sts-dev-768_spearman_cosine",
greater_is_better=True,
)
transformer = Transformer("microsoft/mpnet-base", max_seq_length=384)
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])
loss = losses.MultipleNegativesRankingLoss(model)#, similarity_fct=mrl_cos_sim)
dev_evaluators = []
for matryoshka_dim in [768, 512, 256, 128, 64, 32, 16, 8]:
dev_evaluators.append(evaluation.EmbeddingSimilarityEvaluator(
stsb_dev["sentence1"],
stsb_dev["sentence2"],
[score / 5 for score in stsb_dev["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name=f"sts-dev-{matryoshka_dim}",
truncate_dim=matryoshka_dim,
))
dev_evaluators.append(evaluation.TripletEvaluator(
anchors=all_nli_ds["validation"]["anchor"],
positives=all_nli_ds["validation"]["positive"],
negatives=all_nli_ds["validation"]["negative"],
name=f"allnli-validation-{matryoshka_dim}",
main_distance_function=evaluation.SimilarityFunction.COSINE,
truncate_dim=matryoshka_dim,
))
dev_evaluator = evaluation.SequentialEvaluator(dev_evaluators)
# dev_evaluator(model)
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=all_nli_ds["train"],
eval_dataset=all_nli_ds["validation"],
evaluator=dev_evaluator,
loss=loss,
)
trainer.train()
test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
stsb_test["sentence1"],
stsb_test["sentence2"],
[score / 5 for score in stsb_test["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name="sts-test",
)
results = test_evaluator(model)
print(results)
model.save("mpnet-base-allnli-baseline")
# Continuous-mrl-linear:
# {'sts-test_pearson_cosine': 0.8175790710661428, 'sts-test_spearman_cosine': 0.830859176653514, 'sts-test_pearson_manhattan': 0.8362810843054216, 'sts-test_spearman_manhattan': 0.8284792813481868, 'sts-test_pearson_euclidean': 0.8261216518822675, 'sts-test_spearman_euclidean': 0.8181971965432933, 'sts-test_pearson_dot': 0.7944856099219624, 'sts-test_spearman_dot': 0.7834936544677008, 'sts-test_pearson_max': 0.8362810843054216, 'sts-test_spearman_max': 0.830859176653514}
# Base:
# {'sts-test_pearson_cosine': 0.8101063845818969, 'sts-test_spearman_cosine': 0.8319477736976867, 'sts-test_pearson_manhattan': 0.8438869778745631, 'sts-test_spearman_manhattan': 0.8365807837093596, 'sts-test_pearson_euclidean': 0.8418259869573095, 'sts-test_spearman_euclidean': 0.8352941277766923, 'sts-test_pearson_dot': 0.6568623097846887, 'sts-test_spearman_dot': 0.6654033827828304, 'sts-test_pearson_max': 0.8438869778745631, 'sts-test_spearman_max': 0.8365807837093596} Anyways, I thought these experiments might interest you! I think I'll leave these experiments as-is for now, but perhaps a great solution is possible here (e.g. like how the original MRL is often on par with baselines, but it ALSO does pretty well at lower dimensions).
|
Hey Tom, this is an extremely cool idea! We will look more closely look into that the coming days. |
lol @tomaarsen we thought almost identically about Matryoshka training. I called your "continuous MRL" loss "diagonaloss", and had recently spent some time playing around w/ different versions of the "multiply each vector by a decaying list of weights and compute a distance metric on them" idea. It's interesting to see that level of jumpiness in STSB within each dimension and across dimensions, I wouldn't have guessed that. The AllNLI charts are super clear and cool to see. Though good to know that for mpnet, the clarity of the effect could be due to continued training on NLI.
One thing I'm hoping for is that a clearer geometric interpretation will help in getting to a Matryoshka-like loss. Here's some dirty work I recently did on a few variants of continuous MRL / diagonaloss:
The Matryoshka eval plot (from this notebook)— —shows that MRL is significantly better at 64 dimensions, and slightly better at the rest. (I realize that I should've trained the Matryoshka model down to 8 dimensions and compared all the models at 8, 16, and 32 as well.) And wish I knew about your refactor PR earlier so I'd have those nice w&b charts! :-)
I was reflecting and thinking that on one hand, MRL does work in an odd way. But on the other hand, it is pretty well-motivated—if we want a model to do well in many lower-dim spaces, then just directly do that / penalize in lower-dim spaces. In light of the results—especially your AllNLI results—I would be slightly surprised to see that the Matryoshka effect can be reproduced by re-scaling alone. The gradients of plain MRL vs continuous MRL / diagonaloss are different-enough that I couldn't come up w/ a way to reproduce the gradient via re-scaling alone. |
Interesting results! It indeed seems like MRL is rather challenging to beat.
|
Hello,
TLDR: patch back the original forward method, even after an error, to avoid silent problems during training.
In case there's an error in, e.g.,
self.loss(sentence_features, labels)
, the closingself.model.forward = original_forward
line won't get hit. So if a user is training a model in an interactive session (e.g., a notebook) and doesn't re-create theself.model
object (maybe b/c they know the error came before the optimizer ever stepped, so the model's parameters didn't change), then on a second.fit
run, I think the forward method will start by effectively doing:Whether or not this is a problem depends on the user's input to
MatryoshkaLoss
. If they don't setn_dims_per_step
, they setmatryoshka_dims
to a list such thatmatryoshka_dims[0] == max(matryoshka_dims)
, and the error is raised at the first dimension in the list, then there's no problem. Otherwise, there is a problem and it's silent.The problem occurs if and only if the
self.model.forward.dim
attribute ends up being set to something less thanmax(matryoshka_dims)
. The result will be that the model doesn't actually Matryoshka-train; it will only train up to the last dimension that was set before it errored out in the first run. Call this dimensionerr_dim
.Reasoning: we'll always have—
—b/c
self.fn
is the lastForwardDecorator
whose dimension was last set before erroring out. Next——will give back the exact same tensor if
self.dim >= err_dim
, as this slicing style doesn't raise an error ifself.dim > tensor.shape[-1]
. In other words,self.shrink
gives back tensors truncated aterr_dim
whenself.dim > err_dim
.The downstream result is that, for example, if
err_dim = 32
,matryoshka_dims = [16, 32, 64, 128]
, andmatryoshka_weights = [1, 1, 1, 1]
, then the user's second attempt at training effectively makesmatryoshka_weights
look something like[1, 3, 0, 0]
.This is also making me wonder if some input checking should be done on
matryoshka_dims
. Ifself.model.get_sentence_embedding_dimension() == d
but the user setsmatryoshka_weights=[d/2, d, 2*d, 4*d]
, they should know that they're up-weighing the loss at dimensiond
, which might result in not-so-Matryoshka like properties at inference time.