Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does torchtune support multi-node training? #2018

Open
tginart opened this issue Nov 18, 2024 · 19 comments
Open

Does torchtune support multi-node training? #2018

tginart opened this issue Nov 18, 2024 · 19 comments
Assignees
Labels
discussion Start a discussion distributed Anything related to distributed env (multi-GPU, multi-node)

Comments

@tginart
Copy link

tginart commented Nov 18, 2024

Does torchtune support multi-node training? For example, in a SLURM environment?

If so, would it be possible to get an example config?

@joecummings
Copy link
Contributor

joecummings commented Nov 18, 2024

Hi @tginart, if you just want to get a very, very basic multi-node setup running, it actually shouldn't be too hard. We wrap around torchrun which supports multi-node.

Just curious - what's your use case here? Large models, faster training, more data? Do you already have access to a multi-node setup running w/ SLURM? Or are you considering one?

@tginart
Copy link
Author

tginart commented Nov 18, 2024

Hi @joecummings

Current use-case is just faster fine-tuning over larger datasets.

I do have access to multi-node SLURM already, and have trained using other frameworks. For various reasons I've used torchtune recently for some small models on single node and was just wondering if it has multi-node.

What file should I take a look at?

@joecummings
Copy link
Contributor

joecummings commented Nov 19, 2024

Ah in that case, I'd recommend two things:

  1. Instead of utilizing tune run, which we artificially constrain to 1 node for now, tune cp the recipe and config you want and then launch directly with torchrun. e.g. torchrun --nnodes 2 --nproc-per-node 8 full_finetune_distributed.py --config llama3_2/3B_full.yaml. This alone should enable you to run fine-tuning over larger datasets w/ FSDP (Just make sure you modify the sharding strategy to be HYBRID. You can do that by modifying the fsdp_kwargs here to include an item for "sharding_strategy": ShardingStrategy. HYBRID_SHARD.
  2. Add a very basic tensor parallel configuration. Right now, we just use FSDP for distributed training which will likely be very slow on multinode b/c it will all-gather everything needed for backprop. Tensor parallel should actually achieve the speed up you need. TP is not quite as basic as step number 1. For simplicity sake, I'd recommend modifying our shard_model code to do something like the following pseudocode:
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module

def shard_model(...):
...
	mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
	
	# Parallelize input and output first
	parallelize_module(
		model, 
		mesh_2d["tp"], 
		{
			"token_embeddings": PARALLEL_STRATEGY,
			"output": PARALLEL_STRATEGY,
		}
	)
	
	# Iterate over layers and parallelize 
	for layer in model.layers.items():
		layer_plan = {
			"attention.wq": PARALLEL_STRATEGY,
			"attention.wk": PARALLEL_STRATEGY,
			...
		}
		parallelize_module(model, mesh_2d["tp"], layer_plan)

	# Shard the model normally but using the mesh_2d["dp"] object

Torchtitan has a great example of this FSDP + TP work here.

@fabiogeraci
Copy link

I am using LSF to launch a torchrun --nnodes 2 --nproc-per-node 8 full_finetune_distributed.py full error trace attached
full_trace.txt
8B_full_distributed.txt

@joecummings
Copy link
Contributor

I am using LSF to launch a torchrun --nnodes 2 --nproc-per-node 8 full_finetune_distributed.py full error trace attached full_trace.txt 8B_full_distributed.txt

Can you take a look at how TorchTitan launches multinode training? The key part I think that was missing in the above pseudo-code is that you need to specify a rendezvous backend. You can read more about that here

@fabiogeraci
Copy link

is it something that you guys are thinking to fix/implemenent?

@fabiogeraci
Copy link

I think this is what you suggested as first try, but still
full_trace.txt

torchrun  \
    --nproc_per_node \$GPU_PER_HOST \
    --nnodes \$NUM_HOSTS \
    --rdzv-backend c10d \
    --rdzv_endpoint \$MASTER_ADDR:\$MASTER_PORT \
    src/full_finetune_distributed.py --config \
    config_files/8B_full_distributed.yaml \
    optimizer_in_bwd=False
training.shard_model(
                model=model,
                shard_conditions=fsdp_shard_conditions,
                cpu_offload=fsdp_cpu_offload,
                reshard_after_forward=reshard_after_forward,
                sharding_strategy=ShardingStrategy.HYBRID_SHARD,
            )

modified shard_model

def shard_model(
    model: TransformerDecoder,
    shard_conditions: List[Callable[[str, nn.Module], bool]],
    *,
    cpu_offload: bool,
    reshard_after_forward: bool = True,
    sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
) -> None:
    fsdp_kwargs = {
        "reshard_after_forward": reshard_after_forward,
        "sharding_strategy": sharding_strategy,
    }
    if cpu_offload:
        fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

    num_layers_sharded = 0
    for n, m in reversed(list(model.named_modules())):
        if any([shard_condition(n, m) for shard_condition in shard_conditions]):
            fully_shard(m, **fsdp_kwargs)
            num_layers_sharded += 1

    if num_layers_sharded == 0:
        raise ValueError(
            "No layer modules were sharded. Please check if shard conditions are working as expected."
        )

    fully_shard(model, **fsdp_kwargs)

@fabiogeraci
Copy link

fabiogeraci commented Dec 5, 2024

in the end i am managed to get it to run with mpirun, but i get the follwoing error

[rank3]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-31tjdLhK-py3.11/lib/python3.11/site-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank3]:     updated = func(inp_module, *args, **kwargs)
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: TypeError: fully_shard() got an unexpected keyword argument 'sharding_strategy'

which make sense

def fully_shard(
    module: Union[nn.Module, List[nn.Module]],
    *,
    mesh: Optional[DeviceMesh] = None,
    reshard_after_forward: Union[bool, int] = True,
    mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
    offload_policy: OffloadPolicy = OffloadPolicy(),
):

@joecummings
Copy link
Contributor

@fabiogeraci Ahh! I see in FSDP2, they remap HYBRID_SHARD to reshard_after_forward=True. Give that a go :)

@fabiogeraci
Copy link

fabiogeraci commented Dec 5, 2024

would you mind to explain, please?

i can see
torchtitan/docs /fsdp.md

but how do i know which FSPD is used

@joecummings
Copy link
Contributor

Yep - there's a great guide here.

HYBRID_SHARD is the same as reshard_after_forward=True which "determines whether parameters are resharded (freed) after forward. If True, then they are re-all-gathered in backward. This trades off saving memory at the cost of extra communication."

All of torchtune and torchtitan uses FSDP2.

@fabiogeraci
Copy link

fabiogeraci commented Dec 5, 2024

how would i switch from 1D mesh to 2d mesh?

@joecummings
Copy link
Contributor

mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))

This will create a 2 D mesh with 2 nodes and 8 GPUs per node.

@liyucheng09
Copy link

  1. Add a very basic tensor parallel configuration. Right now, we just use FSDP for distributed training which will likely be very slow on multinode b/c it will all-gather everything needed for backprop. Tensor parallel should actually achieve the speed up you need. TP is not quite as basic as step number 1. For simplicity sake, I'd recommend modifying our shard_model code to do something like the following pseudocode:

Hi Joe @joecummings, I was checking torchtitan for multi-node practice as you suggested. I found they seem to use a pure FSDP approach with a 64 GPU setting: here.

So I try to understand the bottleneck of FSDP under multi-node tuning. Is the all-gather happend every layer across all workers so may be bounded by communication? Is that a mistake that torchtitan use pure FSDP with 64 gpus? And I also wonder will tp contribute in this case if I am doing small models like llama7b.

Thanks!

@fabiogeraci
Copy link

mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))

This will create a 2 D mesh with 2 nodes and 8 GPUs per node.

I implemented your suggestion, massive improvement in speed on multi nodes multi gpus set up via opemnmpi. I had to tweack shard_model.

May I ask why torchtune does not support multi node multi gpu, out of the box?

@joecummings
Copy link
Contributor

Hi Joe @joecummings, I was checking torchtitan for multi-node practice as you suggested. I found they seem to use a pure FSDP approach with a 64 GPU setting: here.

So I try to understand the bottleneck of FSDP under multi-node tuning. Is the all-gather happend every layer across all workers so may be bounded by communication? Is that a mistake that torchtitan use pure FSDP with 64 gpus? And I also wonder will tp contribute in this case if I am doing small models like llama7b.

Thanks!

Great questions! I'd really recommend reading through this issue on PyTorch where @awgu discusses why multi-node FSDP is usually slower than single node. The TL;DR is that communication can take longer between different nodes and since FSDP needs all-gather for parameters and reduce-scatter for gradient reduction, this can be a bottleneck in training.

It's probably not a mistake that torchtitan uses only FSDP for that specific config b/c it's the smallest model and my guess is that they're just trying to show that it's possible. If you look at the torchtitan configs for their larger models, you'll see that they use TP.

If you have good interconnect speed between nodes, FSDP will work faster. Regardless, TP + a "hybrid shard" will likely be faster than FSDP in a multi-node setup b/c it's not as communication bound.

Hope this explanation helps!

@joecummings
Copy link
Contributor

I implemented your suggestion, massive improvement in speed on multi nodes multi gpus set up via opemnmpi. I had to tweack shard_model.

May I ask why torchtune does not support multi node multi gpu, out of the box?

That's awesome! Would love to take a peak at your code if you want to post a gist so other users can see how you did it.

We actually do plan to support multi node OOTB sometime soon The biggest reason we haven't so far is just due to our own bandwidth constraints. We're a small-ish team and there's lots of new models and techniques coming out every day! We wanted to be sure that we provided a great single node experience before tackling multi node, but like I mentioned, we'll probably have a canonical example in torchtune soon.

@fabiogeraci
Copy link

fabiogeraci commented Dec 9, 2024

openmpi script, launch cli

mpirun \
    -np $TOTAL_NUM_GPUS \
    -H \$MPI_HOST_STRING \
    -x PATH \
    -bind-to none \
    -map-by slot \
    --mca pml ob1 --mca btl ^openib \
    --display-allocation \
    --display-map \
    python3 src/full_finetune_distributed.py \
    --config config_files/8B_full_distributed.yaml \
    optimizer_in_bwd=False

full_finetune_distributed.py

if int(os.environ.get("NUM_NODES")) > 1:
    from torch.distributed._tensor import init_device_mesh
    mesh_2d = init_device_mesh("cuda",
                               mesh_shape=(int(os.environ.get("NUM_NODES")),
                                           int(os.environ['WORLD_SIZE']) // 2),
                                           mesh_dim_names=("dp", "tp"))
else:
    mesh_2d = None

training.shard_model(
    model=model,
    shard_conditions=fsdp_shard_conditions,
    cpu_offload=fsdp_cpu_offload,
    reshard_after_forward=reshard_after_forward,
    mesh=mesh_2d,
)

_distributed.py

def shard_model(
    model: TransformerDecoder,
    shard_conditions: List[Callable[[str, nn.Module], bool]],
    *,
    cpu_offload: bool,
    reshard_after_forward: bool = True,
    mesh: Optional[DeviceMesh] = None # <-- Add this line
) -> None:
if mesh is not None: # <-- Add this line
        fsdp_kwargs["mesh"] = mesh # <-- Add this line

@fabiogeraci
Copy link

would i be able to make PR with this code ;)

@joecummings joecummings added discussion Start a discussion and removed question labels Dec 10, 2024
@joecummings joecummings self-assigned this Dec 10, 2024
@joecummings joecummings added the distributed Anything related to distributed env (multi-GPU, multi-node) label Dec 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Start a discussion distributed Anything related to distributed env (multi-GPU, multi-node)
Projects
None yet
Development

No branches or pull requests

4 participants