Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unstable Training Loss and Model Evaluation #3041

Open
nguyenvannghiem0312 opened this issue Nov 7, 2024 · 7 comments
Open

Unstable Training Loss and Model Evaluation #3041

nguyenvannghiem0312 opened this issue Nov 7, 2024 · 7 comments

Comments

@nguyenvannghiem0312
Copy link

image
image

During the training of the multilingual-E5-base model, I encountered an unstable loss pattern. Previously, I trained and had a stable loss function. I tried changing the model but encountered a similar issue. Could you help me understand what this problem might be?

@pesuchin
Copy link
Contributor

@nguyenvannghiem0312
Thanks for posting the issue!
Please let me know for reference which loss class caused the instability.

@nguyenvannghiem0312
Copy link
Author

Hello @pesuchin , sorry I missed the notification. I used the CachedMNRL, MNRL loss functions, but encountered issues with both. I discovered that when there are only anchor and positive samples, training is stable; however, when I add negatives with the following code, it becomes unstable as shown in the image:

train_datasets = read_json_or_dataset(config["train_path"])
train_datasets = process_data(train=train_datasets, number_negatives=config['number_negatives'])
datasets = {}
anchor = [config["query_prompt"] + item["anchor"] for item in train_datasets]
positive = [config["corpus_prompt"] + item["positive"] for item in train_datasets]

datasets["anchor"] = anchor
datasets["positive"] = positive
if "negative" in train_datasets[0] and config["is_triplet"] == True:
    negative = [config["corpus_prompt"] + item["negative"] for item in train_datasets]
    datasets["negative"] = negative
return Dataset.from_dict(datasets)

When I adjusted the code to add negatives as shown below, the training process stabilized:

datasets = []
for item in train_datasets:
    sample = {
        'anchor': config["query_prompt"] + item["anchor"],
        'positive': config["corpus_prompt"] + item["positive"]
    }
    if config["is_triplet"] == True:
        for idx in range(config['number_negatives']):
            sample[f'negative_{idx}'] = config["corpus_prompt"] + item[f'negative_{idx}']

    datasets.append(sample)
datasets = Dataset.from_list(datasets)

I'm not sure if there was an error in my initial code, but I've resolved the issue with the updated code.

@nguyenvannghiem0312
Copy link
Author

Oh @pesuchin , I just realized that the issue is not only from the negative samples, but also from the batch size of CacheMNRL. When I train with a batch size of 128 (with 3 negative samples per query), the loss is stable.
image

But when using a batch size of 512 (with 3 negative samples per query), it becomes unstable, as shown in the image below.
image

However, when I train with a batch size of 4096 (without using negatives), the process is mostly stable... though I did notice a slight instability here
image

So, the issue comes from both the negative samples and the batch size.
I also noticed one more thing: each time training becomes unstable, the "train/epoch" graph also becomes unstable (it’s not a straight line).

@nguyenvannghiem0312
Copy link
Author

Although there is some instability here, the results at the best checkpoint are still quite good. However, this loss graph is not ideal for including in the report.

@tomaarsen
Copy link
Collaborator

Wow, those are very fascinating results.
Could you use batch_sampler="no_duplicates" perhaps? This will prevent identical texts within a batch, which helps remove false negatives in the in-batch negatives. That might be a reason, as a specific text can then both be a positive and a negative for an anchor, which is confusing for the model.

Additionally, would you describe your negatives as very hard? I can imagine that very hard negatives might result in odd behaviour, as the model has trouble distinguishing between the positive and the very hard negatives.


Oh! I just noticed the issue, I believe: the epoch figure is odd.
image
image

It looks like, in the 512 (with 3 negative samples per query) case, you're training with about 900 steps in the first epoch, but the trainer thinks that should correspond with ~3.5 epochs. Once the epoch is completed, the trainer sets the epoch to 1, and keeps going. This results in a very problematic epoch figure, which presumably affects the learning rate. Could you verify if the learning rate has an odd figure as well?

I think the 0 loss that you're seeing sometimes is because the learning rate has been reduced to 0, and then the loss spikes a ton when the learning rate is suddenly a normal number again.

Could you share a bit more about your hyperparameters/training arguments?

  • Tom Aarsen

@nguyenvannghiem0312
Copy link
Author

nguyenvannghiem0312 commented Nov 14, 2024

Tks @tomaarsen for your response.
Here are the hyperparameters I used:

{
    "model": model,
    "guide_model": guide_model,
    "max_length": 1022,
    "query_prompt": "query: ",
    "corpus_prompt": "passage: ",
    "is_triplet": true,
    "number_negatives": 3,
    "loss": "CachedMultipleNegativesRankingLoss",
    "batch_size": 512,
    "mini_batch_size": 16,
    "num_train_epochs": 10,
    "warmup_ratio": 0.05,
    "fp16": false,
    "bf16": true,
    "batch_sampler": "NO_DUPLICATES",
    "eval_strategy": "steps",
    "eval_steps": 10,
    "save_strategy": "steps",
    "save_steps": 10,
    "save_total_limit": 2,
    "logging_steps": 1,
    "load_best_model_at_end": true,
    "metric_for_best_model": "eval_cosine_mrr@10",
    "learning_rate": 5e-5,
}

I also considered cases of misclassification with negative samples (I used BM25 to mine negative samples) and used the GIST method to eliminate that (v7 v7 is the model in which I used GIST with the same parameters.), which seemed to have a slight impact. Although the loss is still not fully stable, its instability has lessened somewhat.
image

and the learning rate curve is very stable.
image

@tomaarsen
Copy link
Collaborator

Hmm, that learning rate looks correct indeed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants