-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[speculator training] Speculator training #35
Conversation
Plan is to move the |
I've pulled all the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some initial comments on this PR
fms_fsdp/utils/dataloader_utils.py
Outdated
# Split line into input and target for the CLM task. | ||
data = Preprocess_Dataset(data, causal_lm) | ||
# Apply desired postprocessing steps in sequence | ||
data = Preprocess_Dataset(data, torch.IntTensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this just wrapping with IntTensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - turn list outputs into torch tensors before applying any user-specified preprocess functions
Hi! What's the status on this PR? I'd like to train a few speculator models, but I'm not sure how to get started, due to a lack of documentation... |
Hi @AlpinDale Working on getting the documentation and code ready for this. Planning to have sometime in the next 3 weeks. Will keep you updated if we get this sometime sooner. |
Thanks for the reply, @JRosenkranz I'd love to wait but I have access to a large cluster of H100s for a limited time, so I wanted to make the most out of it by training as many MLPSpeculator models as possible, on various popular models. If its doable, I'd love some basic instructions on how to get this PR running and start train runs; I can figure out the rest. Different story if the PR itself isn't ready, however 😅 |
Hi, adding +1 to @AlpinDale. We are interested to experiment with MLP speculator, specifically, on latest Llama3.1 models. Excellent work overall @JRosenkranz ! |
PR35 is outdated. We expect to release a stable code version in about 3 weeks. We understand @AlpinDale's urgency and are trying to put this PR in shape so that you can use it in the interim. We hit issues running it against the main branches of foundation-model-stack and fms-extras, and are working on resolving it. If that doesn't work out we can point you to the specific branches for foundation-model-stack, fms-extras and fms-fsdp repos in the meanwhile so that you can train custom speculators, while we work on polishing them and merging them into their respective mains. There are already a bunch of speculators available here and here, in case there is any overlap with your requirements. For example, the llama3-70b speculator works for llama3.1-70b as well as mentioned here (and so llama3-8b might also work for llama3.1-8b) . |
@AlpinDale @vdabravolski |
Is this expected to be merged soon? |
@philschmid We are expecting to have speculator training merged sometime in next 2 weeks. |
@philschmid This has been finished and merged in #114. @philschmid The speculator training implementation is now available in main. Please let us know if you have any feedback or questions. |
Closing in favor of #114 |
Add support for speculator training, piggybacking off the existing training utilities.
Training script and speculator-specific utilities are inside the new
speculator
subfolder.Uses distributed setup, checkpointing, and dataloaders from this repo. Adds speculator-specific fields to the training config file (to be ignored during non-speculator training). It might make more sense to pull these new fields out into a separate config subclass under speculator utilities - open to suggestions.
Uses speculator architecture from fms-extras.
Uses altered Llama-7b and
generate()
function from base fms, allowing the speculator to access embedding vectors, not just logits/token predictions.Do not merge this until that issue can be resolved.