Skip to content

Commit

Permalink
enable Context Parallel
Browse files Browse the repository at this point in the history
ghstack-source-id: a0832f24bf6cfacb5e74dcdc3bca3fb58caca4aa
Pull Request resolved: #592
  • Loading branch information
XilunWu committed Oct 21, 2024
1 parent 36fba84 commit 28aceb1
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 43 deletions.
1 change: 1 addition & 0 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def estimate_memory(job_config: JobConfig):
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down
23 changes: 23 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,29 @@ def build_test_list():
"hsdp+tp",
ngpu=8,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"FSDP+CP",
"fsdp+cp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"HSDP+CP",
"hsdp+cp",
ngpu=8,
),
OverrideDefinitions(
[
[
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,12 @@ def __init__(self):
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--experimental.context_parallel_degree",
type=int,
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,8 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
# Note: removed the 2x relaxing in CP enablement
self.model_args.max_seq_len,
self.model_args.rope_theta,
)

Expand Down
28 changes: 21 additions & 7 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class ParallelDims:
dp_replicate: int
dp_shard: int
cp: int
tp: int
pp: int
world_size: int
Expand All @@ -24,36 +25,38 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, tp, pp = (
dp_replicate, dp_shard, cp, tp, pp = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
)
for d in (dp_replicate, tp, pp):
for d in (dp_replicate, cp, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (tp * pp)
dp = self.world_size // (cp * tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate

assert dp_replicate >= 1
assert dp_shard >= 1
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
["pp", "dp_replicate", "dp_shard", "tp"],
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
dims.append(d)
Expand All @@ -71,6 +74,13 @@ def build_mesh(self, device_type):
# initialized
if self.dp_replicate > 1 and self.dp_shard > 1:
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")

if self.cp > 1:
if self.dp_replicate > 1 and self.dp_shard > 1:
mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp")
else:
mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")

return mesh

@property
Expand All @@ -85,6 +95,10 @@ def dp_replicate_enabled(self):
def dp_shard_enabled(self):
return self.dp_shard > 1

@property
def cp_enabled(self):
return self.cp > 1

@property
def tp_enabled(self):
return self.tp > 1
Expand Down
67 changes: 38 additions & 29 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
import torch.nn as nn

from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.replicate import replicate
Expand Down Expand Up @@ -72,36 +73,44 @@ def parallelize_llama(
)
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_shard_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
else:
dp_mesh = world_mesh["dp"]

apply_fsdp(
model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[
job_config.training.mixed_precision_reduce
],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
)
if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
logger.info("Applied FSDP to the model")
if parallel_dims.dp_shard_enabled: # apply FSDP or HSDP, potentially with Context Parallel
dp_mesh_dim_names = (
("dp_replicate", "dp_shard")
if parallel_dims.dp_replicate_enabled
else ("dp",)
)
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)
apply_fsdp(
model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[
job_config.training.mixed_precision_reduce
],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
)

if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
)
logger.info("Applied FSDP to the model")

if parallel_dims.cp_enabled:
logger.info("Applied Context Parallel to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
)


def apply_tp(
Expand Down
59 changes: 54 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from datetime import timedelta

import torch

from typing import List, Optional, Set
from functools import partial

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.elastic.multiprocessing.errors import record
from torch.nn.attention import SDPBackend, sdpa_kernel

from torchtitan import utils
from torchtitan.checkpoint import CheckpointManager, TrainState
Expand All @@ -28,17 +34,52 @@
)
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling

try:
from torch.distributed.tensor.experimental import context_parallel
except ImportError:
print(
f"PyTorch version {torch.__version__} does not include the experimental "
"Context Parallel API. Please update to a newer version."
)


def get_train_context(
enable_loss_parallel: bool,
enable_compiled_autograd: bool,
cp_mesh: Optional[DeviceMesh] = None,
):
if cp_mesh is not None:
context_parallel_ctx = partial(context_parallel, mesh=cp_mesh)

def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextlib.contextmanager
def context():
def context(
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_no_restore_buffers: Set[torch.Tensor],
):
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

if cp_mesh is not None:
# currently we only support these two SDP backends.
# TODO (xilunwu): support cuDNN backend
stack.enter_context(
sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
)
stack.enter_context(
context_parallel_ctx(
buffers=cp_buffers,
buffer_seq_dims=cp_seq_dims,
no_restore_buffers=cp_no_restore_buffers,
)
)

yield

return context
Expand Down Expand Up @@ -70,6 +111,7 @@ def main(job_config: JobConfig):
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down Expand Up @@ -235,6 +277,7 @@ def loss_fn(pred, labels):
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
world_mesh["cp"] if parallel_dims.cp_enabled else None,
)

# variables used to keep info for metrics logging
Expand Down Expand Up @@ -268,18 +311,24 @@ def loss_fn(pred, labels):
data_load_start = time.perf_counter()
batch = next(data_iterator)
input_ids, labels = batch
ntokens_since_last_log += labels.numel()
ntokens_since_last_log += labels.numel() // parallel_dims.cp
data_loading_times.append(time.perf_counter() - data_load_start)

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()

training_context = train_context(
cp_buffers=[input_ids, labels, model.freqs_cis],
cp_seq_dims=[1, 1, 0],
cp_no_restore_buffers={input_ids, labels},
)

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
with training_context:
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -296,7 +345,7 @@ def loss_fn(pred, labels):
)
else:
# Non-PP forward / backward
with train_context():
with training_context:
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[experimental]
context_parallel_degree = 1
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

Expand Down

0 comments on commit 28aceb1

Please sign in to comment.