Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add 2D Parallelism (FSDP + Tensor Parallel) LoRA #2204

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

matthew-hippocratic
Copy link

Context

What is the purpose of this PR?

  • add a new feature - support for 2D parallelism using FSDP+TP with LoRA
  • other (Request help from the community to productionize this minimal working example. This code works with single and multinode setups and I have provided the run commands with configs below. However, it is hacky is isn't intended to merge at this point. I thought this feature would be useful and I'm running out of time to work on it so I wanted to see if others would like to help bring it to the finish line! If interested, see Limitations below)

Changelog

What are the changes made in this PR?

  • Add support for 2D parallelism (FSDP+TP) of LoRA Llama models in lora_finetune_distributed.py
    • Add 2D device_mesh that shards along "dp" (FSDP) and "tp" dimensions
    • Utilize DTensor like in torchtitan for easy-to-understand device_mesh sharding
    • Modify the DistributedSampler to shard data along "dp" dimension
  • Modify LoRALinear to remap naming so that the Tensor Parallel layer_plan can target the main model weights
    • This involves renaming when training, then renaming back to original names for saving the checkpoint(s)

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • update docstrings for any new or updated methods or classes
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands

Successful run commands:
Llama-3.1-8B-Instruct on a single 8xH100 Node (FSDP=1, TP=8)

tune run --rdzv-endpoint <node_ip> --nnodes 1 --nproc-per-node 8 --rdzv-id <slurm_job_id> --rdzv-backend=c10d \
    lora_finetune_distributed --config recipes/configs/llama3_1/8B_lora.yaml \
    tokenizer.path=<path_to_/meta-llama/Llama-3.1-8B-Instruct/original/tokenizer.model> \
    checkpointer.checkpoint_dir=<path_to_/meta-llama/Llama-3.1-8B-Instruct> \
    compile=False dataset.packed=True batch_size=2 tokenizer.max_seq_len=8192 gradient_accumulation_steps=8 \
    enable_activation_checkpointing=False enable_activation_offloading=False optimizer.lr=0.0001 model.lora_rank=8 model.lora_alpha=16 \
    max_steps_per_epoch=10

Llama-3.3-70B-Instruct on two 8xH100 Nodes (FSDP=2, TP=8)

tune run --rdzv-endpoint <master_node_ip> --nnodes 2 --nproc-per-node 8 --rdzv-id <slurm_job_id> --rdzv-backend=c10d \
    lora_finetune_distributed --config recipes/configs/llama3_3/70B_lora.yaml \
    tokenizer.path=<path_to_/meta-llama/Llama-3.3-70B-Instruct/original/tokenizer.model> \
    checkpointer.checkpoint_dir=<path_to_/meta-llama/Llama-3.3-70B-Instruct> \
    compile=False dataset.packed=True batch_size=1 tokenizer.max_seq_len=256 gradient_accumulation_steps=8 \
    enable_activation_checkpointing=False enable_activation_offloading=False optimizer.lr=0.0001 model.lora_rank=16 model.lora_alpha=32 \
    max_steps_per_epoch=10

Limitations

Although the above run commands work (loss goes down at reasonable rate for low batch size and/or gradient accumulation), I found that many features did not quite work with this 2D parallelism yet. And there are other concerns with the code as it stands:

  • Activation Checkpointing/Offloading - Biggest limitation since this enables much larger batches
    • The current code wraps the model with activation checkpointing before sharding and this fails since the checkpoints are fully-sized, but the model will shard them when computing
    • I tried moving the activation checkpointing after sharding and it appeared to work, but I noticed GPU Memory kept increasing over epochs leading to late OOMs. And for short training runs that didn't OOM I was getting weird errors afterwards:
      [W1223 03:00:54.225759032 Functional.cpp:46] Warning: At the time of process termination, there are still 2528 unwaited c10d_functional collective calls. Please review your program to ensure c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective ops before they are used. (function ~WorkRegistry)
  • Compile
    • As mentioned here, DTensor has issues with torch compile so set compile=False. Ideally this will get fixed since compile is a pure win.
  • Sequence Length
    • The Sequence Length must be dividible by TP since we are sharding with Sequence Parallel
    • To ensure this, use dataset.packed=True combined with tokenizer.max_seq_len
  • LoRA-related (can be removed if we choose not to shard the LoRA weights with TP since they are relatively small)
    • The LoRA Rank must be divisible by TP dimension and at least as large as sharding level. For example, with FSDP=2, TP=8 the LoRA Rank must be >= 16(8x2)
    • The TP dimension cannot exceed the number of KV Heads (8 for Llama models)

UX

  • I did not change any public API

Acknowledgments

Thanks to @akashc1 for helping make this work!!

Copy link

pytorch-bot bot commented Dec 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2204

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @matthew-hippocratic!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 24, 2024
@joecummings
Copy link
Contributor

This is really cool and awesome addition to our library!! As discussed in a few of our issues (#2018, etc), this is something we are very eager to add soon.

I'll take a look over your implementation and the outstanding concerns and lyk any high-level comments.

Obviously, most people on the team are off for the holidays so we'll look to get this land-able, but will likely be sometime early Jan. Last point, do you mind if I help work on this, too? Your contributions will 100% not be lost in anything that gets landed :)

@matthew-hippocratic
Copy link
Author

This is really cool and awesome addition to our library!! As discussed in a few of our issues (#2018, etc), this is something we are very eager to add soon.

I'll take a look over your implementation and the outstanding concerns and lyk any high-level comments.

Obviously, most people on the team are off for the holidays so we'll look to get this land-able, but will likely be sometime early Jan. Last point, do you mind if I help work on this, too? Your contributions will 100% not be lost in anything that gets landed :)

Awesome, sounds great :)
And yes I would love your help!

@matthew-hippocratic
Copy link
Author

Hey @joecummings any updates from y'all's side?

@RdoubleA
Copy link
Contributor

RdoubleA commented Jan 17, 2025

Hi @matthew-hippocratic, we have another PR from a maintainer that will land some core TP utilities that should make it easy it enable TP for any torchtune model when training: #2245. We're hoping to land this soon, would you be able to rebase this PR once that lands and use the utilities?

cc @joecummings @acisseJZhong

@joecummings
Copy link
Contributor

Hey @joecummings any updates from y'all's side?

Hey @matthew-hippocratic thanks for following up! Here's the current gameplan:

@matthew-hippocratic
Copy link
Author

Awesome, sounds great! I'm very excited and will be happy to help more so let me know if you need any!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants