Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Understanding how (hard) negatives in MNRL are used #3097

Closed
HenningDinero opened this issue Nov 28, 2024 · 4 comments
Closed

Understanding how (hard) negatives in MNRL are used #3097

HenningDinero opened this issue Nov 28, 2024 · 4 comments

Comments

@HenningDinero
Copy link

HenningDinero commented Nov 28, 2024

I'm stugling a bit with completely unstanding the MNRL implementation.

I've tried looking down into the source-code for the MNRL loss

    def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
        reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
        embeddings_a = reps[0]
        embeddings_b = torch.cat(reps[1:])

        scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
        # Example a[i] should match with b[i]
        range_labels = torch.arange(0, scores.size(0), device=scores.device)
        return self.cross_entropy_loss(scores, range_labels)

and following the entire training-flow. I'm not that strong in the "ordinary" transformers-training flow thus I might miss some basic understanding.

With at batch-size of 128 (and only using anchors/positive) we get that embeddings_a contains 128 embeddings (for the anchors) and embeddings_b contains the embeddings for the positives (also 128).
Adding hard-negatives in the dataset, we get that embeddings_b contains embeddings for both the positives and negatives i.e is 256 long (2x 128).
But I can only find the negatives when providing them specific as hard-negatives.

I have tried following the entire training-flow and I can't see where the batched negatives (i.e positives from other anchors) are being fed into the loss-function if I don't provide some hard negatives. The inputs (from the Trainer) consists of 4 dictionaries; sentence_input_ids, sentence_attention_mask, positive_input_ids and positive_attention_mask i.e no negatives, thus I can only see the loss-function (CE) being affected by the positives and the anchors since embedding_a consists of embeddings of the anchors and embedding_b consists of embeddings of the positives.
Does that mean, that if we don't provide any hard-negatives the MNRL only tries to minimize the distance between the anchor and the positives without looking at maximizing the distance between some negatives, or am I simply missing something?

I would've assumed the following (pseudo-code):

if hard_negatives is not None:
    negatives = hard_negatives
else:
    negatives = sample_negatives_from_other_positives(anchors, positives)
embeddings_anchor = model(anchors)
embeddings_positive = model(positives)
embeddings_negatives = model(negatives)

embeddings_target = torch.cat(embeddings_positive, embeddings_negatives)

scores = self.similarity_fct(embeddings_anchor, embeddings_target) * self.scale
# Example a[i] should match with b[i]
range_labels = torch.arange(0, scores.size(0), device=scores.device)
return self.cross_entropy_loss(scores, range_labels)

Furthermore; shouldn't we have more than 128 if the batch-size is set to 128 (more specific 128*127), since we create 127 negatives for each sample?

@HenningDinero HenningDinero changed the title Hard negatives in MNRL Understanding how hard negatives in MNRL are used Nov 28, 2024
@HenningDinero HenningDinero changed the title Understanding how hard negatives in MNRL are used Understanding how (hard) negatives in MNRL are used Nov 28, 2024
@tomaarsen
Copy link
Collaborator

Hello!

Your pseudocode is actually super close to what really happens. I think the variable names are a bit bad in this function, so I'll actually make a PR in a bit to resolve that. This is what I'll turn it into:

    def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
        # Compute the embeddings and distribute them to anchor, and candidates (positive and optionally negatives)
        embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
        anchors = embeddings[0]  # (batch_size, embedding_dim)
        candidates = torch.cat(embeddings[1:])  # (batch_size * (1 + num_negatives), embedding_dim)

        # For every anchor, we compute the similarity to all other candidates (positives and negatives)
        scores = self.similarity_fct(anchors, candidates) * self.scale  # (batch_size, batch_size * (1 + num_negatives))
        
        # anchor[i] should be most similar to candidates[i], as that is the paired positive,
        # so the label for anchor[i] is i
        range_labels = torch.arange(0, scores.size(0), device=scores.device)

        return self.cross_entropy_loss(scores, range_labels)

So:

  • anchors (previously embeddings_a) is the first of the features (i.e. the first dataset column), and it has the shape (batch_size, embedding_dim)
  • candidates (previously embeddings_b) is the second of the features onwards, i.e. the second column onwards. If you don't have any negatives, then this is just one column, but this loss function also allows you to add a fixed amount of extra negatives. The shape becomes (batch_size * (1 + num_negatives), embedding_dim). So, we concatenate the positives and each of the negatives.

When we compute the scores, we get (batch_size, batch_size * (1 + num_negatives)), i.e. all anchors VS all candidates:
image
I didn't fill it out all the way, and the shapes are a bit off, but the idea here is that every row relates to an anchor, and there's batch_size * (1 + num_negatives) = 512 candidates. Only 1 is correct: the one on the "diagonal", i.e. anchor[i] matches with candidates[i], because that's the i-th positive.

So, the model has to pick that 1 out of 512 options. Out of those 512 options, only 4 options originated from this sample itself (positive[i], negative_1[i], negative_2[i], and negative_3[i]).

Does that make a bit more sense now?

  • Tom Aarsen

@HenningDinero
Copy link
Author

HenningDinero commented Nov 28, 2024

The variables could be renamed but they are not so much the issue - but it would be a good idea none the less!

I thought a bit of it later on, after I posted the questions, and I came to (almost) the same conclusion.
In my head, scores would be a pairwise scoring metric e.g embedding_a @ embedding_b.T, and since for row i the corresponding column should be i aswell, I now see that the other positives acts as a negatives, since it has to pick i among all positives (one could se it as a classification problem with batch_size targets where the right target for row i is i).

I don't know if that is the same you write?
Also it seems like that the hard-negative for i also "participates" in the negatives for all j since all negatives are being 'catted' to the positives - is that correct? And eventhough we specify negatives, we would still use the other positives as negatives (can this be avoided such that I can control the positives/negatives for each anchor?)

One last thing; if I read your answer correctly it is possible to have multiple hard-negatives for specific anchors? I have tried the following (just to test it out)

    training_data = Dataset.from_dict(
        {
            "sentence": training_data["Text"],
            "positives": training_data["OtherText"],
            "neg": [[p, p] for p in training_data["OtherText"]],
        }
    )

which should add two hard-negatives for each anchor but the dimension of embedding_b seems to be the same (2x batch_size) whether I have one hard-negative or 5 for each anchor

@tomaarsen
Copy link
Collaborator

one could se it as a classification problem with batch_size targets where the right target for row i is i

That's a great way to look at it indeed, but note that it's only batch_size targets if we only have positives (i.e. only 2 columns). If we have N negative columns, then we actually have batch_size + (batch_size * N) targets, because the negatives are also taken as candidates across the entire batch.

I see now that you write the same here:

Also it seems like that the hard-negative for i also "participates" in the negatives for all j since all negatives are being 'catted' to the positives - is that correct?

That is indeed exactly correct.

And eventhough we specify negatives, we would still use the other positives as negatives (can this be avoided such that I can control the positives/negatives for each anchor?)

No, it cannot. You can however create a custom loss function (docs) that has this functionality. It's a bit akin to TripletLoss - which does not use in-batch negatives.

One last thing; if I read your answer correctly it is possible to have multiple hard-negatives for specific anchors?

You can indeed, but your script is indeed wrong, I'm afraid. This is what it would look like:

    training_data = Dataset.from_dict(
        {
            "sentence": training_data["Text"],
            "positives": training_data["OtherText"],
            "negative_1": [p for p in training_data["OtherText"]],
            "negative_2": [p for p in training_data["OtherText"]],
            "negative_3": [p for p in training_data["OtherText"]],
        }
    )

(ignoring the fact that each negative takes data from the same source, training_data["OtherText"])

For example: https://huggingface.co/datasets/tomaarsen/gooaq-hard-negatives/viewer

The reason that we don't allow a dataset to contain lists of texts is that lists are not necessarily the same length for all samples, but we need it to be the same length to be able to create proper matrices behind the scenes. The shapes just won't line up if we don't guarantee that.

  • Tom Aarsen

@HenningDinero
Copy link
Author

Thanks a bunch for the replies and the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants