-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Understanding how (hard) negatives in MNRL are used #3097
Comments
The variables could be renamed but they are not so much the issue - but it would be a good idea none the less! I thought a bit of it later on, after I posted the questions, and I came to (almost) the same conclusion. I don't know if that is the same you write? One last thing; if I read your answer correctly it is possible to have multiple hard-negatives for specific anchors? I have tried the following (just to test it out) training_data = Dataset.from_dict(
{
"sentence": training_data["Text"],
"positives": training_data["OtherText"],
"neg": [[p, p] for p in training_data["OtherText"]],
}
) which should add two hard-negatives for each anchor but the dimension of |
That's a great way to look at it indeed, but note that it's only I see now that you write the same here:
That is indeed exactly correct.
No, it cannot. You can however create a custom loss function (docs) that has this functionality. It's a bit akin to TripletLoss - which does not use in-batch negatives.
You can indeed, but your script is indeed wrong, I'm afraid. This is what it would look like: training_data = Dataset.from_dict(
{
"sentence": training_data["Text"],
"positives": training_data["OtherText"],
"negative_1": [p for p in training_data["OtherText"]],
"negative_2": [p for p in training_data["OtherText"]],
"negative_3": [p for p in training_data["OtherText"]],
}
) (ignoring the fact that each negative takes data from the same source, For example: https://huggingface.co/datasets/tomaarsen/gooaq-hard-negatives/viewer The reason that we don't allow a dataset to contain lists of texts is that lists are not necessarily the same length for all samples, but we need it to be the same length to be able to create proper matrices behind the scenes. The shapes just won't line up if we don't guarantee that.
|
Thanks a bunch for the replies and the help! |
I'm stugling a bit with completely unstanding the MNRL implementation.
I've tried looking down into the source-code for the MNRL loss
and following the entire training-flow. I'm not that strong in the "ordinary" transformers-training flow thus I might miss some basic understanding.
With at batch-size of 128 (and only using anchors/positive) we get that
embeddings_a
contains 128 embeddings (for the anchors) andembeddings_b
contains the embeddings for the positives (also 128).Adding hard-negatives in the dataset, we get that
embeddings_b
contains embeddings for both the positives and negatives i.e is 256 long (2x 128).But I can only find the negatives when providing them specific as hard-negatives.
I have tried following the entire training-flow and I can't see where the batched negatives (i.e positives from other anchors) are being fed into the loss-function if I don't provide some hard negatives. The
inputs
(from theTrainer
) consists of 4 dictionaries;sentence_input_ids
,sentence_attention_mask
,positive_input_ids
andpositive_attention_mask
i.e no negatives, thus I can only see the loss-function (CE) being affected by the positives and the anchors sinceembedding_a
consists of embeddings of the anchors andembedding_b
consists of embeddings of the positives.Does that mean, that if we don't provide any hard-negatives the MNRL only tries to minimize the distance between the anchor and the positives without looking at maximizing the distance between some negatives, or am I simply missing something?
I would've assumed the following (pseudo-code):
Furthermore; shouldn't we have more than 128 if the batch-size is set to 128 (more specific 128*127), since we create 127 negatives for each sample?
The text was updated successfully, but these errors were encountered: