Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference Checkpoints #4620

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepspeed/inference/v2/checkpoint/huggingface_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
param = checkpoint_sd[param_name]
yield param_name, param

del checkpoint_sd


if __name__ == "__main__":
# To test, add your auth_token here and run `python huggingface_engine.py`
Expand Down
66 changes: 44 additions & 22 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,66 @@

# DeepSpeed Team

import json
import logging
from typing import Any
import os
import pickle

from .engine_v2 import InferenceEngineV2
from .config_v2 import RaggedInferenceEngineConfig
from .checkpoint import HuggingFaceCheckpointEngine
from .logging import inference_logger
from .model_implementations import (
OPTPolicy,
Llama2Policy,
MistralPolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata


def build_hf_engine(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO,
random_weights_config: Any = None,
fill_random: bool = False) -> InferenceEngineV2:
debug_level: int = logging.INFO) -> InferenceEngineV2:
"""
Build an InferenceV2 engine for HuggingFace models.
"""
# Set up logging
inference_logger(level=debug_level)

# get HF checkpoint engine
checkpoint_engine = HuggingFaceCheckpointEngine(path)

# get model config from HF AutoConfig
model_config = checkpoint_engine.model_config

# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
from .model_implementations.opt.policy import OPTPolicy
policy = OPTPolicy(checkpoint_engine, model_config)
elif model_config.model_type == "llama":
from .model_implementations.llama_v2.llama_v2_policy import Llama2Policy
policy = Llama2Policy(checkpoint_engine, model_config)
elif model_config.model_type == "mistral":
from .model_implementations.mistral.policy import MistralPolicy
policy = MistralPolicy(checkpoint_engine, model_config)
if os.path.exists(os.path.join(path, "ds_model_config.pkl")):

# Load metadata, for grabbing the policy name we'll have all ranks just check for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part needs to be supported in a higher abstraction layer, as it is not just related to HF models and we should be able to use it with different checkpoint formats.

# rank 0.
metadata_filename = make_metadata_filename(path, 0, engine_config.tensor_parallel.tp_size)
metadata = json.load(open(metadata_filename, "r"))
metadata = ModelMetadata.parse_raw(metadata)

# Get the policy
try:
policy_cls: InferenceV2Policy = POLICIES[metadata.policy]
except KeyError:
raise ValueError(f"Unknown policy {metadata.policy} for model {path}")

# Load the model config
model_config = pickle.load(open(os.path.join(path, "ds_model_config.pkl"), "rb"))
policy = policy_cls(model_config, inf_checkpoint_path=path)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")
# get HF checkpoint engine
checkpoint_engine = HuggingFaceCheckpointEngine(path)

# get model config from HF AutoConfig
model_config = checkpoint_engine.model_config

# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
policy = OPTPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "llama":
policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "mistral":
policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

return InferenceEngineV2(policy, engine_config)
22 changes: 22 additions & 0 deletions deepspeed/inference/v2/engine_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# DeepSpeed Team

import os
import json
import pickle
from typing import Iterable, Tuple

import torch
Expand All @@ -17,6 +19,7 @@
from .logging import inference_logger
from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor
from .scheduling_utils import SchedulingError, SchedulingResult
from .model_implementations.flat_model_helpers import make_param_filename, make_metadata_filename

from .config_v2 import RaggedInferenceEngineConfig

Expand Down Expand Up @@ -215,3 +218,22 @@ def flush(self, uid: int) -> None:
uid (int): The UID of the sequence to flush.
"""
self._state_manager.flush_sequence(uid)

def serialize(self, save_path: str) -> None:
"""
Serialize the model to a file.

Arguments:
path (str): Path to the file to serialize to.
"""
param_file_name = make_param_filename(save_path, self._model.tp_rank, self._model.tp_size)
metadata_file_name = make_metadata_filename(save_path, self._model.tp_rank, self._model.tp_size)

# Save the flattened parameters

torch.save(self._model.flattened_params, param_file_name)

json.dump(self._model.flattened_param_metadata.json(), open(metadata_file_name, "w"))

if self._model.tp_rank == 0:
pickle.dump(self._model._config, open(os.path.join(save_path, "ds_model_config.pkl"), "wb"))
89 changes: 89 additions & 0 deletions deepspeed/inference/v2/inference_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Dict

import torch

CORE_PARAM = "_ds_core_param_key"

STR_TO_DTYPE = {
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.float16": torch.float16,
"torch.int64": torch.int64,
"torch.int32": torch.int32,
"torch.int16": torch.int16,
"torch.int8": torch.int8,
"torch.uint8": torch.uint8,
"torch.bool": torch.bool,
}


class InferenceParameter(torch.Tensor):
"""
An extension of the torch.Tensor class to support our inference focused features. One important
thing to note here is that an InferenceParam can be used a torch.Tensor, but outputs of
torch.Tensor operations will not be InferenceParams.
"""

@staticmethod
def __new__(cls, tensor, *args, **kwargs):
new_tensor = super().__new__(cls, tensor, *args, **kwargs)
if hasattr(tensor, "_aux_attrs"):
new_tensor._aux_attrs = tensor.aux_attrs
return new_tensor

def to(self, *args, **kwargs):
new_tensor = super().to(*args, **kwargs)
if hasattr(self, "_aux_attrs"):
new_tensor._aux_attrs = self.aux_attrs

try:
_ = torch.device(args[0])
for name, attr in new_tensor.aux_attrs.items():
new_attr = attr.to(*args, **kwargs)
setattr(new_tensor, name, new_attr)
new_tensor._aux_attrs[name] = new_attr
except:
pass

return new_tensor

@classmethod
def initialize(cls, core_param: torch.Tensor, **kwargs) -> 'InferenceParameter':
"""
Create the inference parameter.
"""
param = InferenceParameter(core_param)
param._aux_attrs = kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cmikeh2, can you please clarify what the aux_attr is used for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you maybe thinking about scales that are required when adding in the quantization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly right, scales for quantization or anything other metadata we create when transforming the parameter can be stored as an auxiliary attr. The declaration for something like that would be:

p = InferenceParameter.initialize(param, scales=scales)

assert torch.equal(p, param)
assert torch.equal(p.scales, scales)


for attr_name, attr in kwargs.items():
if hasattr(param, attr_name):
raise ValueError(f"Attribute {attr_name} already exists on param.")

if not isinstance(attr, torch.Tensor):
raise ValueError(f"Attribute {attr_name} must be a tensor.")

setattr(param, attr_name, attr)

return param

@classmethod
def initialize_raw(self, **kwargs) -> 'InferenceParameter':
"""
All kwargs must be torch.Tensors and must include the core parameter.
"""
if CORE_PARAM not in kwargs:
raise ValueError(f"Must provide core parameter, with key {CORE_PARAM}.")

return InferenceParameter.initialize(kwargs[CORE_PARAM], **kwargs)

@property
def aux_attrs(self) -> Dict[str, torch.Tensor]:
"""
Dictionary of auxiliary attributes.
"""
return self._aux_attrs
5 changes: 5 additions & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
from .inference_transformer_base import DSTransformerModelBase, DSMoETransformerModelBase
from .inference_policy_base import InferenceV2Policy, ContainerMap
from .sharding import *

# Model Implementations
from .llama_v2 import *
from .opt import *
from .mistral import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

from ...model_implementations.parameter_base import ParameterBase
from ...allocator import on_device
"""
Embedding containers.
"""
Expand All @@ -23,7 +22,6 @@ class EmbeddingParameter(ParameterBase):
Vocabulary parameter of shape [vocab_size, model_dim].
"""

@on_device
def finalize(self) -> torch.Tensor:
return self.params
#return self.inference_model.transform_embed_param(self.params)
print("EmbeddingParameter.finalize")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to remove the debugging code here?

return self.inference_model.transform_embedding_param(self.params)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

from ...model_implementations.parameter_base import ParameterBase
from ...allocator import on_device
"""
Common InvFreq Parameter Patterns
"""
Expand All @@ -16,6 +15,5 @@ class InvFreqParameter(ParameterBase):

params: torch.Tensor

@on_device
def finalize(self) -> torch.Tensor:
return self.params.to(self.inference_model.activation_dtype.value)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch

from ...allocator import on_device
from ...model_implementations.parameter_base import ParameterBase, ParamList
"""
Moe Parameters
Expand All @@ -24,7 +23,6 @@ class MoEGatingWeightParameter(ParameterBase):
Projection matrix from the input activations to the gate logits.
"""

@on_device
def finalize(self) -> torch.Tensor:
return self.inference_model.transform_moe_gate_param(self.params)

Expand Down
Loading