-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference Checkpoints #4620
Inference Checkpoints #4620
Changes from 8 commits
49cea1a
74b6f76
61ad33d
c1b1c2a
32b3cbd
5ee1dd6
aff9b4c
0faea35
07a1c93
75b0e89
573c21b
be3a27a
04e6907
d3434b3
27ba203
91385d2
89566b4
295fa61
9b0bcca
c65aebc
ee87ca7
b12d6be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from typing import Dict | ||
|
||
import torch | ||
|
||
CORE_PARAM = "_ds_core_param_key" | ||
|
||
STR_TO_DTYPE = { | ||
"torch.float32": torch.float32, | ||
"torch.float64": torch.float64, | ||
"torch.float16": torch.float16, | ||
"torch.int64": torch.int64, | ||
"torch.int32": torch.int32, | ||
"torch.int16": torch.int16, | ||
"torch.int8": torch.int8, | ||
"torch.uint8": torch.uint8, | ||
"torch.bool": torch.bool, | ||
} | ||
|
||
|
||
class InferenceParameter(torch.Tensor): | ||
""" | ||
An extension of the torch.Tensor class to support our inference focused features. One important | ||
thing to note here is that an InferenceParam can be used a torch.Tensor, but outputs of | ||
torch.Tensor operations will not be InferenceParams. | ||
""" | ||
|
||
@staticmethod | ||
def __new__(cls, tensor, *args, **kwargs): | ||
new_tensor = super().__new__(cls, tensor, *args, **kwargs) | ||
if hasattr(tensor, "_aux_attrs"): | ||
new_tensor._aux_attrs = tensor.aux_attrs | ||
return new_tensor | ||
|
||
def to(self, *args, **kwargs): | ||
new_tensor = super().to(*args, **kwargs) | ||
if hasattr(self, "_aux_attrs"): | ||
new_tensor._aux_attrs = self.aux_attrs | ||
|
||
try: | ||
_ = torch.device(args[0]) | ||
for name, attr in new_tensor.aux_attrs.items(): | ||
new_attr = attr.to(*args, **kwargs) | ||
setattr(new_tensor, name, new_attr) | ||
new_tensor._aux_attrs[name] = new_attr | ||
except: | ||
pass | ||
|
||
return new_tensor | ||
|
||
@classmethod | ||
def initialize(cls, core_param: torch.Tensor, **kwargs) -> 'InferenceParameter': | ||
""" | ||
Create the inference parameter. | ||
""" | ||
param = InferenceParameter(core_param) | ||
param._aux_attrs = kwargs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cmikeh2, can you please clarify what the aux_attr is used for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you maybe thinking about scales that are required when adding in the quantization? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's exactly right, scales for quantization or anything other metadata we create when transforming the parameter can be stored as an auxiliary attr. The declaration for something like that would be: p = InferenceParameter.initialize(param, scales=scales)
assert torch.equal(p, param)
assert torch.equal(p.scales, scales) |
||
|
||
for attr_name, attr in kwargs.items(): | ||
if hasattr(param, attr_name): | ||
raise ValueError(f"Attribute {attr_name} already exists on param.") | ||
|
||
if not isinstance(attr, torch.Tensor): | ||
raise ValueError(f"Attribute {attr_name} must be a tensor.") | ||
|
||
setattr(param, attr_name, attr) | ||
|
||
return param | ||
|
||
@classmethod | ||
def initialize_raw(self, **kwargs) -> 'InferenceParameter': | ||
""" | ||
All kwargs must be torch.Tensors and must include the core parameter. | ||
""" | ||
if CORE_PARAM not in kwargs: | ||
raise ValueError(f"Must provide core parameter, with key {CORE_PARAM}.") | ||
|
||
return InferenceParameter.initialize(kwargs[CORE_PARAM], **kwargs) | ||
|
||
@property | ||
def aux_attrs(self) -> Dict[str, torch.Tensor]: | ||
""" | ||
Dictionary of auxiliary attributes. | ||
""" | ||
return self._aux_attrs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
import torch | ||
|
||
from ...model_implementations.parameter_base import ParameterBase | ||
from ...allocator import on_device | ||
""" | ||
Embedding containers. | ||
""" | ||
|
@@ -23,7 +22,6 @@ class EmbeddingParameter(ParameterBase): | |
Vocabulary parameter of shape [vocab_size, model_dim]. | ||
""" | ||
|
||
@on_device | ||
def finalize(self) -> torch.Tensor: | ||
return self.params | ||
#return self.inference_model.transform_embed_param(self.params) | ||
print("EmbeddingParameter.finalize") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to remove the debugging code here? |
||
return self.inference_model.transform_embedding_param(self.params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this part needs to be supported in a higher abstraction layer, as it is not just related to HF models and we should be able to use it with different checkpoint formats.