You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The wandb plot below shows some interesting evidence of an issue where the first part of training loss belongs to the loss after reset while the latter half is before the reset. You can also see train/global_step changing around 126500. The "train/learning_rate" and "train/epoch" data stay the same because the pipeline is overriding existing data.
I have a few questions:
Will auto_find_batch_size=True reset epoch?
Does anyone know how to add more loggings if I can reproduce the issue?
Is there a bug in the standard SentenceTransformerTrainer code?
importtorchfromdatasetsimportload_datasetfromtransformersimportAutoTokenizer, LlamaForCausalLM, AutoModelForCausalLMfromsentence_transformersimportSentenceTransformer, SentenceTransformerTrainer, losses, models, SentenceTransformerTrainingArgumentsfromsentence_transformers.lossesimportCoSENTLoss, MultipleNegativesRankingLoss, SoftmaxLossimportrandomimportnumpyasnp# 0. Setting seedsdefset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
set_seed(42) # You can choose any seed value# 1. Model setup# Setup tokenizer with extra token for paddingtokenizer=AutoTokenizer.from_pretrained("keeeeenw/MicroLlama")
special_tokens_dict= {'pad_token': '[PAD]'}
tokenizer.add_special_tokens(special_tokens_dict)
# Load the modelbase_model=models.Transformer("keeeeenw/MicroLlama", tokenizer_args={'pad_token': '[PAD]'})
# Check tokenizer and model vocab sizes before resizingprint(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Model vocab size before resize with padding: {base_model.auto_model.config.vocab_size}")
# Resize model embeddings to match the tokenizerbase_model.auto_model.resize_token_embeddings(len(tokenizer))
# Check model vocab size after resizingprint(f"Model vocab size after resize with padding token: {base_model.auto_model.config.vocab_size}")
# Pooling layer setuppooling_model=models.Pooling(
base_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True
)
# Construct SentenceTransformer modelmodel=SentenceTransformer(modules=[base_model, pooling_model])
# 2. Load several Datasets to train with# (anchor, positive)all_nli_pair_train=load_dataset("sentence-transformers/all-nli", "pair", split="train")
# (premise, hypothesis) + labelall_nli_pair_class_train=load_dataset("sentence-transformers/all-nli", "pair-class", split="train")
# (sentence1, sentence2) + scoreall_nli_pair_score_train=load_dataset("sentence-transformers/all-nli", "pair-score", split="train")
# (anchor, positive, negative)all_nli_triplet_train=load_dataset("sentence-transformers/all-nli", "triplet", split="train")
# (sentence1, sentence2) + scorestsb_pair_score_train=load_dataset("sentence-transformers/stsb", split="train")
# (anchor, positive)quora_pair_train=load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:139000]")
# (query, answer)natural_questions_train=load_dataset("sentence-transformers/natural-questions", split="train[:90000]")
# We can combine all datasets into a dictionary with dataset names to datasetstrain_dataset= {
"all-nli-pair": all_nli_pair_train,
"all-nli-pair-class": all_nli_pair_class_train,
"all-nli-pair-score": all_nli_pair_score_train,
"all-nli-triplet": all_nli_triplet_train,
"stsb": stsb_pair_score_train,
"quora": quora_pair_train,
"natural-questions": natural_questions_train,
}
# 3. Load several Datasets to evaluate with# (anchor, positive, negative)all_nli_triplet_dev=load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
# (sentence1, sentence2, score)stsb_pair_score_dev=load_dataset("sentence-transformers/stsb", split="validation")
# (anchor, positive)quora_pair_dev=load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[139000:]")
# (query, answer)natural_questions_dev=load_dataset("sentence-transformers/natural-questions", split="train[90000:]")
# We can use a dictionary for the evaluation dataset too, but we don't have to. We could also just use# no evaluation dataset, or one dataset.eval_dataset= {
"all-nli-triplet": all_nli_triplet_dev,
"stsb": stsb_pair_score_dev,
"quora": quora_pair_dev,
"natural-questions": natural_questions_dev,
}
# 4. Load several loss functions to train with# (anchor, positive), (anchor, positive, negative)mnrl_loss=MultipleNegativesRankingLoss(model)
# (sentence_A, sentence_B) + classsoftmax_loss=SoftmaxLoss(model, model.get_sentence_embedding_dimension(), 3)
# (sentence_A, sentence_B) + scorecosent_loss=CoSENTLoss(model)
# Create a mapping with dataset names to loss functions, so the trainer knows which loss to apply where.# Note that you can also just use one loss if all of your training/evaluation datasets use the same losslosses= {
"all-nli-pair": mnrl_loss,
"all-nli-pair-class": softmax_loss,
"all-nli-pair-score": cosent_loss,
"all-nli-triplet": mnrl_loss,
"stsb": cosent_loss,
"quora": mnrl_loss,
"natural-questions": mnrl_loss,
}
train_batch_size=4# Batch sizeeval_batch_size=4# Batch sizenum_epochs=3# Number of epochstrain_args=SentenceTransformerTrainingArguments(
"tmp_trainer-full-3-epoch-linear-lr-1e-5-batch-4",
per_device_train_batch_size=train_batch_size, # Batch size per GPU (or CPU if no GPU is used)per_device_eval_batch_size=eval_batch_size, # Evaluation batch size if you have eval datasetnum_train_epochs=num_epochs, # Number of epochslearning_rate=1e-5, # use a smaller LRauto_find_batch_size=True, # auto adjust batch sizeevaluation_strategy="steps", # Evaluate every N stepseval_steps=2500, # Perform evaluation every 5000 steps
)
# Default is save for every 500 steps# https://sbert.net/docs/package_reference/sentence_transformer/training_args.html#sentence_transformers.training_args.SentenceTransformerTrainingArguments.set_savetrain_args.set_save(strategy="steps", steps=5000)
# Issue: for some reason it restarted training around"""{'loss': 0.7495, 'grad_norm': 0.0019107084954157472, 'learning_rate': 9.433869240733956e-06, 'epoch': 0.17} 6%|█████ | 127421/2243298 [6:34:58<109:18:46, 5.38it/s]{'loss': 0.7143, 'grad_norm': 0.016145436093211174, 'learning_rate': 9.997771138743048e-06, 'epoch': 0.0} | 0/2243298 [00:00<?, ?it/s]{'loss': 0.6057, 'grad_norm': 0.34214434027671814, 'learning_rate': 9.995542277486094e-06, 'epoch': 0.0}"""trainer=SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=losses,
tokenizer=tokenizer,
args=train_args
)
trainer.train()
# 6. save the trained model and optionally push it to the Hugging Face Hubmodel.save_pretrained("microllama300m-base-all-nli-stsb-quora-nq")
# model.push_to_hub("microllama300m-base-all-nli-stsb-quora-nq")
The text was updated successfully, but these errors were encountered:
@nguyenvannghiem0312 thanks for responding. I ran a few more tests and it looks like the issue is related to auto_find_batch_size=True. When the training pipeline runs into OOM issue, it will try to reset the global step no obvious warning message. You can reproduce this using by using a slightly larger batch size than the one support by your GPU. Interestingly, if the batch size is really big, if will just give up and print a warning message after a few tries of reseting global step.
I would like to ask for the following improvements.
Print a clear warning message / log when auto_find_batch_size happens and indicate the new batch size.
Try not to reset global step. If you must reset the global step, you should start a complete new run and wandb session.
I am happy to do a pull request on the improvements above if you agree with my suggestions.
Hi folks,
I am trying to perform text embedding model training for my own base model.
However, I am running into an issue where my global step was reset to 0 around step 127421 and epoch 0.17 (line 509):
There is no any other log indicating any error.
The wandb plot below shows some interesting evidence of an issue where the first part of training loss belongs to the loss after reset while the latter half is before the reset. You can also see train/global_step changing around 126500. The "train/learning_rate" and "train/epoch" data stay the same because the pipeline is overriding existing data.
I have a few questions:
auto_find_batch_size=True
reset epoch?SentenceTransformerTrainer
code?Python version: 3.10.14
OS: Linux-6.8.0-48-generic-x86_64-with-glibc2.35 (Ubuntu 22)
GPU: NVIDIA GeForce RTX 4090
Package versions:
sentence-transformers 3.3.1
torch 2.1.0+cu121
Code Reference:
The text was updated successfully, but these errors were encountered: