-
Notifications
You must be signed in to change notification settings - Fork 494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does torchtune support multi-node training? #2018
Comments
Hi @tginart, if you just want to get a very, very basic multi-node setup running, it actually shouldn't be too hard. We wrap around torchrun which supports multi-node. Just curious - what's your use case here? Large models, faster training, more data? Do you already have access to a multi-node setup running w/ SLURM? Or are you considering one? |
Hi @joecummings Current use-case is just faster fine-tuning over larger datasets. I do have access to multi-node SLURM already, and have trained using other frameworks. For various reasons I've used torchtune recently for some small models on single node and was just wondering if it has multi-node. What file should I take a look at? |
Ah in that case, I'd recommend two things:
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
def shard_model(...):
...
mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
# Parallelize input and output first
parallelize_module(
model,
mesh_2d["tp"],
{
"token_embeddings": PARALLEL_STRATEGY,
"output": PARALLEL_STRATEGY,
}
)
# Iterate over layers and parallelize
for layer in model.layers.items():
layer_plan = {
"attention.wq": PARALLEL_STRATEGY,
"attention.wk": PARALLEL_STRATEGY,
...
}
parallelize_module(model, mesh_2d["tp"], layer_plan)
# Shard the model normally but using the mesh_2d["dp"] object Torchtitan has a great example of this FSDP + TP work here. |
I am using LSF to launch a |
Can you take a look at how TorchTitan launches multinode training? The key part I think that was missing in the above pseudo-code is that you need to specify a rendezvous backend. You can read more about that here |
is it something that you guys are thinking to fix/implemenent? |
I think this is what you suggested as first try, but still
modified shard_model
|
in the end i am managed to get it to run with mpirun, but i get the follwoing error
which make sense
|
@fabiogeraci Ahh! I see in FSDP2, they remap HYBRID_SHARD to |
Yep - there's a great guide here. HYBRID_SHARD is the same as All of torchtune and torchtitan uses FSDP2. |
how would i switch from 1D mesh to 2d mesh? |
This will create a 2 D mesh with 2 nodes and 8 GPUs per node. |
Hi Joe @joecummings, I was checking So I try to understand the bottleneck of FSDP under multi-node tuning. Is the all-gather happend every layer across all workers so may be bounded by communication? Is that a mistake that Thanks! |
I implemented your suggestion, massive improvement in speed on multi nodes multi gpus set up via opemnmpi. I had to tweack shard_model. May I ask why torchtune does not support multi node multi gpu, out of the box? |
Great questions! I'd really recommend reading through this issue on PyTorch where @awgu discusses why multi-node FSDP is usually slower than single node. The TL;DR is that communication can take longer between different nodes and since FSDP needs all-gather for parameters and reduce-scatter for gradient reduction, this can be a bottleneck in training. It's probably not a mistake that torchtitan uses only FSDP for that specific config b/c it's the smallest model and my guess is that they're just trying to show that it's possible. If you look at the torchtitan configs for their larger models, you'll see that they use TP. If you have good interconnect speed between nodes, FSDP will work faster. Regardless, TP + a "hybrid shard" will likely be faster than FSDP in a multi-node setup b/c it's not as communication bound. Hope this explanation helps! |
That's awesome! Would love to take a peak at your code if you want to post a gist so other users can see how you did it. We actually do plan to support multi node OOTB sometime soon The biggest reason we haven't so far is just due to our own bandwidth constraints. We're a small-ish team and there's lots of new models and techniques coming out every day! We wanted to be sure that we provided a great single node experience before tackling multi node, but like I mentioned, we'll probably have a canonical example in torchtune soon. |
openmpi script, launch cli
full_finetune_distributed.py
_distributed.py
|
would i be able to make PR with this code ;) |
Does torchtune support multi-node training? For example, in a SLURM environment?
If so, would it be possible to get an example config?
The text was updated successfully, but these errors were encountered: