Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable Context Parallel #592

Merged
merged 25 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3bf7333
enable Context Parallel
XilunWu Sep 30, 2024
afb1051
Update on "enable Context Parallel"
XilunWu Sep 30, 2024
f99a6f5
Update on "enable Context Parallel"
XilunWu Oct 3, 2024
4ad6881
Update on "enable Context Parallel"
XilunWu Oct 3, 2024
038b5ce
Update on "enable Context Parallel"
XilunWu Oct 4, 2024
4758df2
Update base for Update on "enable Context Parallel"
XilunWu Oct 21, 2024
a6758dd
Update on "enable Context Parallel"
XilunWu Oct 21, 2024
f570fa8
Update base for Update on "enable Context Parallel"
XilunWu Oct 21, 2024
c102f73
Update on "enable Context Parallel"
XilunWu Oct 21, 2024
83230fd
Update base for Update on "enable Context Parallel"
XilunWu Oct 21, 2024
2863907
Update on "enable Context Parallel"
XilunWu Oct 21, 2024
534ce58
Update base for Update on "enable Context Parallel"
XilunWu Oct 21, 2024
0c355e6
Update on "enable Context Parallel"
XilunWu Oct 21, 2024
172717d
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
b89e59b
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
a5e453f
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
e319ab9
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
99fe0bc
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
9bec02c
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
47c0078
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
15c00d5
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
bba36b4
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
346d721
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
a5d1fdf
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
8045cad
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan import utils
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from train import get_train_context


def estimate_memory(job_config: JobConfig):
Expand Down Expand Up @@ -66,6 +66,7 @@ def estimate_memory(job_config: JobConfig):
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down Expand Up @@ -94,7 +95,7 @@ def estimate_memory(job_config: JobConfig):
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

train_context = get_train_context(
train_context = utils.get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
)
Expand Down
35 changes: 35 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,41 @@ def build_test_list():
"hsdp+tp",
ngpu=8,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"FSDP+CP",
"fsdp+cp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"HSDP+CP",
"hsdp+cp",
ngpu=8,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.tensor_parallel_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"FSDP+TP+CP",
"fsdp+tp+cp",
ngpu=8,
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Looks like FSDP/HSDP + TP + CP is working. How about PP? We can also mention progress in the .md doc later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the next step is to test 4D/5D (w/ PP and HSDP)

OverrideDefinitions(
[
[
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,12 @@ def __init__(self):
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--experimental.context_parallel_degree",
type=int,
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,9 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
# TODO: explain in docs/composability.md why we removed the 2x
# relaxing in our CP enablement PR
self.model_args.max_seq_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc., @tianyu-l Want to understand is this okay?

For a general use case, we can also expand the CP to support stride-like feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate a bit on why this change was needed by CP?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l CP parallelize on the sequence dimension, anything related to the sequence dimension needs to be shard. So freqs_cis is the positional embedding and is required to be sharded according to the sequence length. So it is easier to support CP if everything has the same sequence length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me. @awgu to confirm this is OK.

Also we need to add a note in docs/composability.md to clarify why this (model change) is needed. It can be addressed in a separate PR; in that case please create issue / leave TODO.

self.model_args.rope_theta,
)

Expand Down
32 changes: 23 additions & 9 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class ParallelDims:
dp_replicate: int
dp_shard: int
cp: int
tp: int
pp: int
world_size: int
Expand All @@ -24,36 +25,38 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, tp, pp = (
dp_replicate, dp_shard, cp, tp, pp = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
)
for d in (dp_replicate, tp, pp):
for d in (dp_replicate, cp, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (tp * pp)
dp = self.world_size // (cp * tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate

assert dp_replicate >= 1
assert dp_shard >= 1
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
["pp", "dp_replicate", "dp_shard", "tp"],
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
dims.append(d)
Expand All @@ -71,6 +74,13 @@ def build_mesh(self, device_type):
# initialized
if self.dp_replicate > 1 and self.dp_shard > 1:
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")

if self.cp > 1:
if self.dp_replicate > 1 and self.dp_shard > 1:
mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp")
else:
mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")

return mesh

@property
Expand All @@ -85,6 +95,10 @@ def dp_replicate_enabled(self):
def dp_shard_enabled(self):
return self.dp_shard > 1

@property
def cp_enabled(self):
return self.cp > 1

@property
def tp_enabled(self):
return self.tp > 1
Expand All @@ -98,5 +112,5 @@ def loss_parallel_enabled(self):
return self.tp > 1 and self.enable_loss_parallel

@cached_property
def model_parallel_size(self):
return self.tp * self.pp
def non_data_parallel_size(self):
return self.cp * self.tp * self.pp
58 changes: 29 additions & 29 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
import torch.nn as nn

from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.replicate import replicate
Expand Down Expand Up @@ -72,36 +73,35 @@ def parallelize_llama(
)
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_shard_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
else:
dp_mesh = world_mesh["dp"]

apply_fsdp(
model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[
job_config.training.mixed_precision_reduce
],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
)
if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
logger.info("Applied FSDP to the model")
if (
parallel_dims.dp_shard_enabled
): # apply FSDP or HSDP, potentially with Context Parallel
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
apply_fsdp(
model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
)

if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
)
logger.info("Applied FSDP to the model")

if parallel_dims.cp_enabled:
logger.info("Applied Context Parallel to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
)


def apply_tp(
Expand Down
54 changes: 53 additions & 1 deletion torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import gc
import os
import subprocess
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional, Union
from typing import Generator, List, Optional, Set, Union

import torch
import torch.distributed._functional_collectives as funcol
Expand Down Expand Up @@ -101,6 +102,57 @@ def run(self, step_count):
SKIP_CLEANUP = "3"


def create_context_parallel_ctx(
cp_mesh: DeviceMesh,
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_no_restore_buffers: Set[torch.Tensor],
):
try:
from torch.distributed.tensor.experimental import context_parallel
except ImportError:
print(
f"PyTorch version {torch.__version__} does not include the experimental "
"Context Parallel API. Please update to a newer version."
)

return context_parallel(
cp_mesh,
buffers=cp_buffers,
buffer_seq_dims=cp_seq_dims,
no_restore_buffers=cp_no_restore_buffers,
)


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextlib.contextmanager
def context(cp_context: Optional[Generator[None, None, None]] = None):
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

if cp_context is not None:
from torch.nn.attention import sdpa_kernel, SDPBackend

# currently we only support these two SDP backends.
# TODO (xilunwu): support cuDNN backend
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious if you recall what the blocker for CUDNN_ATTENTION is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tmm1 it's simply cudnn attention has a different op signature. I'm adding support now and should be able to have the PR draft out by next week.

stack.enter_context(
sdpa_kernel(
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
)
)
stack.enter_context(cp_context)

yield

return context


def init_distributed(job_config):
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
Expand Down
Loading
Loading