Skip to content

A highly capable 2.4B lightweight LLM using only 1T pre-training data with all details.

License

Notifications You must be signed in to change notification settings

RUC-GSAI/YuLan-Mini

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

70 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

YuLan-Mini: An Open Data-efficient Language Model

license Hugging Face

YuLan-Mini is a lightweight language model with 2.4 billion parameters. It achieves performance comparable to industry-leading models trained on significantly more data, despite being pre-trained on only 1.08T tokens. The model excels particularly in the domains of mathematics and code. To facilitate reproducibility, we open-source the relevant pre-training resources.


Model Downloads πŸ”—

Model Context Length SFT πŸ€— Hugging Face Wise Model
YuLan-Mini (Recommended) 28K ❎ YuLan-Mini YuLan-Mini
YuLan-Mini-2.4B-4K 4K ❎
YuLan-Mini-Instruct Comming soon βœ…

The intermediate checkpoint can be found here.


Features 🌟

Our pre-training methodology improves training efficiency through three key innovations:

  1. an elaborately designed data pipeline that combines data cleaning with data schedule strategies;
  2. a systematic optimization method that can effectively mitigate training instability;
  3. an effective annealing approach that integrate targeted data selection and long context training.

Behchmarks 🌟

Note: The model size calculation includes the embedding size.

Models Model Size # Train Tokens Context Length MATH 500 GSM 8K Human Eval MBPP RACE Middle RACE High RULER
MiniCPM 2.71B 1.06T 4K 15.00 53.83 50.00* 47.31 56.61 44.27 N/A
Qwen-2 1.54B 7T 128K 22.60 46.90* 34.80* 46.90* 55.77 43.69 60.16
Qwen2.5 0.49B 18T 128K 23.60 41.60* 30.50* 39.30* 52.36 40.31 49.23
Qwen2.5 1.54B 18T 128K 45.40 68.50* 37.20* 60.20* 58.77 44.33 68.26
Gemma2 2.61B 2T 8K 18.30* 30.30* 19.50* 42.10* - - N/A
StableLM2 1.64B 2T 4K - 20.62 8.50* 17.50 56.33 45.06 N/A
SmolLM2 1.71B 11T 8K 11.80 - 23.35 45.00 55.77 43.06 N/A
Llama3.2 3.21B 9T 128K 7.40 - 29.30 49.70 55.29 43.34 77.06
YuLan-Mini 2.42B 1.04T 4K 32.60 66.65 61.60 66.70 55.71 43.58 N/A
YuLan-Mini 2.42B 1.08T 28K 37.80 68.46 64.00 65.90 57.18 44.57 51.48
Models LAMBADA MMLU CMMLU CEval HellaSwag WinoGrande StoryCloze ARC-e ARC-c
MiniCPM-2.71B 61.91 53.37 48.97 48.24 67.92 65.74 78.51 55.51 43.86
Qwen2-1.54B 64.68 55.90 70.76 71.94 66.11 66.14 77.60 62.21 42.92
Qwen2.5-0.49B 52.00 47.50 52.17 54.27 50.54 55.88 71.67 56.10 39.51
Qwen2.5-1.54B 62.12 60.71 67.82 69.05 67.18 64.48 76.80 71.51 53.41
Gemma2-2.61B - 52.20* - 28.00* 74.60* 71.50* - - 55.70*
StableLM2-1.64B 66.15 40.37 29.29 26.99 69.79 64.64 78.56 54.00 40.78
SmolLM2-1.71B 67.42 51.91 33.46 35.10 72.96 67.40 79.32 44.82 35.49
Llama3.2-3.21B 69.08 63.40 44.44 44.49 75.62 67.48 76.80 70.12 48.81
YuLan-Mini-2.42B-4K 64.72 51.79 48.35 51.47 68.65 67.09 76.37 69.87 50.51
YuLan-Mini-2.42B-28K 65.67 49.10 45.45 48.23 67.22 67.24 75.89 67.47 49.32

Pre-Training Resources πŸ”§

To enhance research transparency and reproducibility, we are open-sourcing relevant pre-training resources:

Pre-Training

1. Pre-training and Evaluation Code

The pre-training code can be found here. Note that due to subsequent code modifications, this code may not run directly and may require some adjustments.

Step 1: Modify the config.json

Due to the implementation of Hugging Face Trainer, certain parameters are stored in the config.json file and cannot be modified through the Trainer's command-line arguments. Therefore, you need to update these parameters in the config.json file first, particularly:

  • save_steps: The frequency of saving intermediate checkpoints.
  • train_batch_size: The batch size per GPU (equivalent to per_device_train_batch_size in the Trainer). We used a batch size of 1008 (approximately 4M tokens) during the stable training stage. Maintaining this same batch size is equally important for training effectiveness.

Below is an example of a properly configured config.json file:

{
  "best_metric": null,
  "best_model_checkpoint": null,
  "epoch": 0.0,
  "eval_steps": 500,
  "global_step": 0,
  "is_hyper_param_search": false,
  "is_local_process_zero": true,
  "is_world_process_zero": true,
  "log_history": [],
  "logging_steps": 3,
  "max_steps": 0,
  "num_input_tokens_seen": 0,
  "num_train_epochs": 0,
  "save_steps": 250,
  "stateful_callbacks": {
    "TrainerControl": {
      "args": {
        "should_epoch_stop": false,
        "should_evaluate": false,
        "should_log": false,
        "should_save": true,
        "should_training_stop": true
      },
      "attributes": {}
    }
  },
  "total_flos": 0,
  "train_batch_size": 3,
  "trial_name": null,
  "trial_params": null
}

Step 2: Enable Universal Checkpointing in the DeepSpeed Configuration

To ensure DeepSpeed Integration loads the Universal Checkpoint, you need to enable this feature in the DeepSpeed configuration JSON file.

Here is an example of a ZeRO2 configuration with Universal Checkpointing enabled:

{
  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 8e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 8e8,
    "contiguous_gradients": true
  },
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "steps_per_print": 16,
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "wall_clock_breakdown": false,
  "dump_state": true,
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": "auto",
      "betas": "auto",
      "eps": "auto",
      "weight_decay": "auto"
    }
  },
  "checkpoint": {
    "load_universal": true
  }
}

Step 3: Resume Training

When calling trainer.train, include the resume_from_checkpoint argument to load the distributed optimizer state from the Universal Checkpoint and resume training.

trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

We provide an internal training framework for your reference, but you are free to choose other frameworks.

2. Intermediate Stage Checkpoints The intermediate stage checkpoints are released in YuLan-Mini.
Stage Curriculum Phase 4K Context 28K Context Optimizer Inference Architecture LAMBADA Acc GSM8K Acc HumanEval pass@1
Stable 5 YuLan-Mini-Phase5 yulanmini 53.85 3.41 12.26
Stable 10 YuLan-Mini-Phase10 yulanmini 55.00 9.57 15.95
Stable 15 YuLan-Mini-Phase15 yulanmini 55.81 13.81 16.99
Stable 20 YuLan-Mini-Phase20 βœ… yulanmini 55.81 21.39 20.79
Stable 25 (1T tokens) YuLan-Mini-Before-Annealing βœ… yulanmini 55.67 29.94 34.06
Annealing 26 YuLan-Mini-4K llama* 64.72 66.65 61.60
Annealing 27 YuLan-Mini llama* 65.67 68.46 64.00

*: For easier inference and deployment, we merged the re-parameterized added parameters and scaling factors into the final released models (YuLan-Mini and YuLan-Mini-Intermediate-4K), enabling it to run on the Llama architecture. However, these parameters are still retained in the intermediate checkpoints from the training process.

3. Optimizer States Before Annealing

YuLan-Mini-Before-Annealing

Datasets

4. The Used Open-Source Datasets

Used-Datasets-List

5. Data Distribution for every phase
6. Synthetic Data

Data cleaning and synthesis pipeline:

The synthetic data we are using is released in YuLan-Mini-Datasets

What you can do with these pre-training resources

  1. Pre-train your own LLM. You can use our data and curriculum to train a model that's just as powerful as YuLan-Mini.
  2. Perform your own learning rate annealing. During the annealing phase, YuLan-Mini's learning ability is at its peak. You can resume training from the checkpoint before annealing and use your own dataset for learning rate annealing.
  3. Fine-tune the Instruct version of the LLM. You can use the YuLan-Mini base model to train your own Instruct version.
  4. Training dynamics research. You can use YuLan-Mini's intermediate checkpoints to explore internal changes during the pre-training process.
  5. Synthesize your own data. You can use YuLan-Mini's data pipeline to clean and generate your own dataset.

Quick Start πŸ’»

Below is a simple example for inference using Huggingface:

Huggingface Inference Example

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("yulan-team/YuLan-Mini")
model = AutoModelForCausalLM.from_pretrained("yulan-team/YuLan-Mini", torch_dtype=torch.bfloat16)

# Input text
input_text = "Renmin University of China is"
inputs = tokenizer(input_text, return_tensors="pt")

# Completion
output = model.generate(inputs["input_ids"], max_new_tokens=100)
print(tokenizer.decode(output[0], skip_special_tokens=True))

vLLM Serve Example

vllm serve yulan-team/YuLan-Mini --dtype bfloat16

SGLang Serve Example

python -m sglang.launch_server --model-path yulan-team/YuLan-Mini --port 30000 --host 0.0.0.0

The Team

YuLan-Mini is developed and maintained by AI Box, Renmin University of China.

License

  • The code in this repository, the model weights, and optimizer states are released under the MIT License.
  • Policies regarding the use of model weights, intermediate optimizer states, and training data will be announced in future updates.
  • Limitations: Despite our efforts to mitigate safety concerns and encourage the generation of ethical and lawful text, the probabilistic nature of language models may still lead to unexpected outputs. For instance, responses might contain bias, discrimination, or other harmful content. Please refrain from disseminating such content. We are not liable for any consequences arising from the spread of harmful information.

Citation

If you find YuLan-Mini helpful for your research or development, please cite our technical report:

@misc{hu2024yulanmini,
      title={YuLan-Mini: An Open Data-efficient Language Model},
      author={Yiwen Hu and Huatong Song and Jia Deng and Jiapeng Wang and Jie Chen and Kun Zhou and Yutao Zhu and Jinhao Jiang and Zican Dong and Wayne Xin Zhao and Ji-Rong Wen},
      year={2024},
      eprint={2412.17743},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2412.17743},
}

About

A highly capable 2.4B lightweight LLM using only 1T pre-training data with all details.

Resources

License

Stars

Watchers

Forks

Contributors 4

  •  
  •  
  •  
  •