Skip to content

Commit

Permalink
Merge pull request karpathy#243 from gkielian/add_progress_bar
Browse files Browse the repository at this point in the history
Add progress bar to train.py
  • Loading branch information
klei22 authored Aug 25, 2024
2 parents c46fe8e + 270d0af commit 291eb55
Showing 1 changed file with 136 additions and 126 deletions.
262 changes: 136 additions & 126 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import time

from model_info_util.model_info import print_summary, print_module_structure, print_model_blocks, print_model_tree

from rich.progress import Progress

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
Expand Down Expand Up @@ -778,154 +781,161 @@ def train(self):
for head in range(self.args.n_head):
graph_y_labels.append(f"Layer {layer} Head {head}")

while True:
lr = self.get_lr(self.iter_num) if self.args.decay_lr else self.args.learning_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr

if self.iter_num % self.args.eval_interval == 0 and self.master_process:
losses = self.estimate_loss()
print(f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
self.log_metrics(losses, lr, running_mfu, self.iter_num)

if math.isnan(losses["val"]):
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : 0,
'nan' : True,
'config': vars(self.args),
}
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
if losses['val'] < self.best_val_loss or self.args.always_save_checkpoint:
if losses['val'] < self.best_val_loss:
self.iter_num_best_val_loss = self.iter_num
self.best_val_loss = losses['val']
# Save best validation loss
with open(os.path.join(self.args.out_dir, 'best_val_loss_and_iter.txt'), "w") as best_loss_file:
best_loss_file.write(str(self.best_val_loss.item())+","+str(self.iter_num))
# Reset early exit counter
num_steps_with_worse_loss = 0
if self.iter_num > 0:
# Create progress bar
progress = Progress()
with progress:
task_id = progress.add_task("[green]Training...", total=(self.args.max_iters - self.iter_num))
while True:
lr = self.get_lr(self.iter_num) if self.args.decay_lr else self.args.learning_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr

if self.iter_num % self.args.eval_interval == 0 and self.master_process:
losses = self.estimate_loss()
print(f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
self.log_metrics(losses, lr, running_mfu, self.iter_num)

if math.isnan(losses["val"]):
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : None,
'nan' : None,
'nan_iter_num' : 0,
'nan' : True,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
# Save checkpoint
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
# Try new checkpoint if better val loss
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)
elif self.args.sample_each_eval:
# Try model inference (e.g. exploring inference from overfitting)
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)

if self.args.patience is not None and num_steps_with_worse_loss >= self.args.patience:
print(f"Early Stopping: loss has not decreased in {self.args.patience + 1} steps")
if losses['val'] < self.best_val_loss or self.args.always_save_checkpoint:
if losses['val'] < self.best_val_loss:
self.iter_num_best_val_loss = self.iter_num
self.best_val_loss = losses['val']
# Save best validation loss
with open(os.path.join(self.args.out_dir, 'best_val_loss_and_iter.txt'), "w") as best_loss_file:
best_loss_file.write(str(self.best_val_loss.item())+","+str(self.iter_num))
# Reset early exit counter
num_steps_with_worse_loss = 0
if self.iter_num > 0:
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : None,
'nan' : None,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
# Save checkpoint
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
# Try new checkpoint if better val loss
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)
elif self.args.sample_each_eval:
# Try model inference (e.g. exploring inference from overfitting)
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)

if self.args.patience is not None and num_steps_with_worse_loss >= self.args.patience:
print(f"Early Stopping: loss has not decreased in {self.args.patience + 1} steps")
break
if losses['val'] > self.best_val_loss:
num_steps_with_worse_loss += 1

if self.iter_num == 0 and self.args.eval_only:
break
if losses['val'] > self.best_val_loss:
num_steps_with_worse_loss += 1

if self.iter_num == 0 and self.args.eval_only:
break

for micro_step in range(self.args.gradient_accumulation_steps):
if self.ddp:
self.model.require_backward_grad_sync = (micro_step == self.args.gradient_accumulation_steps - 1)

with self.ctx:
logits, loss = self.model(self.X, self.Y)
loss = loss / self.args.gradient_accumulation_steps

self.X, self.Y = self.get_batch('train')

self.scaler.scale(loss).backward()

if self.args.grad_clip != 0.0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)

self.scaler.step(self.optimizer)
self.scaler.update()

self.optimizer.zero_grad(set_to_none=True)

t1 = time.time()
dt = t1 - t0
t0 = t1
if self.iter_num % self.args.log_interval == 0 and self.master_process:
lossf = loss.item() * self.args.gradient_accumulation_steps
if local_iter_num >= 5:
mfu = self.raw_model.estimate_mfu(self.args.batch_size * self.args.gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f} ms, mfu {running_mfu*100:.2f}%")
if math.isnan(lossf):
if self.args.save_nan_checkpoint:
for micro_step in range(self.args.gradient_accumulation_steps):
if self.ddp:
self.model.require_backward_grad_sync = (micro_step == self.args.gradient_accumulation_steps - 1)

with self.ctx:
logits, loss = self.model(self.X, self.Y)
loss = loss / self.args.gradient_accumulation_steps

self.X, self.Y = self.get_batch('train')

self.scaler.scale(loss).backward()

if self.args.grad_clip != 0.0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)

self.scaler.step(self.optimizer)
self.scaler.update()

self.optimizer.zero_grad(set_to_none=True)

t1 = time.time()
dt = t1 - t0
t0 = t1
if self.iter_num % self.args.log_interval == 0 and self.master_process:
lossf = loss.item() * self.args.gradient_accumulation_steps
if local_iter_num >= 5:
mfu = self.raw_model.estimate_mfu(self.args.batch_size * self.args.gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f} ms, mfu {running_mfu*100:.2f}%")
if math.isnan(lossf):
if self.args.save_nan_checkpoint:
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num_best_val_loss,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : self.iter_num,
'nan' : True,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
sys.exit("Exiting training loss is NaN")
self.log_metrics_non_validation(lossf, running_mfu, self.iter_num)


if self.args.create_statistics:
create_statistics(self, graph_y_labels)


self.iter_num += 1
local_iter_num += 1

# Update progress bar
progress.update(task_id, advance=1)

# End of training actions
if self.iter_num > self.args.max_iters:
if self.args.only_save_checkpoint_at_end:
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num_best_val_loss,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : self.iter_num,
'nan' : True,
'nan_iter_num' : None,
'nan' : None,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
sys.exit("Exiting training loss is NaN")
self.log_metrics_non_validation(lossf, running_mfu, self.iter_num)


if self.args.create_statistics:
create_statistics(self, graph_y_labels)


self.iter_num += 1
local_iter_num += 1

# End of training actions
if self.iter_num > self.args.max_iters:
if self.args.only_save_checkpoint_at_end:
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : None,
'nan' : None,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
# Sample if set
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)
break

if self.args.plot_statistics:
plot_statistics(self.args, self.stats, graph_y_labels)
# Sample if set
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)
break

if self.args.tensorboard_log:
self.writer.flush()
self.writer.close()
if self.args.plot_statistics:
plot_statistics(self.args, self.stats, graph_y_labels)

if self.args.wandb_log and self.master_process:
import wandb
wandb.log({"finished": True})
wandb.finish()
if self.args.tensorboard_log:
self.writer.flush()
self.writer.close()

if self.args.wandb_log and self.master_process:
import wandb
wandb.log({"finished": True})
wandb.finish()

def main():
args, model_group, training_group, logging_group = parse_args()
Expand Down

0 comments on commit 291eb55

Please sign in to comment.