forked from tatsu-lab/stanford_alpaca
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmezo.py
893 lines (756 loc) · 39.3 KB
/
mezo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import contextlib
import functools
import glob
import inspect
import math
import os
import random
import re
import shutil
import sys
import time
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import copy
from metrics import f1
import numpy as np
from tqdm.auto import tqdm
from transformers import Trainer
from sklearn.linear_model import LinearRegression, LogisticRegression, LogisticRegressionCV
# Integrations must be imported before ML frameworks:
from transformers.integrations import ( # isort: split
default_hp_search_backend,
get_reporting_integration_callbacks,
hp_params,
is_fairscale_available,
is_optuna_available,
is_ray_tune_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb,
)
import numpy as np
import torch
import torch.distributed as dist
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from huggingface_hub import Repository
from transformers import __version__
from transformers.configuration_utils import PretrainedConfig
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from transformers.dependency_versions_check import dep_version_check
from transformers.modelcard import TrainingSummary
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from transformers.optimization import Adafactor, get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
ShardSampler,
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_module_class_from_name,
get_parameter_names,
nested_concat,
nested_detach,
nested_numpify,
nested_truncate,
nested_xla_mesh_reduce,
reissue_pt_warnings,
)
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalLoopOutput,
EvalPrediction,
FSDPOption,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
PredictionOutput,
RemoveColumnsCollator,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
default_hp_space,
denumpify_detensorize,
enable_full_determinism,
find_executable_batch_size,
get_last_checkpoint,
has_length,
number_of_arguments,
seed_worker,
set_seed,
speed_metrics,
)
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.utils import (
CONFIG_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
find_labels,
get_full_repo_name,
is_apex_available,
is_datasets_available,
is_in_notebook,
is_ipex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tensorrt_fx_available,
is_torch_tpu_available,
is_torchdynamo_available,
logging,
)
from transformers.utils.generic import ContextManagers
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
if is_in_notebook():
from .utils.notebook import NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
if is_apex_available():
from apex import amp
if is_datasets_available():
import datasets
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
if is_fairscale_available():
dep_version_check("fairscale")
import fairscale
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.nn.wrap import auto_wrap
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
IS_SAGEMAKER_MP_POST_1_10 = False
if TYPE_CHECKING:
import optuna
logger = logging.get_logger(__name__)
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
class OurTrainer(Trainer):
from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
"""
We overload the original training loop to add linear probing and MeZO. Search key word "MeZO added"
for those updates.
"""
self._train_batch_size = batch_size
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
# MeZO added: Linear probing
if self.args.linear_probing:
def _get_token_prediction_layer(model):
if model.config.model_type == "opt":
return model.lm_head
else:
raise NotImplementedError(model.config.model_type)
def _extract_features(model, *args, **kwargs):
"""some magic for getting features pre last layer"""
features = {}
def __hook(model_, input_, output_):
features["features"] = input_[0].detach()
_get_token_prediction_layer(model).register_forward_hook(__hook)
model.forward(*args, **kwargs)
return features["features"]
logger.info("Linear probing")
logger.info("Starting to get features for training dataset")
targets = []
features = []
with torch.inference_mode():
for step, inputs in enumerate(tqdm(train_dataloader)):
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.model.device)
feature = _extract_features(self.model, **inputs)
target = inputs["labels"]
# Shift the target (bc it's autoregressive LM) and add the corresponding part
assert not self.args.train_as_classification and self.args.only_train_option
feature, target = feature[:, :-1], target[:, 1:]
for _i, _len in enumerate(inputs["option_len"]):
features.append(feature[_i, -_len:])
targets.append(target[_i, -_len:])
logger.info("Finished getting features for training dataset")
features = torch.cat(features, dim=0).cpu().numpy()
targets = torch.cat(targets, dim=0).cpu().numpy()
# Whether to use bias
if self.model.config.model_type in ["opt", "gpt2"]:
use_bias = False
else:
raise NotImplementedError
# Set early stopping
tol = 0.01 if self.args.lp_early_stopping else 1e-4 # 1e-4 is scipy default
max_iter = 1000 if self.args.lp_early_stopping else 5000
logger.info("Fitting logistic regression...")
reg = LogisticRegressionCV(max_iter=max_iter, fit_intercept=use_bias, multi_class="multinomial", random_state=0, tol=tol, n_jobs=-1).fit(features, targets)
logger.info("Done")
logger.info("Assigning weights to model")
decoder = _get_token_prediction_layer(self.model)
coef_torch = torch.tensor(reg.coef_, device=decoder.weight.device, dtype=decoder.weight.dtype)
if use_bias:
bias_torch = torch.tensor(reg.intercept_, device=decoder.weight.device, dtype=decoder.weight.dtype)
if coef_torch.shape[0] == 1: # The regressor only detects two classes
assert len(reg.classes_) == 2
coef_torch = torch.cat([-coef_torch / 2, coef_torch / 2], dim=0)
if use_bias:
bias_torch = torch.cat([-bias_torch / 2, bias_torch / 2], dim=0)
for _i, token_id in enumerate(reg.classes_):
decoder.weight.data[token_id] = coef_torch[_i]
if use_bias:
decoder.bias.data[token_id] = bias_torch[_i]
return None
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
len_dataloader = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
f" {args.max_steps}"
)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
# references registered here no longer work on other gpus, breaking the module
raise ValueError(
"Currently --debug underflow_overflow is not supported under DP. Please use DDP"
" (torch.distributed.launch)."
)
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = (
self.sharded_ddp is not None
and self.sharded_ddp != ShardedDDPOption.SIMPLE
or is_sagemaker_mp_enabled()
or self.fsdp is not None
)
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
elif not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
model = self._wrap_model(self.model_wrapped)
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(
f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
)
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
"flag to your launch command, but you will resume the training on data already seen by your model."
)
if self.is_local_process_zero() and not args.disable_tqdm:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
if self.hp_name is not None and self._trial is not None:
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# parameter to Train when using DDP.
self.state.trial_name = self.hp_name(self._trial)
if trial is not None:
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
self.state.trial_params = hp_params(assignments)
else:
self.state.trial_params = None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
train_dataloader.sampler, RandomSampler
)
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
# That was before PyTorch 1.11 however...
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
_ = list(train_dataloader.sampler)
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
train_dataloader.dataset.set_epoch(epoch)
if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = parallel_loader
else:
epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
step = -1
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
# MeZO added: estimate gradient
if args.trainer == "zo":
tr_loss_step = self.zo_step(model, inputs)
else:
if (
((step + 1) % args.gradient_accumulation_steps != 0)
and args.local_rank != -1
and args._no_sync_in_gradient_accumulation
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss_step = self.training_step(model, inputs)
else:
tr_loss_step = self.training_step(model, inputs)
if (
args.logging_nan_inf_filter
and not is_torch_tpu_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
tr_loss += tr_loss_step
self.current_flos += float(self.floating_point_ops(inputs))
# Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
if self.deepspeed:
self.deepspeed.step()
if (step + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# MeZO added: update model with the estimated gradient
if args.trainer == "zo":
self.zo_update(model)
else:
# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
# deepspeed does its own clipping
if self.do_grad_scaling:
# Reduce gradients first for XLA
if is_torch_tpu_available():
gradients = xm._fetch_gradients(self.optimizer)
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
args.max_grad_norm,
)
# Optimizer step
optimizer_was_run = True
if self.deepspeed:
pass # called outside the loop
elif is_torch_tpu_available():
if self.do_grad_scaling:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
xm.optimizer_step(self.optimizer)
elif self.do_grad_scaling:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
optimizer_was_run = scale_before <= scale_after
else:
self.optimizer.step()
if optimizer_was_run and not self.deepspeed:
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if step < 0:
logger.warning(
"There seems to be not a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
else:
logger.warning(
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.control.should_training_stop:
break
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sur the model has been saved by process 0.
if is_torch_tpu_available():
xm.rendezvous("load_best_model_at_end")
elif args.local_rank != -1:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()
self._load_best_model()
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint.
if self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if checkpoint != self.state.best_model_checkpoint:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
return TrainOutput(self.state.global_step, train_loss, metrics)
############## MeZO ##############
def zo_perturb_parameters(self, random_seed=None, scaling_factor=1):
"""
Perturb the parameters with random vector z.
Input:
- random_seed: random seed for MeZO in-place perturbation (if it's None, we will use self.zo_random_seed)
- scaling_factor: theta = theta + scaling_factor * z * eps
"""
# Set the random seed to ensure that we sample the same z for perturbation/update
torch.manual_seed(random_seed if random_seed is not None else self.zo_random_seed)
for name, param in self.named_parameters_to_optim:
z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
param.data = param.data + scaling_factor * z * self.args.zo_eps
def zo_forward(self, model, inputs):
"""
Get (no gradient) loss from the model. Dropout is turned off too.
"""
model.eval()
if self.args.non_diff:
# Non-differentiable objective (may require autoregressive generation)
return self.zo_forward_nondiff(model, inputs)
with torch.inference_mode():
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
# Warning: this is copied from the original Huggingface Trainer. Untested.
loss = loss.mean() # mean() to average on multi-gpu parallel training
return loss.detach()
def zo_forward_nondiff(self, model, inputs):
"""
Get (no gradient) non-diffiable loss from the model.
"""
model.eval()
assert self.args.task_name == "SQuAD", "Non differentiable objective only supports SQuAD for now."
with torch.inference_mode():
inputs = self._prepare_inputs(inputs)
args = self.args
outputs = self.model.generate(
inputs["input_ids"], do_sample=args.sampling, temperature=args.temperature,
num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k, max_new_tokens=min(args.max_new_tokens, args.max_length - inputs["input_ids"].size(1)),
num_return_sequences=1, eos_token_id=[self.tokenizer.encode(args.eos_token, add_special_tokens=False)[0], self.tokenizer.eos_token_id],
)
output_text = []
for i in range(len(outputs)):
output_text.append(self.tokenizer.decode(outputs[i][inputs["input_ids"].size(1):], skip_special_tokens=True).strip())
f1s = [f1(output_text[i], inputs['gold'][i]) for i in range(len(output_text))]
return -torch.tensor(np.mean(f1s), dtype=torch.float32)
def zo_step(self, model, inputs):
"""
Estimate gradient by MeZO. Return the loss from f(theta + z)
"""
args = self.args
# What parameters to optimize
self.named_parameters_to_optim = []
for name, param in model.named_parameters():
if param.requires_grad:
self.named_parameters_to_optim.append((name, param))
# Sample the random seed for sampling z
self.zo_random_seed = np.random.randint(1000000000)
# First function evaluation
self.zo_perturb_parameters(scaling_factor=1)
loss1 = self.zo_forward(model, inputs)
# Second function evaluation
self.zo_perturb_parameters(scaling_factor=-2)
loss2 = self.zo_forward(model, inputs)
self.projected_grad = ((loss1 - loss2) / (2 * self.args.zo_eps)).item()
# No gradient accumulation support
assert self.args.gradient_accumulation_steps == 1
# Reset model back to its parameters at start of step
self.zo_perturb_parameters(scaling_factor=1)
return loss1
def zo_update(self, model):
"""
Update the parameters with the estimated gradients.
"""
args = self.args
# Reset the random seed for sampling zs
torch.manual_seed(self.zo_random_seed)
for name, param in self.named_parameters_to_optim:
# Resample z
z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
param.data = param.data - self._get_learning_rate() * (self.projected_grad * z + args.weight_decay * param.data)
else:
param.data = param.data - self._get_learning_rate() * (self.projected_grad * z)
self.lr_scheduler.step()
############## Misc overload functions ##############
def _set_signature_columns_if_needed(self):
"""
We overload this function for non-differentiable objective training to pass "gold" -- the gold text for the task
"""
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
self._signature_columns += ["gold"]
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""
We overload this function to fix an FSDP saving bug (before fix, it will likely cause OOM)
"""
if output_dir is None:
output_dir = self.args.output_dir
if is_torch_tpu_available():
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
os.makedirs(output_dir, exist_ok=True)
state_dict = self.model_wrapped.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if IS_SAGEMAKER_MP_POST_1_10:
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
Path(os.path.join(output_dir, "user_content.pt")).touch()
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
or self.fsdp is not None
):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
# Fix the FSDP loading bug
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
state_dict = self.model.state_dict()
# state_dict = self.model.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.deepspeed:
# this takes care of everything as long as we aren't under zero3
if self.args.should_save:
self._save(output_dir)
if is_deepspeed_zero3_enabled():
# It's too complicated to try to override different places where the weights dump gets
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either user deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.args.should_save:
file = os.path.join(output_dir, WEIGHTS_NAME)
if os.path.isfile(file):
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
os.remove(file)
# now save the real model if stage3_gather_16bit_weights_on_model_save=True
# if false it will not be saved.
# This must be called on all ranks
if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
logger.warning(
"deepspeed.save_16bit_model didn't save the model, since"
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
)
self.deepspeed.save_checkpoint(output_dir)
elif self.args.should_save:
self._save(output_dir)
# Push to the Hub when `save_model` is called by the user.
if self.args.push_to_hub and not _internal_call:
self.push_to_hub(commit_message="Model save")