Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #113 from ver217/feature/ckpt
Browse files Browse the repository at this point in the history
refactor load_checkpoint
  • Loading branch information
dujiangsu authored Aug 22, 2022
2 parents 0ddbb40 + 8669003 commit 495b45b
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 558 deletions.
13 changes: 8 additions & 5 deletions energonai/model/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from colossalai.context import ParallelMode
from energonai.utils import is_using_pp, get_current_device
from energonai.logging import get_dist_logger
from energonai.utils.checkpointing import load_checkpoint
from energonai.utils.checkpointing_hf_gpt2 import processing_HF_GPT
from energonai.utils.checkpointing_opt import processing_OPT


def gelu_impl(x):
Expand Down Expand Up @@ -160,12 +163,12 @@ def create_pipeline_model(depth: int = 48,
if "checkpoint" in model_kwargs.keys() and "model_name" in model_kwargs.keys():
start = time.time()
assert os.path.exists(model_kwargs["checkpoint"]), "Checkpoint file not found"
preprocess_fn = None
if model_kwargs["model_name"] == "hf_gpt2":
from energonai.utils.checkpointing_hf_gpt2 import load_checkpoint
load_checkpoint(model_kwargs["checkpoint"], model, **model_kwargs)
if model_kwargs["model_name"] == "opt":
from energonai.utils.checkpointing_opt import load_checkpoint
load_checkpoint(model_kwargs["checkpoint"], model, **model_kwargs)
preprocess_fn = processing_HF_GPT
elif model_kwargs["model_name"] == "opt":
preprocess_fn = processing_OPT
load_checkpoint(model_kwargs["checkpoint"], model, preprocess_fn=preprocess_fn, **model_kwargs)
logger.info(f'Load time: {time.time() - start:.3f} s')

return model
Expand Down
51 changes: 36 additions & 15 deletions energonai/utils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..communication.collective import scatter_object_list
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc

from typing import Optional, Callable
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
Expand All @@ -19,6 +19,29 @@
"partition_tensor_parallel_state_dict", "load_checkpoint", "gather_tensor_parallel_state_dict", "save_checkpoint"
]

import os
from multiprocessing import Pool
from time import time


def load_state_dict(path: str):
if os.path.isfile(path):
return torch.load(path)
assert os.path.isdir(path)
state_dict = {}
files = []
for filename in os.listdir(path):
filepath = os.path.join(path, filename)
if os.path.isfile(filepath):
files.append(filepath)
threads = torch.get_num_threads()
print(f'load {len(files)} files using {threads} threads')
with Pool(threads) as pool:
state_dicts = pool.map(torch.load, files)
for sd in state_dicts:
state_dict.update(sd)
return state_dict


def broadcast_state_dict(state_dict, parallel_mode):
"""
Expand Down Expand Up @@ -170,9 +193,8 @@ def remove_prefix(state_dict, prefix):

def load_checkpoint(file,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
strict: bool = True,
preprocess_fn: Optional[Callable[[dict], dict]] = None,
**kwargs):
"""Loads training states from a checkpoint file.
Expand All @@ -192,17 +214,22 @@ def load_checkpoint(file,
Raises:
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
"""
state_dict = (torch.load(file, map_location=torch.device("cpu"))
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)

# model states
model_state = state_dict.pop("model") if state_dict is not None else dict()
start = time()
if gpc.get_local_rank(ParallelMode.MODEL) == 0:
model_state = load_state_dict(file)
if preprocess_fn:
model_state = preprocess_fn(model_state)
else:
model_state = dict()
dist.barrier()
print(f'Load file time: {time()-start:.3f} s')
# pipeline
if is_using_pp():
model_state = partition_pipeline_parallel_state_dict(model, model_state, **kwargs)
if "prefix" in kwargs.keys():
if kwargs['prefix'] != '':
model_state = remove_prefix(model_state, kwargs["prefix"])

try:
model.load_state_dict(model_state, strict=strict)
except RuntimeError as e:
Expand All @@ -219,13 +246,7 @@ def load_checkpoint(file,
else:
raise e

# broadcast the rest states
state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL)

# last epoch
last_epoch = state_dict.pop("epoch", -1)

return last_epoch
return -1


def save_checkpoint(file,
Expand Down
Loading

0 comments on commit 495b45b

Please sign in to comment.