-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun_trainer_tpu.py
executable file
·95 lines (72 loc) · 3.77 KB
/
run_trainer_tpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
import time
import wandb
import torch
import transformers
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers import HfArgumentParser
import utils
from arguments import TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments
from lib.training.tpu import TPUManager
from callback import CollaborativeCallback
from task import TrainingTask
transformers.utils.logging.set_verbosity_warning()
use_hivemind_log_handler("in_root_logger")
logger = get_logger()
transformers.training_args.is_torch_tpu_available = lambda: False # disable builtin TPU support to use custom code
torch.set_num_threads(min(torch.get_num_threads(), 4)) # Otherwise, it becomes very slow on machines with ~100 CPUs
def main():
parser = HfArgumentParser((TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments))
peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
logger.info(f"Found {len(peer_args.initial_peers)} initial peers: {peer_args.initial_peers}")
if len(peer_args.initial_peers) == 0:
logger.warning("Please specify at least one network endpoint in initial peers.")
utils.log_process_rank(trainer_args)
task = TrainingTask(peer_args, trainer_args, collab_args)
model = task.model
# BEGIN init TPU
assert trainer_args.do_train and not trainer_args.do_eval
tpu_manager = TPUManager(model, dataset=task.training_dataset, collate_fn=task.data_collator,
grad_accumulation_steps=trainer_args.gradient_accumulation_steps,
batch_size_per_device=trainer_args.per_device_train_batch_size,
nprocs=trainer_args.n_tpus, start=True)
model = task.model = tpu_manager._synchronizer.master_model
# warmup tpus
logger.info("Waiting for TPUs to warm up, this may take a minute...")
tpu_manager.step()
logger.info("Warmup step 1 / 3 done.")
tpu_manager.update_model_parameters(model.parameters())
tpu_manager.step()
logger.info("Warmup step 2 / 3 done.")
tpu_manager.step()
tpu_manager.get_aggregated_gradients()
tpu_manager.zero_grad()
logger.info("Warmup step 3 / 3 done.")
# END init TPU
def push_params_onto_tpu():
logger.info("Pushing new params onto TPU.")
tpu_manager.update_model_parameters(model.parameters())
tpu_manager.zero_grad()
collaborative_optimizer = task.collaborative_optimizer
collaborative_optimizer.callbacks.on_after_global_step.add(push_params_onto_tpu)
collaborative_optimizer.callbacks.on_load_state_from_peers(push_params_onto_tpu)
collaborative_training_callback = CollaborativeCallback(task, peer_args)
state = transformers.TrainerState()
control = transformers.TrainerControl()
collaborative_training_callback.on_train_begin(trainer_args, state, control)
tpu_manager.update_model_parameters(model.parameters())
wandb.init(project=trainer_args.wandb_project, name=trainer_args.run_name)
while True:
start_time = time.perf_counter()
loss, num_accumulated = tpu_manager.step()
time_delta = time.perf_counter() - start_time
logger.info(f"Accumulated {num_accumulated} gradients at {num_accumulated / time_delta:.3f} samples/second.")
wandb.log({"train/loss": loss, "train/learning_rate": collaborative_optimizer.state_averager.scheduler.get_lr()[0]})
with torch.no_grad():
for param, grad_from_tpu in zip(model.parameters(), tpu_manager.get_aggregated_gradients()):
param.grad[...] = grad_from_tpu
collaborative_optimizer.step()
state.log_history.append(dict(loss=loss))
collaborative_training_callback.on_step_end(trainer_args, state, control)
if __name__ == "__main__":
main()