Skip to content

Commit

Permalink
Integrate distributed inference with chat/server (#1381)
Browse files Browse the repository at this point in the history
* Integrate distributed inference without introducing abstraction

* Cleanup old distributed inference integration

* Read distribution from model_config

* Declare distribution_path if args.model is not given

* Address some nits from PR review

* Added comment on model size all reduce + type hint

* Apply suggestions from code review

Co-authored-by: Jack-Khuu <[email protected]>

* Make sure speculative decoding is disable for pp >1 and remark this in the comments as well

* Refactor conditions in pp

* Rename and alter signature of setup_env to reflect that it also runs the target

* Rename setup_env in server + fix condition

* Update generate.py

* Add default value to add_generation_prompt to preserve bc

---------

Co-authored-by: Jack-Khuu <[email protected]>
  • Loading branch information
mreso and Jack-Khuu authored Dec 19, 2024
1 parent b1b32f1 commit cc0ffce
Show file tree
Hide file tree
Showing 8 changed files with 596 additions and 956 deletions.
111 changes: 109 additions & 2 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@
import torch
import torch._dynamo.config
import torch._inductor.config
import torch.nn as nn
import torch.distributed as dist

from torchchat.model import Model, ModelArgs, ModelType
from torchchat.distributed.utils import(
Color as color,
CUDATrackTime,
init_distributed,
GPUMemoryMonitor,
)
from torchchat.distributed.logging_utils import SingletonLogger

from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
device_sync,
Expand All @@ -28,6 +35,7 @@
from torchchat.utils.measure_time import measure_time
from torchchat.utils.quantize import quantize_model


from torchtune.models.convert_weights import meta_to_tune

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
Expand Down Expand Up @@ -56,6 +64,7 @@ class BuilderArgs:
pp: int = 1
tp: int = 1
chpt_from: str = "hf"
distribution_path: Optional[str] = None
is_chat_model: bool = False
prefill_possible: bool = False
dynamic_shapes: bool = False
Expand Down Expand Up @@ -107,6 +116,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":

checkpoint_path = args.checkpoint_path
params_table = args.params_table
distribution_path = None
if args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)

Expand All @@ -121,6 +131,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
model_config.transformer_params_key or model_config.name.split("/")[-1]
)

distribution_path = model_config.distribution_path

dso_path = getattr(args, "dso_path", None)
pte_path = getattr(args, "pte_path", None)
aoti_package_path = getattr(args, "aoti_package_path", None)
Expand Down Expand Up @@ -186,6 +198,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
pp=pp,
tp=tp,
chpt_from=chpt_from,
distribution_path=distribution_path,
is_chat_model=is_chat_model,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
max_seq_length=getattr(args, "max_seq_length", None),
Expand Down Expand Up @@ -601,6 +614,100 @@ def do_nothing(max_batch_size, max_seq_length):
model = PTEModel(config, builder_args.pte_path)
except Exception:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
elif builder_args.distributed:
pp_degree = builder_args.pp
tp_degree = builder_args.tp

init_distributed()
rank = dist.get_rank()
torch.cuda.set_device(rank % torch.cuda.device_count())

logger = SingletonLogger.get_logger()

gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

# Model-level config
if builder_args.params_table:
model_config = ModelArgs.from_table(builder_args.params_table)
else:
raise NotImplementedError()
# Transformer-level config
config = TransformerArgs.from_params(model_config.transformer_args["text"])
logger.info(f"Transformer Config: {config}")

#TODO: Move into head of file after solving circular import
from torchchat.distributed.checkpoint_utils import (
load_model_weights,
)

# Validate pipeline degree
assert config.n_layers % pp_degree == 0

# Create device mesh
device_mesh = dist.init_device_mesh(
"cuda",
(pp_degree, tp_degree),
mesh_dim_names=("pp", "tp")
)
tp_mesh = device_mesh["tp"]
pp_mesh = device_mesh["pp"]
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")

pp_rank = pp_mesh.get_local_rank()
logger.info(f"{pp_degree=}, {tp_degree=}")

# Assuming same number of GPUs per node
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

# Fill in PP configs
config.stage_idx = pp_rank
config.n_stages = pp_degree

with torch.device("meta"):
# TODO: we should create model instead of Transformer
model = Transformer(config)

# Distribute model on TP mesh
# (Surprisingly, this works even though model is on meta device and mesh is of
# cuda devices)
model.distribute(tp_mesh)
if rank == 0:
logger.info(f"Model: {model}")

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Setup KV caches (after model distribution)
# The number of cache lanes is the same as the maximum number of
# micro-batches that can be "in flight" in parallel -- imagine each
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
# When decoding is done for certain micro-batches, we can reuse the KV cache
# lanes.
# TODO: bump up the lane count
pipeline_lanes = 1
seqlen_prefill=1024
with device:
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)

# info on stage size and params
# stage_size = get_module_size(model)
# stage_size_formatted = bytes_to_readable(stage_size)
# stage_num_params = get_num_params(model)
# logger.info(
# f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
# )
model.eval()

model.text_transformer_args = None
model.config.model_type = model_config.model_type
model.device_mesh = device_mesh
else:
with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args)
Expand Down
32 changes: 32 additions & 0 deletions torchchat/distributed/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.distributed._tensor import DTensor
from torchchat.distributed.dtensor_utils import convert_to_dtensor
from torchchat.cli.builder import BuilderArgs, _load_checkpoint
from torchchat.model import ModelArgs


_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
Expand Down Expand Up @@ -450,3 +451,34 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
# Fill state dict into stage module
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")


def load_model_weights(
stage_module: torch.nn.Module,
distribution: str,
device: torch.device,
model_config: ModelArgs,
chpt_from: str,
):
"""Load the weights from the safetensor file(s) into the model stage.
Model config is needed b/c we permute wq and wk weights based on attn heads.
Args:
stage_module (torch.nn.Module): The model stage to load the weights into.
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
device (torch.device): The device to load the weights onto.
model_config (ModelArgs): The model config.
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
"""
if chpt_from == "hf":
# This format stands for: index file + multiple binary files
load_weights_from_hf_format(stage_module, distribution, device, model_config)
elif chpt_from == "torchchat":
# This format stands for:
# single binary file, OR
# multiple binary files without index files.
load_weights_from_torchchat_format(
stage_module, distribution, device, model_config
)
else:
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
Loading

0 comments on commit cc0ffce

Please sign in to comment.