Skip to content

Commit

Permalink
refine GLM (#187)
Browse files Browse the repository at this point in the history
* refine GLM

* style

* glm: add 1x1

* add MFU

* add MFU annotation for case readme

* add e2e_time for GLM 1x1

* update 1x1 e2e_time to about 2h

---------

Co-authored-by: zhouyu <[email protected]>
  • Loading branch information
yuzhou03 and zhouyu authored Aug 21, 2023
1 parent 7157a94 commit 9cb95dd
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 88 deletions.
8 changes: 5 additions & 3 deletions training/benchmarks/glm/pytorch/config/_base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import ClassVar
#from train.event.base import BaseTrainingEventInterface
# required parameters

# case info
# chip vendor: nvidia, kunlunxin, iluvatar, cambricon etc. key vendor is required.
vendor: str = None
# model name
name: str = "GLM"
cudnn_benchmark: bool = False
cudnn_deterministic: bool = True
data_dir: str = None

do_train = True
fp16 = True
# =========================================================
# data
# =========================================================
data_dir: str = "/mnt/data/glm/train/"

train_data: str = "ReCoRD/glm_train_eval_hdf5_sparse/train_hdf5/train_sparse.hdf5"
eval_data: str = "ReCoRD/glm_train_eval_hdf5_sparse/eval_hdf5/eval_sparse.hdf5"
output_dir: str = ""
Expand Down
66 changes: 29 additions & 37 deletions training/benchmarks/glm/pytorch/run_pretraining.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,39 @@
# Copyright (c) 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
"""GLM Pretraining"""

import time
import argparse
import os
import sys
import numpy as np
import time
import torch
import random

import config
from dataloaders import (WorkerInitializer, build_train_dataloader,
build_eval_dataloaders)
from train.trainer import Trainer, Evaluator
from train.training_state import TrainingState
# from train.event import TrainingEventCompose, TrainingLogger, BaseTrainingEventInterface

from train import trainer_adapter

CURR_PATH = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.abspath(os.path.join(CURR_PATH, "../../")))
import driver
from driver import Driver, Event, dist_pytorch, check
from driver import Event, dist_pytorch, check
from driver.helper import InitHelper

logger = None


def main():
import config

from config import mutable_params
global logger
global config

if config.use_env and 'LOCAL_RANK' in os.environ:
config.local_rank = int(os.environ['LOCAL_RANK'])

glm_driver = Driver(config, config.mutable_params)
glm_driver.setup_config(argparse.ArgumentParser("Glm"))
glm_driver.setup_modules(globals(), locals())

init_helper = InitHelper(config)
glm_driver = init_helper.init_driver(globals(), locals())
logger = glm_driver.logger

dist_pytorch.init_dist_training_env(config)
Expand All @@ -54,12 +52,13 @@ def main():
else:
worker_seed = worker_seeds[0]

random.seed(worker_seed)
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
init_helper.set_seed(config.seed, config.vendor)

worker_init = WorkerInitializer.default(worker_seed)
train_dataloader = build_train_dataloader(config, worker_init)
eval_dataloader = build_eval_dataloaders(config)

evaluator = Evaluator(config, None)
evaluator = Evaluator(config, eval_dataloader)
training_state = TrainingState()
trainer = Trainer(driver=glm_driver,
adapter=trainer_adapter,
Expand All @@ -72,71 +71,64 @@ def main():
dist_pytorch.barrier(config.vendor)
trainer.init()

eval_dataloader = build_eval_dataloaders(config)

dist_pytorch.barrier(config.vendor)
init_evaluation_start = time.time()
evaluator.dataloader = eval_dataloader
score = trainer.evaluator.evaluate(trainer)
training_state.eval_accuracy = score
init_evaluation_end = time.time()
init_evaluation_info = dict(eval_accuracy=score,
time=init_evaluation_end -
init_evaluation_start)
# training_event.on_init_evaluate(init_evaluation_info)
glm_driver.event(Event.INIT_EVALUATION, init_evaluation_info)

train_dataloader = build_train_dataloader(config, worker_init)

if not config.do_train:
return config, training_state

# training_event.on_init_end()
glm_driver.event(Event.INIT_END)
init_end_time = logger.previous_log_time
training_state.init_time = (init_end_time - init_start_time) / 1e+3

dist_pytorch.barrier(config.vendor)

epoch = -1
# training_event.on_train_begin()
glm_driver.event(Event.TRAIN_START)
raw_train_start_time = logger.previous_log_time
raw_train_start_time = time.time()

epoch = 0
while training_state.num_trained_samples < config.max_samples_termination and not training_state.end_training:
epoch += 1
training_state.epoch = epoch
train_dataloader.sampler.set_epoch(epoch)
trainer.train_one_epoch(train_dataloader)
epoch += 1

# training_event.on_train_end()
glm_driver.event(Event.TRAIN_END)
raw_train_end_time = logger.previous_log_time
training_state.raw_train_time = (raw_train_end_time -
raw_train_start_time) / 1e+3
training_state.raw_train_time = time.time() - raw_train_start_time

return config, training_state


if __name__ == "__main__":

now = time.time()
config, state = main()
config_upadted, state = main()

if not dist_pytorch.is_main_process():
sys.exit()

e2e_time = time.time() - now
if config.do_train:
training_perf = (dist_pytorch.global_batch_size(config) *
if config_upadted.do_train:
training_perf = (dist_pytorch.global_batch_size(config_upadted) *
state.global_steps) / state.raw_train_time
finished_info = {
"e2e_time": e2e_time,
"global_steps": state.global_steps,
"num_trained_samples": state.num_trained_samples,
"training_sequences_per_second": training_perf,
"converged": state.converged,
"final_accuracy": state.eval_accuracy,
"raw_train_time": state.raw_train_time,
"init_time": state.init_time,
"pure_training_computing_time": state.pure_compute_time,
"throughput(ips)_raw": state.num_trained_samples / state.raw_train_time,
"throughput(ips)_no_eval": state.num_trained_samples / state.no_eval_time,
"throughput(ips)_pure_compute": state.num_trained_samples / state.pure_compute_time,
}
else:
finished_info = {"e2e_time": e2e_time}
Expand Down
7 changes: 3 additions & 4 deletions training/benchmarks/glm/pytorch/train/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# coding=utf-8
# Copyright (c) 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import torch


Expand Down Expand Up @@ -58,14 +60,12 @@ def multichoice_evaluate(model, dataloader, args, segment_length=10):
total_score += batch_score

model.train()
#config.training_event_instance.device_barrier()
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.all_reduce(total_sample,
op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(total_score,
op=torch.distributed.ReduceOp.SUM)

# print(f"samples:{total_sample}, score:{total_score}")
score = total_score / total_sample

return score.item()
Expand All @@ -77,5 +77,4 @@ def em_evaluate(predictions, labels):
for pred, true_list in zip(predictions, labels):
if pred in true_list:
score += 1
# score = 100.0 * score / len(predictions)
return score
52 changes: 23 additions & 29 deletions training/benchmarks/glm/pytorch/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
import torch
from torch.types import Device
import os
import sys
# Copyright (c) 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import time
import math

import torch
from torch.types import Device

import config
from model import create_model
from schedulers import create_scheduler

from train.evaluator import Evaluator
from train.training_state import TrainingState

import config

CURR_PATH = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.abspath(os.path.join(CURR_PATH, "../../")))
from driver import Driver, Event, dist_pytorch


def process_batch(batch, device):
"""Process batch and produce inputs for the model."""
batch = {t: batch[t].to(device) for t in batch if t != 'answer_idx'}

return batch


Expand Down Expand Up @@ -64,8 +60,6 @@ def _init_model(self, model, args, device):
# Load the checkpoint.
sd = torch.load(checkpoint_name, map_location='cpu')

# model = model.module

# Model.
def extend_embedding_weights(state_weights, model_weights):
original_length = state_weights.shape[0]
Expand Down Expand Up @@ -111,12 +105,17 @@ def extend_embedding_weights(state_weights, model_weights):
def train_one_epoch(self, dataloader):
state = self.training_state
driver = self.driver
dataloader.sampler.set_epoch(state.epoch)
driver.event(Event.EPOCH_BEGIN, state.epoch)

step_start_time = time.time()
epoch_start_num_sample = state.num_trained_samples

no_eval_start_time = time.time()
iter_end_time = no_eval_start_time

for batch_idx, batch in enumerate(dataloader):
iter_start_time = time.time()
dataload_time = iter_start_time - iter_end_time

state.global_steps += 1
# TODO: Maybe we should update num_trained_samples after all epochs.
Expand All @@ -125,6 +124,8 @@ def train_one_epoch(self, dataloader):

driver.event(Event.STEP_BEGIN, step=state.global_steps)
self.train_one_step(batch)
self.training_state.no_eval_time += (
time.time() - iter_start_time) + dataload_time

other_state = dict()
if state.global_steps % self.config.gradient_accumulation_steps == 0:
Expand Down Expand Up @@ -159,43 +160,36 @@ def train_one_epoch(self, dataloader):
if eval_result is not None:
driver.event(Event.EVALUATE, eval_result)

iter_end_time = time.time()

if end_training:
break

epoch_start_num_sample += len(dataloader.dataset)
state.num_trained_samples = epoch_start_num_sample

driver.event(Event.EPOCH_END, state.epoch)

def train_one_step(self, batch):
data = process_batch(batch, self.config.device)
state = self.training_state

# self.training_event.on_step_begin(state.global_steps)
self.model.train()
pure_compute_start_time = time.time()

lm_loss, _ = self.forward(data)
lm_loss /= self.config.gradient_accumulation_steps
reduced_loss = lm_loss.detach().clone().view(1)
if torch.distributed.is_available(
) and torch.distributed.is_initialized():
if dist_pytorch.is_dist_avail_and_initialized():
torch.distributed.all_reduce(reduced_loss.data)
reduced_loss.data = reduced_loss.data / (dist_pytorch.get_world_size())

state.loss = lm_loss
#lm_loss.backward()
#self.optimizer.step()
self.adapter.backward(state.global_steps, lm_loss, reduced_loss,
self.optimizer, self.lr_scheduler, self.model)
#self.adapter.backward(state.global_steps, state.loss, self.optimizer)
#self.adapter.backward(state.global_steps, reduced_loss, self.optimizer)
#self.adapter.backward(state.global_steps, reduced_loss, self.optimizer, self.lr_scheduler)
# self.training_event.on_backward(
# state.global_steps, lm_loss, reduced_loss, self.optimizer, self.lr_scheduler)
#self.lr_scheduler.step()

self.training_state.pure_compute_time += time.time(
) - pure_compute_start_time

self.driver.event(Event.BACKWARD, state.global_steps, state.loss,
self.optimizer, self.grad_scaler)
#self.lr_scheduler.step()

def detect_training_status(self, state):
config = self.config
Expand Down
6 changes: 0 additions & 6 deletions training/benchmarks/glm/pytorch/train/trainer_adapter.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import os
import sys

from torch.optim import Optimizer
from torch import nn, Tensor
from typing import Tuple

import optimizers
try:
from apex.optimizers import FusedAdam as Adam
except ImportError:
from torch.optim import AdamW as Adam
from optimizers import FP16_Optimizer, get_optimizer_param_groups

CURR_PATH = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.abspath(os.path.join(CURR_PATH, "../../../")))
from driver.dist_pytorch import main_proc_print


Expand Down
6 changes: 6 additions & 0 deletions training/benchmarks/glm/pytorch/train/training_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
from dataclasses import dataclass
import inspect
import torch
Expand All @@ -23,6 +26,9 @@ class TrainingState:
init_time = 0
raw_train_time = 0

no_eval_time = 0
pure_compute_time = 0

def status(self):
if self.converged:
self._status = "success"
Expand Down
Loading

0 comments on commit 9cb95dd

Please sign in to comment.