Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SentenceTransformerTrainer reset global step around step 127421 with no warning message #3145

Open
keeeeenw opened this issue Dec 24, 2024 · 2 comments

Comments

@keeeeenw
Copy link

keeeeenw commented Dec 24, 2024

Hi folks,

I am trying to perform text embedding model training for my own base model.

However, I am running into an issue where my global step was reset to 0 around step 127421 and epoch 0.17 (line 509):

490 {'loss': 0.6453, 'grad_norm': 0.37768709659576416, 'learning_rate': 9.453928992046532e-06, 'epoch': 0.16}
491 {'eval_all-nli-triplet_loss': 0.24458371102809906, 'eval_all-nli-triplet_runtime': 43.8751, 'eval_all-nli-triplet_samples_per_second': 150.062, 'eval_all-nli-triplet_steps_per_second': 37.516, 'epoch': 0.16}
492 {'eval_stsb_loss': 2.918501615524292, 'eval_stsb_runtime': 6.7477, 'eval_stsb_samples_per_second': 222.299, 'eval_stsb_steps_per_second': 55.575, 'epoch': 0.16}
493 {'eval_quora_loss': 0.005668556783348322, 'eval_quora_runtime': 45.834, 'eval_quora_samples_per_second': 223.917, 'eval_quora_steps_per_second': 55.985, 'epoch': 0.16}
494 {'eval_natural-questions_loss': 0.01890234276652336, 'eval_natural-questions_runtime': 69.9569, 'eval_natural-questions_samples_per_second': 146.247, 'eval_natural-questions_steps_per_second': 36.565, 'epoch': 0.16}
495 {'loss': 0.7351, 'grad_norm': 0.020067762583494186, 'learning_rate': 9.45170013078958e-06, 'epoch': 0.16}
496 {'loss': 0.8327, 'grad_norm': 7.6738362312316895, 'learning_rate': 9.449471269532627e-06, 'epoch': 0.17}
497 {'loss': 0.9185, 'grad_norm': 35.78461837768555, 'learning_rate': 9.447242408275672e-06, 'epoch': 0.17}
498 {'loss': 0.735, 'grad_norm': 0.007549208588898182, 'learning_rate': 9.445013547018721e-06, 'epoch': 0.17}
499   6%|█████                                                                                    | 127421/2243298 [6:34:58<109:18:46,  5.38it/s]
500 {'loss': 0.8274, 'grad_norm': 5.161532878875732, 'learning_rate': 9.442784685761767e-06, 'epoch': 0.17}
501 {'eval_all-nli-triplet_loss': 0.2401724010705948, 'eval_all-nli-triplet_runtime': 43.3488, 'eval_all-nli-triplet_samples_per_second': 151.884, 'eval_all-nli-triplet_steps_per_second': 37.971, 'epoch': 0.17}
502 {'eval_stsb_loss': 2.8828203678131104, 'eval_stsb_runtime': 6.6191, 'eval_stsb_samples_per_second': 226.616, 'eval_stsb_steps_per_second': 56.654, 'epoch': 0.17}
503 {'eval_quora_loss': 0.0062046111561357975, 'eval_quora_runtime': 45.1054, 'eval_quora_samples_per_second': 227.534, 'eval_quora_steps_per_second': 56.889, 'epoch': 0.17}
504 {'eval_natural-questions_loss': 0.020038239657878876, 'eval_natural-questions_runtime': 70.2651, 'eval_natural-questions_samples_per_second': 145.606, 'eval_natural-questions_steps_per_second': 36.405, 'epoch': 0.17}
505 {'loss': 0.8205, 'grad_norm': 0.004999153316020966, 'learning_rate': 9.440555824504814e-06, 'epoch': 0.17}
506 {'loss': 0.7468, 'grad_norm': 0.021909011527895927, 'learning_rate': 9.438326963247862e-06, 'epoch': 0.17}
507 {'loss': 0.7893, 'grad_norm': 113.3169174194336, 'learning_rate': 9.436098101990909e-06, 'epoch': 0.17}
508 {'loss': 0.7495, 'grad_norm': 0.0019107084954157472, 'learning_rate': 9.433869240733956e-06, 'epoch': 0.17}
509   0%|                                                                                                            | 0/2243298 [00:00<?, ?it/s]
510   File "/home/ken/workspace/microllama_embedding/train.py", line 208, in <module>               | 67500/2243298 [3:28:05<71:08:44,  8.50it/s]
511 {'loss': 0.7143, 'grad_norm': 0.016145436093211174, 'learning_rate': 9.997771138743048e-06, 'epoch': 0.0}
512 {'loss': 0.6057, 'grad_norm': 0.34214434027671814, 'learning_rate': 9.995542277486094e-06, 'epoch': 0.0}
513 {'loss': 0.6817, 'grad_norm': 5.778988838195801, 'learning_rate': 9.993313416229143e-06, 'epoch': 0.0}
514 {'loss': 0.5854, 'grad_norm': 1.0194756984710693, 'learning_rate': 9.991084554972188e-06, 'epoch': 0.0}
515 {'loss': 0.6466, 'grad_norm': 0.1206388995051384, 'learning_rate': 9.988855693715236e-06, 'epoch': 0.0}
516 {'eval_all-nli-triplet_loss': 0.21341252326965332, 'eval_all-nli-triplet_runtime': 42.8829, 'eval_all-nli-triplet_samples_per_second': 153.534, 'eval_all-nli-triplet_steps_per_second': 38.384, 'epoch': 0.0}
517 {'eval_stsb_loss': 2.605201005935669, 'eval_stsb_runtime': 6.808, 'eval_stsb_samples_per_second': 220.328, 'eval_stsb_steps_per_second': 55.082, 'epoch': 0.0}
518 {'eval_quora_loss': 0.005764608271420002, 'eval_quora_runtime': 45.3652, 'eval_quora_samples_per_second': 226.231, 'eval_quora_steps_per_second': 56.563, 'epoch': 0.0}
519 {'eval_natural-questions_loss': 0.016714297235012054, 'eval_natural-questions_runtime': 70.5634, 'eval_natural-questions_samples_per_second': 144.99, 'eval_natural-questions_steps_per_second': 36.251, 'epoch': 0.0}
520 {'loss': 0.6766, 'grad_norm': 0.0839328020811081, 'learning_rate': 9.986626832458283e-06, 'epoch': 0.0}
521 {'loss': 0.6579, 'grad_norm': 108.81422424316406, 'learning_rate': 9.98439797120133e-06, 'epoch': 0.0}
522 {'loss': 0.6551, 'grad_norm': 11.56656265258789, 'learning_rate': 9.982169109944378e-06, 'epoch': 0.01}
523 {'loss': 0.6071, 'grad_norm': 0.0011169251520186663, 'learning_rate': 9.979940248687423e-06, 'epoch': 0.01}
524 {'loss': 0.6588, 'grad_norm': 6.516519546508789, 'learning_rate': 9.97771138743047e-06, 'epoch': 0.01}

There is no any other log indicating any error.

The wandb plot below shows some interesting evidence of an issue where the first part of training loss belongs to the loss after reset while the latter half is before the reset. You can also see train/global_step changing around 126500. The "train/learning_rate" and "train/epoch" data stay the same because the pipeline is overriding existing data.

Screenshot 2024-12-24 at 12 23 47 AM

I have a few questions:

  1. Will auto_find_batch_size=True reset epoch?
  2. Does anyone know how to add more loggings if I can reproduce the issue?
  3. Is there a bug in the standard SentenceTransformerTrainer code?

Python version: 3.10.14
OS: Linux-6.8.0-48-generic-x86_64-with-glibc2.35 (Ubuntu 22)
GPU: NVIDIA GeForce RTX 4090
Package versions:
sentence-transformers 3.3.1
torch 2.1.0+cu121

Code Reference:

import torch

from datasets import load_dataset
from transformers import AutoTokenizer, LlamaForCausalLM, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses, models, SentenceTransformerTrainingArguments
from sentence_transformers.losses import CoSENTLoss, MultipleNegativesRankingLoss, SoftmaxLoss

import random
import numpy as np

# 0. Setting seeds
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(42)  # You can choose any seed value

# 1. Model setup
# Setup tokenizer with extra token for padding
tokenizer = AutoTokenizer.from_pretrained("keeeeenw/MicroLlama")
special_tokens_dict = {'pad_token': '[PAD]'}
tokenizer.add_special_tokens(special_tokens_dict)

# Load the model
base_model = models.Transformer("keeeeenw/MicroLlama", tokenizer_args={'pad_token': '[PAD]'})

# Check tokenizer and model vocab sizes before resizing
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Model vocab size before resize with padding: {base_model.auto_model.config.vocab_size}")

# Resize model embeddings to match the tokenizer
base_model.auto_model.resize_token_embeddings(len(tokenizer))

# Check model vocab size after resizing
print(f"Model vocab size after resize with padding token: {base_model.auto_model.config.vocab_size}")

# Pooling layer setup
pooling_model = models.Pooling(
    base_model.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

# Construct SentenceTransformer model
model = SentenceTransformer(modules=[base_model, pooling_model])

# 2. Load several Datasets to train with
# (anchor, positive)
all_nli_pair_train = load_dataset("sentence-transformers/all-nli", "pair", split="train")
# (premise, hypothesis) + label
all_nli_pair_class_train = load_dataset("sentence-transformers/all-nli", "pair-class", split="train")
# (sentence1, sentence2) + score
all_nli_pair_score_train = load_dataset("sentence-transformers/all-nli", "pair-score", split="train")
# (anchor, positive, negative)
all_nli_triplet_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
# (sentence1, sentence2) + score
stsb_pair_score_train = load_dataset("sentence-transformers/stsb", split="train")
# (anchor, positive)
quora_pair_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:139000]")
# (query, answer)
natural_questions_train = load_dataset("sentence-transformers/natural-questions", split="train[:90000]")

# We can combine all datasets into a dictionary with dataset names to datasets
train_dataset = {
    "all-nli-pair": all_nli_pair_train,
    "all-nli-pair-class": all_nli_pair_class_train,
    "all-nli-pair-score": all_nli_pair_score_train,
    "all-nli-triplet": all_nli_triplet_train,
    "stsb": stsb_pair_score_train,
    "quora": quora_pair_train,
    "natural-questions": natural_questions_train,
}

# 3. Load several Datasets to evaluate with
# (anchor, positive, negative)
all_nli_triplet_dev = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
# (sentence1, sentence2, score)
stsb_pair_score_dev = load_dataset("sentence-transformers/stsb", split="validation")
# (anchor, positive)
quora_pair_dev = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[139000:]")
# (query, answer)
natural_questions_dev = load_dataset("sentence-transformers/natural-questions", split="train[90000:]")

# We can use a dictionary for the evaluation dataset too, but we don't have to. We could also just use
# no evaluation dataset, or one dataset.
eval_dataset = {
    "all-nli-triplet": all_nli_triplet_dev,
    "stsb": stsb_pair_score_dev,
    "quora": quora_pair_dev,
    "natural-questions": natural_questions_dev,
}

# 4. Load several loss functions to train with
# (anchor, positive), (anchor, positive, negative)
mnrl_loss = MultipleNegativesRankingLoss(model)
# (sentence_A, sentence_B) + class
softmax_loss = SoftmaxLoss(model, model.get_sentence_embedding_dimension(), 3)
# (sentence_A, sentence_B) + score
cosent_loss = CoSENTLoss(model)

# Create a mapping with dataset names to loss functions, so the trainer knows which loss to apply where.
# Note that you can also just use one loss if all of your training/evaluation datasets use the same loss
losses = {
    "all-nli-pair": mnrl_loss,
    "all-nli-pair-class": softmax_loss,
    "all-nli-pair-score": cosent_loss,
    "all-nli-triplet": mnrl_loss,
    "stsb": cosent_loss,
    "quora": mnrl_loss,
    "natural-questions": mnrl_loss,
}

train_batch_size = 4  # Batch size
eval_batch_size = 4  # Batch size
num_epochs = 3  # Number of epochs
train_args = SentenceTransformerTrainingArguments(
    "tmp_trainer-full-3-epoch-linear-lr-1e-5-batch-4",
    per_device_train_batch_size=train_batch_size,  # Batch size per GPU (or CPU if no GPU is used)
    per_device_eval_batch_size=eval_batch_size,   # Evaluation batch size if you have eval dataset
    num_train_epochs=num_epochs,  # Number of epochs
    learning_rate=1e-5,            # use a smaller LR
    auto_find_batch_size=True,        # auto adjust batch size
    evaluation_strategy="steps",      # Evaluate every N steps
    eval_steps=2500,                  # Perform evaluation every 5000 steps
)
# Default is save for every 500 steps
# https://sbert.net/docs/package_reference/sentence_transformer/training_args.html#sentence_transformers.training_args.SentenceTransformerTrainingArguments.set_save
train_args.set_save(strategy="steps", steps=5000)

# Issue: for some reason it restarted training around
"""
{'loss': 0.7495, 'grad_norm': 0.0019107084954157472, 'learning_rate': 9.433869240733956e-06, 'epoch': 0.17}
  6%|█████                                                                                    | 127421/2243298 [6:34:58<109:18:46,  5.38it/s]
{'loss': 0.7143, 'grad_norm': 0.016145436093211174, 'learning_rate': 9.997771138743048e-06, 'epoch': 0.0}        | 0/2243298 [00:00<?, ?it/s]
{'loss': 0.6057, 'grad_norm': 0.34214434027671814, 'learning_rate': 9.995542277486094e-06, 'epoch': 0.0}
"""
trainer = SentenceTransformerTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=losses,
    tokenizer=tokenizer,
    args = train_args
)
trainer.train()

# 6. save the trained model and optionally push it to the Hugging Face Hub
model.save_pretrained("microllama300m-base-all-nli-stsb-quora-nq")
# model.push_to_hub("microllama300m-base-all-nli-stsb-quora-nq")
@nguyenvannghiem0312
Copy link

I think, I also have a similar problem here. link

@keeeeenw
Copy link
Author

@nguyenvannghiem0312 thanks for responding. I ran a few more tests and it looks like the issue is related to auto_find_batch_size=True. When the training pipeline runs into OOM issue, it will try to reset the global step no obvious warning message. You can reproduce this using by using a slightly larger batch size than the one support by your GPU. Interestingly, if the batch size is really big, if will just give up and print a warning message after a few tries of reseting global step.

I would like to ask for the following improvements.

  1. Print a clear warning message / log when auto_find_batch_size happens and indicate the new batch size.
  2. Try not to reset global step. If you must reset the global step, you should start a complete new run and wandb session.

I am happy to do a pull request on the improvements above if you agree with my suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants