forked from tatsu-lab/stanford_alpaca
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmezo_args.py
82 lines (66 loc) · 3.64 KB
/
mezo_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from transformers import TrainingArguments
from dataclasses import dataclass
@dataclass
class OurArguments(TrainingArguments):
# dataset and sampling strategy
task_name: str = "SST2" # task name should match the string before Dataset in the Dataset class name. We support the following task_name: SST2, RTE, CB, BoolQ, WSC, WIC, MultiRC, Copa, ReCoRD, SQuAD, DROP
# Number of examples
num_train: int = 0 # ICL mode: number of demonstrations; training mode: number of training samples
num_dev: int = None # (only enabled with training) number of development samples
num_eval: int = None # number of evaluation samples
num_train_sets: int = None # how many sets of training samples/demos to sample; if None and train_set_seed is None, then we will sample one set for each evaluation sample
train_set_seed: int = None # designated seed to sample training samples/demos
result_file: str = None # file name for saving performance; if None, then use the task name, model name, and config
# Model loading
model_name: str = "facebook/opt-125m" # HuggingFace model name
load_float16: bool = False # load model parameters as float16
load_bfloat16: bool = False # load model parameters as bfloat16
load_int8: bool = False # load model parameters as int8
max_length: int = 2048 # max length the model can take
no_auto_device: bool = False # do not load model by auto device; should turn this on when using FSDP
# Calibration
sfc: bool = False # whether to use SFC calibration
icl_sfc: bool = False # whether to use SFC calibration for ICL samples
# Training
trainer: str = "none"
## options
## - none: no training -- for zero-shot or in-context learning (ICL)
## - regular: regular huggingface trainer -- for fine-tuning
## - zo: zeroth-order (MeZO) training
only_train_option: bool = True # whether to only train the option part of the input
train_as_classification: bool = False # take the log likelihood of all options and train as classification
# MeZO
zo_eps: float = 1e-3 # eps in MeZO
# Prefix tuning
prefix_tuning: bool = False # whether to use prefix tuning
num_prefix: int = 5 # number of prefixes to use
no_reparam: bool = True # do not use reparameterization trick
prefix_init_by_real_act: bool = True # initialize prefix by real activations of random words
# LoRA
lora: bool = False # whether to use LoRA
lora_alpha: int = 16 # alpha in LoRA
lora_r: int = 8 # r in LoRA
# Generation
sampling: bool = False # whether to use sampling
temperature: float = 1.0 # temperature for generation
num_beams: int = 1 # number of beams for generation
top_k: int = None # top-k for generation
top_p: float = 0.95 # top-p for generation
max_new_tokens: int = 50 # max number of new tokens to generate
eos_token: str = "\n" # end of sentence token
# Saving
save_model: bool = False # whether to save the model
no_eval: bool = False # whether to skip evaluation
tag: str = "" # saving tag
# Linear probing
linear_probing: bool = False # whether to do linear probing
lp_early_stopping: bool = False # whether to do early stopping in linear probing
head_tuning: bool = False # head tuning: only tune the LM head
# Untie emb/lm_head weights
untie_emb: bool = False # untie the embeddings and LM head
# Display
verbose: bool = False # verbose output
# Non-diff objective
non_diff: bool = False # use non-differentiable objective (only support F1 for SQuAD for now)
# Auto saving when interrupted
save_on_interrupt: bool = False # save model when interrupted (useful for long training)