-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsft.py
109 lines (93 loc) · 3.45 KB
/
sft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Script for SFT training
Script adapted from the TRL library:
https://github.com/huggingface/trl/blob/v0.13.0/trl/scripts/sft.py
"""
import logging
import time
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import (
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
setup_chat_format,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from unpaired_rlhf.utils.runtime import set_seed
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
# Add dataset name and a timestamp to the output directory
training_args.output_dir += f"-{model_args.model_name_or_path.split('/')[-1]}-{script_args.dataset_name.split('/')[-1]}-{time.strftime('%Y%m%d_%H%M%S')}"
training_args.output_dir = training_args.output_dir.replace("_", "-")
training_args.run_name = training_args.output_dir
# Set seed everywhere
set_seed(training_args.seed)
################
# Model init kwargs & Tokenizer
################
logger.info("Loading the pretrained model...")
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
use_fast=True,
)
tokenizer.pad_token = tokenizer.eos_token
# Align padding tokens between tokenizer and model
model.config.pad_token_id = tokenizer.pad_token_id
# If post-training a base model, use ChatML as the default template
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
################
# Dataset
################
logger.info("Loading the dataset...")
dataset = load_dataset(script_args.dataset_name)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
logger.info("Starting training...")
trainer.train()
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
logger.info("Training complete.")
################
# Evaluation
################
logger.info("Evaluating the model...")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)