Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[speculator training] Speculator training #35

Closed
wants to merge 40 commits into from
Closed

Conversation

daviswer
Copy link
Collaborator

@daviswer daviswer commented Mar 1, 2024

Add support for speculator training, piggybacking off the existing training utilities.

Training script and speculator-specific utilities are inside the new speculator subfolder.

Uses distributed setup, checkpointing, and dataloaders from this repo. Adds speculator-specific fields to the training config file (to be ignored during non-speculator training). It might make more sense to pull these new fields out into a separate config subclass under speculator utilities - open to suggestions.

Uses speculator architecture from fms-extras.

Uses altered Llama-7b and generate() function from base fms, allowing the speculator to access embedding vectors, not just logits/token predictions. Do not merge this until that issue can be resolved.

@daviswer daviswer requested review from nairbv and lchu6 March 1, 2024 20:54
@daviswer daviswer marked this pull request as ready for review March 20, 2024 16:25
@daviswer
Copy link
Collaborator Author

Plan is to move the include_embeds=True versions of Llama/GPTBigCode/generate() into fms-extras. Once that is done I'll update the relevant imports here and then we can push this in

@lchu6 lchu6 changed the title Speculator training [speculator training] Speculator training Mar 28, 2024
@daviswer daviswer requested a review from JRosenkranz March 29, 2024 18:16
@daviswer
Copy link
Collaborator Author

I've pulled all the include_embeds stuff out of fms into here. We now have EmbedLLaMA and EmbedGPTBigCode subclasses that override the corresponding forward function, and an altered version of generate for use only with this script. We register the subclassed models for use with get_model in the training script.

Copy link
Collaborator

@JRosenkranz JRosenkranz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some initial comments on this PR

fms_fsdp/config/training.py Outdated Show resolved Hide resolved
fms_fsdp/utils/dataloader_utils.py Outdated Show resolved Hide resolved
fms_fsdp/utils/dataloader_utils.py Show resolved Hide resolved
# Split line into input and target for the CLM task.
data = Preprocess_Dataset(data, causal_lm)
# Apply desired postprocessing steps in sequence
data = Preprocess_Dataset(data, torch.IntTensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just wrapping with IntTensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - turn list outputs into torch tensors before applying any user-specified preprocess functions

speculator/train_speculator.py Show resolved Hide resolved
speculator/train_speculator.py Show resolved Hide resolved
speculator/train_speculator_utils.py Show resolved Hide resolved
speculator/train_speculator_utils.py Show resolved Hide resolved
speculator/train_speculator_utils.py Show resolved Hide resolved
@AlpinDale
Copy link

Hi! What's the status on this PR? I'd like to train a few speculator models, but I'm not sure how to get started, due to a lack of documentation...

@JRosenkranz
Copy link
Collaborator

Hi! What's the status on this PR? I'd like to train a few speculator models, but I'm not sure how to get started, due to a lack of documentation...

Hi @AlpinDale Working on getting the documentation and code ready for this. Planning to have sometime in the next 3 weeks. Will keep you updated if we get this sometime sooner.

@AlpinDale
Copy link

Thanks for the reply, @JRosenkranz

I'd love to wait but I have access to a large cluster of H100s for a limited time, so I wanted to make the most out of it by training as many MLPSpeculator models as possible, on various popular models. If its doable, I'd love some basic instructions on how to get this PR running and start train runs; I can figure out the rest. Different story if the PR itself isn't ready, however 😅

@vdabravolski
Copy link

Hi, adding +1 to @AlpinDale. We are interested to experiment with MLP speculator, specifically, on latest Llama3.1 models.

Excellent work overall @JRosenkranz !

@sahilsuneja1
Copy link
Collaborator

Hi @AlpinDale @vdabravolski,

PR35 is outdated. We expect to release a stable code version in about 3 weeks.

We understand @AlpinDale's urgency and are trying to put this PR in shape so that you can use it in the interim. We hit issues running it against the main branches of foundation-model-stack and fms-extras, and are working on resolving it. If that doesn't work out we can point you to the specific branches for foundation-model-stack, fms-extras and fms-fsdp repos in the meanwhile so that you can train custom speculators, while we work on polishing them and merging them into their respective mains.

There are already a bunch of speculators available here and here, in case there is any overlap with your requirements. For example, the llama3-70b speculator works for llama3.1-70b as well as mentioned here (and so llama3-8b might also work for llama3.1-8b) .

@sahilsuneja1
Copy link
Collaborator

sahilsuneja1 commented Aug 14, 2024

@AlpinDale @vdabravolski
PR35 has been updated-- it should now work with foundation-model-stack and fms-extras main branches.
Added a sample training script containing example arguments to pass to the speculator training routine.
Most arg names should be straightforward. For more details please refer: https://arxiv.org/pdf/2404.19124, https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/ and https://github.com/foundation-model-stack/fms-fsdp/blob/main/docs/configurations.md

@philschmid
Copy link

Is this expected to be merged soon?

@JRosenkranz
Copy link
Collaborator

Is this expected to be merged soon?

@philschmid We are expecting to have speculator training merged sometime in next 2 weeks.

@JRosenkranz
Copy link
Collaborator

JRosenkranz commented Sep 10, 2024

@philschmid This has been finished and merged in #114. @philschmid The speculator training implementation is now available in main. Please let us know if you have any feedback or questions.

CC: @AlpinDale @vdabravolski

@JRosenkranz
Copy link
Collaborator

Closing in favor of #114

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants