-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference Checkpoints #4620
Inference Checkpoints #4620
Conversation
Co-authored-by: Michael Wyatt <[email protected]> Co-authored-by: Ammar Ahmad Awan <[email protected]> Co-authored-by: Connor Holmes <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Logan Adams <[email protected]>
Create the inference parameter. | ||
""" | ||
param = InferenceParameter(core_param) | ||
param._aux_attrs = kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cmikeh2, can you please clarify what the aux_attr is used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you maybe thinking about scales that are required when adding in the quantization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's exactly right, scales for quantization or anything other metadata we create when transforming the parameter can be stored as an auxiliary attr. The declaration for something like that would be:
p = InferenceParameter.initialize(param, scales=scales)
assert torch.equal(p, param)
assert torch.equal(p.scales, scales)
def finalize(self) -> torch.Tensor: | ||
return self.params | ||
#return self.inference_model.transform_embed_param(self.params) | ||
print("EmbeddingParameter.finalize") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to remove the debugging code here?
""" | ||
dtype: str | ||
shape: Tuple[int, ...] | ||
strides: Tuple[int, ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the stride used for here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we ever have some kind of non-contiguous parameter, this will make sure we return it correctly as is rather than silently changing the underlying storage. It shouldn't be the common case, but is low overhead to support and might remove some weird bugs in the future.
@cmikeh2, I am seeing that you are flattening the parameters and use a big allocated memory for it, but I don't see where the saving of the data is happening so that we can later load it! Is it as simple as |
setattr(layer_container, p_name, None) | ||
continue | ||
|
||
dummy_tensor = torch.empty([], dtype=STR_TO_DTYPE[p_metadata.core_param.dtype]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we not pass just the string to figure out the dtype in the alloc_fn rather than creating a dummy tensor with it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strings would also work. The dummy tensor has no storage, and I'd really prefer to just pass the torch.dtype
, but I couldn't figure out the correct argument for that. The whole binding around blob
is kind of ugly, but when I tried using UntypedStorage
which is supposedly the correct way there were phantom allocations I was never able to track down.
The save is exposed as a DeepSpeed/deepspeed/inference/v2/engine_v2.py Line 222 in 0faea35
The load codepath is here:
|
policy = MistralPolicy(checkpoint_engine, model_config) | ||
if os.path.exists(os.path.join(path, "ds_model_config.pkl")): | ||
|
||
# Load metadata, for grabbing the policy name we'll have all ranks just check for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this part needs to be supported in a higher abstraction layer, as it is not just related to HF models and we should be able to use it with different checkpoint formats.
…ft/DeepSpeed into cholmes/checkpoints-inference-v2
19b2587
into
cholmes/checkpoints-inference-v2-2
No description provided.