Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

refactor load_checkpoint #113

Merged
merged 1 commit into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions energonai/model/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from colossalai.context import ParallelMode
from energonai.utils import is_using_pp, get_current_device
from energonai.logging import get_dist_logger
from energonai.utils.checkpointing import load_checkpoint
from energonai.utils.checkpointing_hf_gpt2 import processing_HF_GPT
from energonai.utils.checkpointing_opt import processing_OPT


def gelu_impl(x):
Expand Down Expand Up @@ -160,12 +163,12 @@ def create_pipeline_model(depth: int = 48,
if "checkpoint" in model_kwargs.keys() and "model_name" in model_kwargs.keys():
start = time.time()
assert os.path.exists(model_kwargs["checkpoint"]), "Checkpoint file not found"
preprocess_fn = None
if model_kwargs["model_name"] == "hf_gpt2":
from energonai.utils.checkpointing_hf_gpt2 import load_checkpoint
load_checkpoint(model_kwargs["checkpoint"], model, **model_kwargs)
if model_kwargs["model_name"] == "opt":
from energonai.utils.checkpointing_opt import load_checkpoint
load_checkpoint(model_kwargs["checkpoint"], model, **model_kwargs)
preprocess_fn = processing_HF_GPT
elif model_kwargs["model_name"] == "opt":
preprocess_fn = processing_OPT
load_checkpoint(model_kwargs["checkpoint"], model, preprocess_fn=preprocess_fn, **model_kwargs)
logger.info(f'Load time: {time.time() - start:.3f} s')

return model
Expand Down
51 changes: 36 additions & 15 deletions energonai/utils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..communication.collective import scatter_object_list
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc

from typing import Optional, Callable
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
Expand All @@ -19,6 +19,29 @@
"partition_tensor_parallel_state_dict", "load_checkpoint", "gather_tensor_parallel_state_dict", "save_checkpoint"
]

import os
from multiprocessing import Pool
from time import time


def load_state_dict(path: str):
if os.path.isfile(path):
return torch.load(path)
assert os.path.isdir(path)
state_dict = {}
files = []
for filename in os.listdir(path):
filepath = os.path.join(path, filename)
if os.path.isfile(filepath):
files.append(filepath)
threads = torch.get_num_threads()
print(f'load {len(files)} files using {threads} threads')
with Pool(threads) as pool:
state_dicts = pool.map(torch.load, files)
for sd in state_dicts:
state_dict.update(sd)
return state_dict


def broadcast_state_dict(state_dict, parallel_mode):
"""
Expand Down Expand Up @@ -170,9 +193,8 @@ def remove_prefix(state_dict, prefix):

def load_checkpoint(file,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
strict: bool = True,
preprocess_fn: Optional[Callable[[dict], dict]] = None,
**kwargs):
"""Loads training states from a checkpoint file.

Expand All @@ -192,17 +214,22 @@ def load_checkpoint(file,
Raises:
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
"""
state_dict = (torch.load(file, map_location=torch.device("cpu"))
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)

# model states
model_state = state_dict.pop("model") if state_dict is not None else dict()
start = time()
if gpc.get_local_rank(ParallelMode.MODEL) == 0:
model_state = load_state_dict(file)
if preprocess_fn:
model_state = preprocess_fn(model_state)
else:
model_state = dict()
dist.barrier()
print(f'Load file time: {time()-start:.3f} s')
# pipeline
if is_using_pp():
model_state = partition_pipeline_parallel_state_dict(model, model_state, **kwargs)
if "prefix" in kwargs.keys():
if kwargs['prefix'] != '':
model_state = remove_prefix(model_state, kwargs["prefix"])

try:
model.load_state_dict(model_state, strict=strict)
except RuntimeError as e:
Expand All @@ -219,13 +246,7 @@ def load_checkpoint(file,
else:
raise e

# broadcast the rest states
state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL)

# last epoch
last_epoch = state_dict.pop("epoch", -1)

return last_epoch
return -1


def save_checkpoint(file,
Expand Down
Loading