From 46566f7a33cd86812e5b52c7ef6f9e78d6ad0bf7 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Tue, 21 Jan 2025 09:34:27 +0800 Subject: [PATCH 01/10] all bfloat16 training --- train_gpt.py | 299 ++++++++++++++++++++++++++------------------------- 1 file changed, 151 insertions(+), 148 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 6b7d46f8..0e7ae01f 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -467,153 +467,156 @@ class Hyperparameters: # implementation seq_len = 64*1024 # FlexAttention sequence length save_checkpoint = False -args = Hyperparameters() - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = uuid.uuid4() - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): + +def train(args: Hyperparameters): + # torchrun sets these env variables + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + assert torch.cuda.is_available() + device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) + torch.cuda.set_device(device) + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = (rank == 0) # this process will do logging, checkpointing etc. + + # begin logging + logfile = None if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -# load data -train_loader = distributed_data_generator(args.train_files, args.batch_size, rank, world_size) - -model = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for p in model.blocks.parameters() if p.ndim >= 2] -embed_params = [model.embed.weight, *model.value_embeds.parameters()] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -adam_params = [dict(params=head_params, lr=0.008), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)] -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), fused=True, eps=1e-10) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, rank=rank, world_size=world_size) -optimizers = [optimizer1, optimizer2] - -# learning rate schedule: stable then decay -def get_lr(it: int): - t = 1 - it / args.num_iterations # time remaining in training - assert 1 >= t >= 0 - w = min(t / args.cooldown_frac, 1.0) # 1 -> 0 - return w * 1.0 + (1 - w) * 0.1 -schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] -@lru_cache(1) -def sw_num_blks(window_size: int): - return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - -model: nn.Module = torch.compile(model) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - # This effectively ignores timing first 10 steps, which are slower for weird reasons. - # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 - # steps with dummy data first, and then re-initialize the model and reset the loader. - if step == 10: - training_time_ms = 0 - t0 = time.perf_counter() - timed_steps = float("nan") if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val - - # Linearly increase the block-wise sliding window size over training 128 -> 1792: - # increase by @fernbear.bsky.social; block-wise by @YouJiacheng - window_size = next_multiple_of_n(1728 * step / train_steps, n=128) - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - val_bs = world_size * args.seq_len - assert args.val_tokens % val_bs == 0 - val_steps = args.val_tokens // val_bs - val_loader = distributed_data_generator(args.val_files, val_bs, rank, world_size) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - x, y = next(val_loader) - val_loss += model(x, y, sw_num_blks(window_size)) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION BEGIN ----------------- - inputs, targets = next(train_loader) - for input_seq, target_seq in zip(inputs.split(args.seq_len), targets.split(args.seq_len)): - model(input_seq, target_seq, sw_num_blks(window_size)).backward() + run_id = uuid.uuid4() + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) + def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + + # begin by printing this file (the Python code) + print0(code) + print0("="*100) + # log information about the hardware/software environment this is running on + print0(f"Running Python {sys.version}") + print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") + def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout + print0(nvidia_smi()) + print0("="*100) + + # load data + train_loader = distributed_data_generator(args.train_files, args.batch_size, rank, world_size) + + model = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768).cuda().bfloat16() for param in model.parameters(): - dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) - # momentum warmup for Muon - frac = min(step / 300, 1) - for group in optimizer2.param_groups: - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers and schedulers - for opt, sched in zip(optimizers, schedulers): - opt.step() - sched.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_time = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms", console=True) - -print0( - f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" -) -dist.destroy_process_group() + dist.broadcast(param.detach(), 0) + + # collect the parameters to optimize + hidden_matrix_params = [p for p in model.blocks.parameters() if p.ndim >= 2] + embed_params = [model.embed.weight, *model.value_embeds.parameters()] + scalar_params = [p for p in model.parameters() if p.ndim < 2] + head_params = [model.lm_head.weight] + + # init the optimizer(s) + adam_params = [dict(params=head_params, lr=0.008), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)] + # small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence + # discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 + optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), fused=True, eps=1e-10) + optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, rank=rank, world_size=world_size) + optimizers = [optimizer1, optimizer2] + + # learning rate schedule: stable then decay + def get_lr(it: int): + t = 1 - it / args.num_iterations # time remaining in training + assert 1 >= t >= 0 + w = min(t / args.cooldown_frac, 1.0) # 1 -> 0 + return w * 1.0 + (1 - w) * 0.1 + schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] + @lru_cache(1) + def sw_num_blks(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + + model: nn.Module = torch.compile(model) + training_time_ms = 0 + # start the clock + torch.cuda.synchronize() + t0 = time.perf_counter() + # begin training + train_steps = args.num_iterations + for step in range(train_steps + 1): + last_step = (step == train_steps) + # This effectively ignores timing first 10 steps, which are slower for weird reasons. + # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 + # steps with dummy data first, and then re-initialize the model and reset the loader. + if step == 10: + training_time_ms = 0 + t0 = time.perf_counter() + timed_steps = float("nan") if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val + + # Linearly increase the block-wise sliding window size over training 128 -> 1792: + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * step / train_steps, n=128) + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_bs = world_size * args.seq_len + assert args.val_tokens % val_bs == 0 + val_steps = args.val_tokens // val_bs + val_loader = distributed_data_generator(args.val_files, val_bs, rank, world_size) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + x, y = next(val_loader) + val_loss += model(x, y, sw_num_blks(window_size)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION BEGIN ----------------- + inputs, targets = next(train_loader) + for input_seq, target_seq in zip(inputs.split(args.seq_len), targets.split(args.seq_len)): + loss = model(input_seq, target_seq, sw_num_blks(window_size)) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss.backward() + for param in model.parameters(): + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) + # momentum warmup for Muon + frac = min(step / 300, 1) + for group in optimizer2.param_groups: + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers and schedulers + for opt, sched in zip(optimizers, schedulers): + opt.step() + sched.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_time = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms", console=True) + + print0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + dist.destroy_process_group() + +if __name__ == "__main__": + args = Hyperparameters() + train(args) From 8099ee3242a56da3b776f487ebd91bf5333d1cc3 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Wed, 22 Jan 2025 06:44:24 +0800 Subject: [PATCH 02/10] bfloat16 in custom op --- train_gpt.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 0e7ae01f..ee0a0555 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -32,8 +32,8 @@ def impl(x: Tensor, w: Tensor): x_f8, w_f8.t(), out_dtype=torch.bfloat16, - scale_a=x.new_tensor(1 / x_s, dtype=torch.float32), - scale_b=x.new_tensor(1 / w_s, dtype=torch.float32), + scale_a=x.new_tensor(1 / x_s, dtype=torch.bfloat16), + scale_b=x.new_tensor(1 / w_s, dtype=torch.bfloat16), use_fast_accum=True, ) return out, x_f8, w_f8 @@ -53,9 +53,9 @@ def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float @torch.compile def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): assert grad.is_contiguous() - x_inv_s = grad.new_tensor(1 / x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(1 / w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(1 / grad_s, dtype=torch.float32) + x_inv_s = grad.new_tensor(1 / x_s, dtype=torch.bfloat16) + w_inv_s = grad.new_tensor(1 / w_s, dtype=torch.bfloat16) + grad_inv_s = grad.new_tensor(1 / grad_s, dtype=torch.bfloat16) grad_f8 = grad.mul(grad_s).to(torch.float8_e5m2) grad_x = torch._scaled_mm( grad_f8, @@ -69,7 +69,7 @@ def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): grad_w = torch._scaled_mm( x_f8.t().contiguous(), grad_f8.t().contiguous().t(), - out_dtype=torch.float32, + out_dtype=torch.bfloat16, scale_a=x_inv_s, scale_b=grad_inv_s, use_fast_accum=False, @@ -80,7 +80,7 @@ def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): @mm_backward_op.register_fake def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.to(torch.float32) + return x_f8.to(torch.bfloat16), w_f8.to(torch.bfloat16) def backward(ctx, grad_out: Tensor, *_): x_f8, w_f8 = ctx.saved_tensors @@ -249,7 +249,7 @@ def __init__(self, dim: int, max_seq_len=65536): def forward(self, x_BTHD: Tensor): assert self.cos.size(0) >= x_BTHD.size(-3) cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + x1, x2 = x_BTHD.to(dtype=torch.bfloat16).chunk(2, dim=-1) y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat((y1, y2), 3).type_as(x_BTHD) From 34c6cf6b6ac61011f3c2495b14dd25f9576a3ff0 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Wed, 22 Jan 2025 07:38:08 +0800 Subject: [PATCH 03/10] bfloat16 in custom op --- train_gpt.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index ee0a0555..4949646e 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -32,8 +32,8 @@ def impl(x: Tensor, w: Tensor): x_f8, w_f8.t(), out_dtype=torch.bfloat16, - scale_a=x.new_tensor(1 / x_s, dtype=torch.bfloat16), - scale_b=x.new_tensor(1 / w_s, dtype=torch.bfloat16), + scale_a=x.new_tensor(1 / x_s, dtype=torch.float32), + scale_b=x.new_tensor(1 / w_s, dtype=torch.float32), use_fast_accum=True, ) return out, x_f8, w_f8 @@ -53,9 +53,9 @@ def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float @torch.compile def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): assert grad.is_contiguous() - x_inv_s = grad.new_tensor(1 / x_s, dtype=torch.bfloat16) - w_inv_s = grad.new_tensor(1 / w_s, dtype=torch.bfloat16) - grad_inv_s = grad.new_tensor(1 / grad_s, dtype=torch.bfloat16) + x_inv_s = grad.new_tensor(1 / x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(1 / w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(1 / grad_s, dtype=torch.float32) grad_f8 = grad.mul(grad_s).to(torch.float8_e5m2) grad_x = torch._scaled_mm( grad_f8, From 3f96031ed6578bcbd5b25c5ed586d73dc02ee189 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Wed, 22 Jan 2025 07:46:52 +0800 Subject: [PATCH 04/10] increase num iterations --- train_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index 4949646e..e6fff3c8 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -460,7 +460,7 @@ class Hyperparameters: val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons # optimization batch_size = 8*64*1024 # batch size in tokens - num_iterations = 1393 # number of iterations to run + num_iterations = 1395 # number of iterations to run cooldown_frac = 0.4 # fraction of training spent cooling down the learning rate # evaluation and logging val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end From a3c83bae0f8c44f50eee6eed4b6d0ad7bdc9cbe6 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Thu, 23 Jan 2025 16:16:12 +0800 Subject: [PATCH 05/10] increase num_iterations & prepare for fp8 ops on other layers --- train_gpt.py | 71 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index e6fff3c8..f1b9a19d 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -99,11 +99,6 @@ def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): mm_op.register_autograd(backward, setup_context=setup_context) -def lm_head_fp8(x: Tensor, w: Tensor) -> Tensor: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, w, x_s=2.0, w_s=32.0, grad_s=2.0**29)[0] - return out.reshape(*x.shape[:-1], -1) - # ----------------------------------------------------------------------------- # Muon optimizer @@ -219,12 +214,16 @@ def update_prev(): # optimized Muon implementation contributed by @YouJiacheng # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the model -def norm(x): +def norm(x: Tensor): return F.rms_norm(x, (x.size(-1),)) class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int): + def __init__(self, in_features: int, out_features: int, use_fp8: bool = False, x_scale: float = 1.0, w_scale: float = 1.0, grad_scale: float = 1.0): super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_scale = x_scale + self.w_scale = w_scale + self.grad_scale = grad_scale def reset_parameters(self) -> None: std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) @@ -232,11 +231,16 @@ def reset_parameters(self) -> None: with torch.no_grad(): self.weight.uniform_(-bound, bound) - def forward(self, x): - return F.linear(x, self.weight.type_as(x)) + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_scale, w_s=self.w_scale, grad_s=self.grad_scale)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len=65536): + def __init__(self, dim: int, max_seq_len: int): super().__init__() # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) @@ -255,7 +259,7 @@ def forward(self, x_BTHD: Tensor): return torch.cat((y1, y2), 3).type_as(x_BTHD) class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, layer_idx: int): + def __init__(self, dim: int, num_heads: int, max_seq_len: int): super().__init__() assert dim % num_heads == 0 self.num_heads = num_heads @@ -265,7 +269,7 @@ def __init__(self, dim: int, num_heads: int, layer_idx: int): # https://x.com/hi_tysam/status/1879699187107033311 self.qkv_w = nn.Parameter(torch.empty(3, dim, dim).uniform_(-bound, bound)) self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) - self.rotary = Rotary(dim // num_heads) # dim // num_heads = head_dim + self.rotary = Rotary(dim // num_heads, max_seq_len) self.c_proj = CastedLinear(dim, dim) self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun @@ -288,27 +292,27 @@ def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask): return y class MLP(nn.Module): - def __init__(self, dim): + def __init__(self, dim: int): super().__init__() self.c_fc = CastedLinear(dim, 4 * dim) self.c_proj = CastedLinear(4 * dim, dim) self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 - def forward(self, x): + def forward(self, x: Tensor): x = self.c_fc(x) x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 x = self.c_proj(x) return x class Block(nn.Module): - def __init__(self, model_dim: int, num_heads: int, layer_idx: int): + def __init__(self, model_dim: int, num_heads: int, layer_idx: int, max_seq_len: int): super().__init__() # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(model_dim, num_heads, layer_idx) if layer_idx != 7 else None + self.attn = CausalSelfAttention(model_dim, num_heads, max_seq_len) if layer_idx != 7 else None self.mlp = MLP(model_dim) self.lambdas = nn.Parameter(torch.tensor([1., 0.])) - def forward(self, x, ve, x0, block_mask): + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask): x = self.lambdas[0] * x + self.lambdas[1] * x0 if self.attn is not None: x = x + self.attn(norm(x), ve, block_mask) @@ -316,14 +320,16 @@ def forward(self, x, ve, x0, block_mask): return x class ValueEmbedding(nn.Module): - def __init__(self, num_embeddings: int, embedding_dim: int): + def __init__(self, vocab_size: int, embedding_dim: int, num_layers: int, num_embeddings: int = 3): super().__init__() - self.embed = nn.ModuleList([nn.Embedding(num_embeddings, embedding_dim) for _ in range(3)]) + self.num_layers = num_layers + self.num_embeddings = num_embeddings + self.embed = nn.ModuleList([nn.Embedding(vocab_size, embedding_dim) for _ in range(num_embeddings)]) - def forward(self, input_seq) -> list[Tensor | None]: - ve = [emb(input_seq) for emb in self.embed] + def forward(self, x: Tensor) -> list[Tensor | None]: + ve = [emb(x) for emb in self.embed] # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2], None, None, None, None, None, None, ve[0], ve[1], ve[2]] + ve = [ve[0], ve[1], ve[2]] + [None] * (self.num_layers - 2 * self.num_embeddings) + [ve[0], ve[1], ve[2]] return ve # ----------------------------------------------------------------------------- @@ -333,12 +339,13 @@ def next_multiple_of_n(v: float | int, *, n: int): return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int, enable_lm_head_fp8: bool = True): super().__init__() + self.enable_lm_head_fp8 = enable_lm_head_fp8 self.embed = nn.Embedding(vocab_size, model_dim) # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - self.value_embeds = ValueEmbedding(vocab_size, model_dim) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, layer_idx) for layer_idx in range(num_layers)]) + self.value_embeds = ValueEmbedding(vocab_size, model_dim, num_layers) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, layer_idx, max_seq_len) for layer_idx in range(num_layers)]) # U-net design by @brendanh0gan self.num_encoder_layers = num_layers // 2 # Half of the layers for encoder self.num_decoder_layers = num_layers - self.num_encoder_layers # Remaining for decoder @@ -346,7 +353,7 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128)) + self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128), use_fp8=True, x_scale=2.0, w_scale=2.0**5, grad_scale=2.0**29) self.lm_head.weight.detach().zero_() # @Grad62304977 def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): @@ -404,16 +411,18 @@ def build_bm(sw_num_blocks: Tensor) -> BlockMask: skip_connections = [] # Encoder pass - process only the first half of the blocks block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm] + assert len(block_masks) == self.num_encoder_layers for i in range(self.num_encoder_layers): x = self.blocks[i](x, ve_enc[i], x0, block_masks[i]) skip_connections.append(x) # Decoder pass - process the remaining blocks with weighted skip connections block_masks.reverse() + assert len(block_masks) == self.num_decoder_layers for i in range(self.num_decoder_layers): x = x + self.skip_weights[i] * skip_connections.pop() x = self.blocks[self.num_encoder_layers + i](x, ve_dec[i], x0, block_masks[i]) x = norm(x) - logits = lm_head_fp8(x, self.lm_head.weight) if self.training else self.lm_head(x) + logits = self.lm_head(x) # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) logits = 30 * torch.sigmoid(logits.float() / 7.5) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq) @@ -460,13 +469,15 @@ class Hyperparameters: val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons # optimization batch_size = 8*64*1024 # batch size in tokens - num_iterations = 1395 # number of iterations to run + num_iterations = 1400 # number of iterations to run cooldown_frac = 0.4 # fraction of training spent cooling down the learning rate # evaluation and logging val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end # implementation seq_len = 64*1024 # FlexAttention sequence length + val_seq_len = 64*1024 # FlexAttention sequence length for validation save_checkpoint = False + enable_lm_head_fp8 = True def train(args: Hyperparameters): # torchrun sets these env variables @@ -508,7 +519,7 @@ def nvidia_smi(): # load data train_loader = distributed_data_generator(args.train_files, args.batch_size, rank, world_size) - model = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768).cuda().bfloat16() + model = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=args.seq_len, enable_lm_head_fp8=args.enable_lm_head_fp8).cuda().bfloat16() for param in model.parameters(): dist.broadcast(param.detach(), 0) @@ -563,7 +574,7 @@ def sw_num_blks(window_size: int): torch.cuda.synchronize() training_time_ms += 1000 * (time.perf_counter() - t0) model.eval() - val_bs = world_size * args.seq_len + val_bs = world_size * args.val_seq_len assert args.val_tokens % val_bs == 0 val_steps = args.val_tokens // val_bs val_loader = distributed_data_generator(args.val_files, val_bs, rank, world_size) From d9303252b6760aa824e395f9d08e22f27d924cba Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Thu, 23 Jan 2025 16:45:50 +0800 Subject: [PATCH 06/10] . --- train_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index f1b9a19d..71fac9d3 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -470,7 +470,7 @@ class Hyperparameters: # optimization batch_size = 8*64*1024 # batch size in tokens num_iterations = 1400 # number of iterations to run - cooldown_frac = 0.4 # fraction of training spent cooling down the learning rate + cooldown_frac = 0.35 # fraction of training spent cooling down the learning rate # evaluation and logging val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end # implementation From 58c99314d05cbbe42e1f5d436e2e3afc04b4839a Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Thu, 23 Jan 2025 16:51:41 +0800 Subject: [PATCH 07/10] . --- train_gpt.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 71fac9d3..babaab4d 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -339,9 +339,8 @@ def next_multiple_of_n(v: float | int, *, n: int): return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int, enable_lm_head_fp8: bool = True): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): super().__init__() - self.enable_lm_head_fp8 = enable_lm_head_fp8 self.embed = nn.Embedding(vocab_size, model_dim) # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 self.value_embeds = ValueEmbedding(vocab_size, model_dim, num_layers) @@ -469,15 +468,14 @@ class Hyperparameters: val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons # optimization batch_size = 8*64*1024 # batch size in tokens - num_iterations = 1400 # number of iterations to run - cooldown_frac = 0.35 # fraction of training spent cooling down the learning rate + num_iterations = 1405 # number of iterations to run + cooldown_frac = 0.4 # fraction of training spent cooling down the learning rate # evaluation and logging val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end # implementation seq_len = 64*1024 # FlexAttention sequence length val_seq_len = 64*1024 # FlexAttention sequence length for validation save_checkpoint = False - enable_lm_head_fp8 = True def train(args: Hyperparameters): # torchrun sets these env variables @@ -519,7 +517,7 @@ def nvidia_smi(): # load data train_loader = distributed_data_generator(args.train_files, args.batch_size, rank, world_size) - model = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=args.seq_len, enable_lm_head_fp8=args.enable_lm_head_fp8).cuda().bfloat16() + model = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=args.seq_len).cuda().bfloat16() for param in model.parameters(): dist.broadcast(param.detach(), 0) From b77c1233182b70361b859857a406b8bef333a73e Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Thu, 23 Jan 2025 16:53:08 +0800 Subject: [PATCH 08/10] . --- train_gpt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index babaab4d..5a2bb836 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -622,7 +622,8 @@ def sw_num_blks(window_size: int): print0( f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", + console=True, ) dist.destroy_process_group() From 63af732095700d5e7c2b5511adafc43bdb92362a Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Thu, 23 Jan 2025 19:36:44 +0800 Subject: [PATCH 09/10] adjust fp8 scales --- train_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index 5a2bb836..e82a4829 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -352,7 +352,7 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128), use_fp8=True, x_scale=2.0, w_scale=2.0**5, grad_scale=2.0**29) + self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128), use_fp8=True, x_scale=2.0, w_scale=2.0, grad_scale=2.0**9) self.lm_head.weight.detach().zero_() # @Grad62304977 def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): From 7904ab102282892e6cb5126c0da2f186a15bd476 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Thu, 23 Jan 2025 19:38:32 +0800 Subject: [PATCH 10/10] fix Rotary dtype computation --- train_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt.py b/train_gpt.py index e82a4829..ba9ff7c9 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -253,7 +253,7 @@ def __init__(self, dim: int, max_seq_len: int): def forward(self, x_BTHD: Tensor): assert self.cos.size(0) >= x_BTHD.size(-3) cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.bfloat16).chunk(2, dim=-1) + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat((y1, y2), 3).type_as(x_BTHD)