Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference Checkpoints #4620

Merged

Conversation

cmikeh2
Copy link
Contributor

@cmikeh2 cmikeh2 commented Nov 4, 2023

No description provided.

Create the inference parameter.
"""
param = InferenceParameter(core_param)
param._aux_attrs = kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cmikeh2, can you please clarify what the aux_attr is used for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you maybe thinking about scales that are required when adding in the quantization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly right, scales for quantization or anything other metadata we create when transforming the parameter can be stored as an auxiliary attr. The declaration for something like that would be:

p = InferenceParameter.initialize(param, scales=scales)

assert torch.equal(p, param)
assert torch.equal(p.scales, scales)

def finalize(self) -> torch.Tensor:
return self.params
#return self.inference_model.transform_embed_param(self.params)
print("EmbeddingParameter.finalize")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to remove the debugging code here?

"""
dtype: str
shape: Tuple[int, ...]
strides: Tuple[int, ...]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the stride used for here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we ever have some kind of non-contiguous parameter, this will make sure we return it correctly as is rather than silently changing the underlying storage. It shouldn't be the common case, but is low overhead to support and might remove some weird bugs in the future.

@RezaYazdaniAminabadi
Copy link
Contributor

@cmikeh2, I am seeing that you are flattening the parameters and use a big allocated memory for it, but I don't see where the saving of the data is happening so that we can later load it! Is it as simple as torch.save the injected model with the new inference-engine and reload it back?

setattr(layer_container, p_name, None)
continue

dummy_tensor = torch.empty([], dtype=STR_TO_DTYPE[p_metadata.core_param.dtype])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we not pass just the string to figure out the dtype in the alloc_fn rather than creating a dummy tensor with it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strings would also work. The dummy tensor has no storage, and I'd really prefer to just pass the torch.dtype, but I couldn't figure out the correct argument for that. The whole binding around blob is kind of ugly, but when I tried using UntypedStorage which is supposedly the correct way there were phantom allocations I was never able to track down.

@cmikeh2
Copy link
Contributor Author

cmikeh2 commented Nov 7, 2023

@cmikeh2, I am seeing that you are flattening the parameters and use a big allocated memory for it, but I don't see where the saving of the data is happening so that we can later load it! Is it as simple as torch.save the injected model with the new inference-engine and reload it back?

The save is exposed as a serialize method on the engine here:

def serialize(self, save_path: str) -> None:
. It's mostly just going to save a couple of configs and the large tensor though.

The load codepath is here:

.

policy = MistralPolicy(checkpoint_engine, model_config)
if os.path.exists(os.path.join(path, "ds_model_config.pkl")):

# Load metadata, for grabbing the policy name we'll have all ranks just check for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part needs to be supported in a higher abstraction layer, as it is not just related to HF models and we should be able to use it with different checkpoint formats.

@cmikeh2 cmikeh2 changed the base branch from master to cholmes/checkpoints-inference-v2-2 November 10, 2023 02:31
@cmikeh2 cmikeh2 changed the base branch from cholmes/checkpoints-inference-v2-2 to master November 10, 2023 02:32
@cmikeh2 cmikeh2 changed the base branch from master to cholmes/checkpoints-inference-v2-2 November 10, 2023 02:34
@cmikeh2 cmikeh2 merged commit 19b2587 into cholmes/checkpoints-inference-v2-2 Nov 10, 2023
15 of 16 checks passed
@cmikeh2 cmikeh2 deleted the cholmes/checkpoints-inference-v2 branch November 10, 2023 02:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants