-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtask.py
181 lines (154 loc) · 7.43 KB
/
task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
from dataclasses import asdict
from itertools import cycle, islice
from pathlib import Path
import hivemind
import torch
import torch.nn as nn
import transformers
from dalle_pytorch import DALLE
from dalle_pytorch.vae import VQGanVAE
from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
import utils
from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
from data import make_dataset
from huggingface_auth import authorize_with_huggingface
from lib.training.lamb_8bit import CPULAMB8Bit
logger = hivemind.get_logger(__name__)
class VQGanParams(VQGanVAE):
def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True):
nn.Module.__init__(self)
self.num_layers = num_layers
self.image_size = image_size
self.num_tokens = num_tokens
self.is_gumbel = is_gumbel
class ModelWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, image):
loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
return {'loss': loss}
class TrainingTask:
"""A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
_authorizer = _dht = _collaborative_optimizer = _training_dataset = None
def __init__(
self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
self.trainer_args.run_name = self.authorizer.username # For wandb
self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
transformers.set_seed(trainer_args.seed) # seed used for initialization
self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info(f"Creating model")
depth = 64
attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
attn_types.append('conv_like')
shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
shared_layer_ids.append('w_conv')
dalle = DALLE(
vae=VQGanParams(),
num_text_tokens=self.tokenizer.vocab_size,
text_seq_len=trainer_args.text_seq_length,
dim=1024,
depth=depth,
heads=16,
dim_head=64,
attn_types=attn_types,
ff_dropout=0,
attn_dropout=0,
shared_attn_ids=shared_layer_ids,
shared_ff_ids=shared_layer_ids,
rotary_emb=True,
reversible=True,
share_input_output_emb=True,
)
logger.info(f"Trainable parameters: "
f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
self.model = ModelWrapper(dalle)
output_dir = Path(trainer_args.output_dir)
logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
if latest_checkpoint_dir is not None:
logger.info(f"Loading model from {latest_checkpoint_dir}")
self.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
@property
def authorizer(self):
if self._authorizer is None and self.peer_args.authorize:
self._authorizer = authorize_with_huggingface()
return self._authorizer
@property
def dht(self):
if self._dht is None:
self._dht = hivemind.DHT(
start=True,
initial_peers=self.peer_args.initial_peers,
client_mode=self.peer_args.client_mode,
host_maddrs=self.peer_args.host_maddrs,
announce_maddrs=self.peer_args.announce_maddrs,
use_ipfs=self.peer_args.use_ipfs,
record_validators=self.validators,
identity_path=self.peer_args.identity_path,
authorizer=self.authorizer,
)
if self.peer_args.client_mode:
logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
else:
utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
return self._dht
@property
def collaborative_optimizer(self):
if self._collaborative_optimizer is None:
params, opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
averaging_compression = SizeAdaptiveCompression(
threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
self._collaborative_optimizer = hivemind.Optimizer(
dht=self.dht, run_id=self.peer_args.experiment_prefix,
params=params, optimizer=opt, scheduler=scheduler,
offload_optimizer=True, delay_grad_averaging=False, delay_optimizer_step=True,
batch_size_per_step=self.trainer_args.batch_size_per_step,
grad_compression=averaging_compression, state_averaging_compression=averaging_compression,
client_mode=self.peer_args.client_mode, verbose=True,
**asdict(self.collab_args))
return self._collaborative_optimizer
def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
no_decay = ["bias", "LayerNorm.weight"]
params = [
{
"params": [p for n, p in self.model.named_parameters()
if not any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": training_args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters()
if any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": 0.0,
},
]
opt = lambda params: CPULAMB8Bit(
params,
lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
max_grad_norm=training_args.max_grad_norm,
clamp_value=training_args.clamp_value,
reuse_grad_buffers=True,
)
scheduler = lambda opt: get_linear_schedule_with_warmup(
opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
)
return params, opt, scheduler
@property
def training_dataset(self):
if self._training_dataset is None:
self._training_dataset = make_dataset(
self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
max_sequence_length=self.trainer_args.text_seq_length
)
return self._training_dataset
@property
def data_collator(self):
return DataCollatorWithPadding(tokenizer=self.tokenizer,
padding='max_length', max_length=self.trainer_args.text_seq_length)