From 7b4d5c580f70c8cea16dc12aa976b833f09b73c4 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Fri, 15 Nov 2024 17:59:58 -0800 Subject: [PATCH 01/13] Integrate distributed inference without introducing abstraction --- torchchat/cli/builder.py | 123 ++++++- torchchat/distributed/checkpoint_utils.py | 32 ++ torchchat/distributed/utils.py | 14 +- torchchat/generate.py | 403 +++++++++++++++++++--- torchchat/usages/openai_api.py | 18 +- torchchat/usages/server.py | 90 ++++- 6 files changed, 622 insertions(+), 58 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index f67cb9d0a..6ea5ff25f 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -14,10 +14,17 @@ import torch import torch._dynamo.config import torch._inductor.config -import torch.nn as nn +import torch.distributed as dist -from torchchat.model import Model, ModelArgs, ModelType +from torchchat.distributed.utils import( + Color as color, + CUDATrackTime, + init_distributed, + GPUMemoryMonitor, +) +from torchchat.distributed.logging_utils import SingletonLogger +from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( device_sync, @@ -28,6 +35,7 @@ from torchchat.utils.measure_time import measure_time from torchchat.utils.quantize import quantize_model + from torchtune.models.convert_weights import meta_to_tune from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE @@ -598,6 +606,117 @@ def do_nothing(max_batch_size, max_seq_length): model = PTEModel(config, builder_args.pte_path) except Exception: raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") + elif builder_args.distributed: + # Using params_table to identify the model to load, for example "Meta-Llama-3.1-8B". + #TODO This is a hacky way to please the distributed loading api and needs to be replaced + NAME_TO_DISTRIBUTION = { + "Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct", + "Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "Meta-Llama-3-70B": "meta-llama/Meta-Llama-3-70B-Instruct", + "Meta-Llama-3.1-70B": "meta-llama/Meta-Llama-3.1-70B-Instruct", + + } + # TODO: Use information in builder_args directly to build model and load weights + assert builder_args.params_table + try: + distribution = NAME_TO_DISTRIBUTION[builder_args.params_table] + except KeyError as e: + print(f"Unknown params_table: {builder_args.params_table}. Suported model names are: llama3.1, llama3, llama2-7b-chat") + raise e + + pp_degree = builder_args.pp + tp_degree = builder_args.tp + + init_distributed() + rank = dist.get_rank() + torch.cuda.set_device(rank % torch.cuda.device_count()) + + logger = SingletonLogger.get_logger() + + gpu_memory_monitor = GPUMemoryMonitor("cuda") + logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") + + # Model-level config + if builder_args.params_table: + model_config = ModelArgs.from_table(builder_args.params_table) + else: + raise NotImplementedError() + # Transformer-level config + config = TransformerArgs.from_params(model_config.transformer_args["text"]) + logger.info(f"Transformer Config: {config}") + + #TODO: Move into head of file after solving circular import + from torchchat.distributed.checkpoint_utils import ( + load_model_weights, + ) + + # Validate pipeline degree + assert config.n_layers % pp_degree == 0 + + # Create device mesh + device_mesh = dist.init_device_mesh( + "cuda", + (pp_degree, tp_degree), + mesh_dim_names=("pp", "tp") + ) + tp_mesh = device_mesh["tp"] + pp_mesh = device_mesh["pp"] + logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") + + pp_rank = pp_mesh.get_local_rank() + logger.info(f"{pp_degree=}, {tp_degree=}") + + # Assuming same number of GPUs per node + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + + # Fill in PP configs + config.stage_idx = pp_rank + config.n_stages = pp_degree + + with torch.device("meta"): + # TODO: we should create model instead of Transformer + model = Transformer(config) + + # Distribute model on TP mesh + # (Surprisingly, this works even though model is on meta device and mesh is of + # cuda devices) + model.distribute(tp_mesh) + if rank == 0: + logger.info(f"Model: {model}") + + # Load weights + logger.info(f"Loading weights for {pp_rank=} on {device=}") + with CUDATrackTime() as timer: + load_model_weights(model, distribution, device, config, builder_args.chpt_from) + + logger.info( + f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + + # Setup KV caches (after model distribution) + # The number of cache lanes is the same as the maximum number of + # micro-batches that can be "in flight" in parallel -- imagine each + # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. + # When decoding is done for certain micro-batches, we can reuse the KV cache + # lanes. + # TODO: bump up the lane count + pipeline_lanes = 1 + seqlen_prefill=1024 + with device: + model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) + + # info on stage size and params + # stage_size = get_module_size(model) + # stage_size_formatted = bytes_to_readable(stage_size) + # stage_num_params = get_num_params(model) + # logger.info( + # f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}" + # ) + model.eval() + + model.text_transformer_args = None + model.config.model_type = model_config.model_type + model.device_mesh = device_mesh else: with measure_time("Time to load model: {time:.02f} seconds"): model = _load_model(builder_args) diff --git a/torchchat/distributed/checkpoint_utils.py b/torchchat/distributed/checkpoint_utils.py index cf3206e4e..806855c4b 100644 --- a/torchchat/distributed/checkpoint_utils.py +++ b/torchchat/distributed/checkpoint_utils.py @@ -17,6 +17,7 @@ from torch.distributed._tensor import DTensor from torchchat.distributed.dtensor_utils import convert_to_dtensor from torchchat.cli.builder import BuilderArgs, _load_checkpoint +from torchchat.model import ModelArgs _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json" @@ -450,3 +451,34 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model # Fill state dict into stage module stage_module.load_state_dict(stage_state_dict, strict=False, assign=True) logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") + + +def load_model_weights( + stage_module: torch.nn.Module, + distribution: str, + device: torch.device, + model_config: ModelArgs, + chpt_from: str, +): + """Load the weights from the safetensor file(s) into the model stage. + Model config is needed b/c we permute wq and wk weights based on attn heads. + + Args: + stage_module (torch.nn.Module): The model stage to load the weights into. + distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct". + device (torch.device): The device to load the weights onto. + model_config (ModelArgs): The model config. + chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf". + """ + if chpt_from == "hf": + # This format stands for: index file + multiple binary files + load_weights_from_hf_format(stage_module, distribution, device, model_config) + elif chpt_from == "torchchat": + # This format stands for: + # single binary file, OR + # multiple binary files without index files. + load_weights_from_torchchat_format( + stage_module, distribution, device, model_config + ) + else: + raise ValueError(f"Unknown checkpoint format: {chpt_from}") diff --git a/torchchat/distributed/utils.py b/torchchat/distributed/utils.py index 46ea5d9a1..e935226b0 100644 --- a/torchchat/distributed/utils.py +++ b/torchchat/distributed/utils.py @@ -6,15 +6,15 @@ import itertools import os +import time from dataclasses import dataclass from datetime import timedelta -import time +from os import environ from typing import Optional import torch - from torchchat.distributed.logging_utils import SingletonLogger logger = SingletonLogger.get_logger() @@ -257,3 +257,13 @@ def get_device_info( f"with {self.device_capacity_gib:.2f}GiB memory" ) return device_info + +def setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): + environ["MASTER_ADDR"] = "localhost" + environ["MASTER_PORT"] = "29500" + environ["RDZV_BACKEND"] = "c10d" + environ["WORLD_SIZE"] = str(world_size) + environ["RANK"] = str(rank) + environ["LOCALRANK"] = str(rank) + + return target(*args, **kwargs) diff --git a/torchchat/generate.py b/torchchat/generate.py index 66f26ff9f..e99709950 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -3,13 +3,15 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import argparse import base64 +import contextlib import itertools import logging import os import textwrap import time +from concurrent import futures +from functools import partial from abc import ABC, abstractmethod from dataclasses import dataclass @@ -21,6 +23,9 @@ import torch import torch._dynamo.config import torch._inductor.config +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed.pipelining import PipelineStage, ScheduleGPipe from PIL import Image @@ -28,7 +33,6 @@ from torchtune.data import Message, padded_collate_tiled_images_and_mask from torchtune.generation import sample as tune_sample -from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform from torchtune.training import set_default_dtype @@ -39,11 +43,16 @@ BuilderArgs, TokenizerArgs, ) -from torchchat.distributed.generate import DistributedGenerator +from torchchat.distributed.utils import ( + Color as color, + setup_env, +) from torchchat.model import Model, ModelType from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info +logger = logging.getLogger(__name__) + class _ChatFormatter(ABC): def __init__(self, tokenizer): @@ -214,7 +223,7 @@ def from_args(cls, args): ) -class Generator: +class LocalGenerator: """ Generates text samples based on a pre-trained Transformer model and tokenizer. Args: @@ -251,6 +260,7 @@ def __init__( self.draft_quantize = draft_quantize self.is_torchtune_model = generator_args.is_torchtune_model self.dtype = builder_args.precision + self.get_user_input = input self.rank: Optional[int] = None @@ -283,7 +293,7 @@ def __init__( if self.is_llama3_model: self.chat_formatter = Llama3ChatFormatter(self.tokenizer) if generator_args.chat_mode: - logging.debug( + logger.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) else: @@ -358,7 +368,6 @@ def prefill( sequential_prefill=True, **sampling_kwargs, ) -> torch.Tensor: - # logging.debug(f"x: {x}, input_pos: {input_pos}") width = x.size(1) assert input_pos.size(0) == width @@ -394,14 +403,11 @@ def prefill( elif sequential_prefill: for i in range(width): x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) - # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}") logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])da else: # input_pos: [B, S] logits = model(x, input_pos) - # print(f"logits {logits.shape}") - # print(f"x: {x},\n input_pos: {input_pos}\n") return self.sample(logits, need_probs=False, **sampling_kwargs)[0] def decode_one_token( @@ -425,7 +431,6 @@ def decode_one_token( )[:, -1:] else: logits = model(x, input_pos) - # print(f"x: {x},\n input_pos: {input_pos}\n") return self.sample(logits, need_probs=need_probs, **sampling_kwargs) """ @@ -727,7 +732,7 @@ def encode_tokens(self, string, bos=True, device="cpu"): tokens = self.tokenizer.encode(string) if bos: tokens = [self.tokenizer.bos_id()] + tokens - logging.debug(f"Size after encode_tokens: {len(tokens)}") + logger.debug(f"Size after encode_tokens: {len(tokens)}") return torch.tensor(tokens, dtype=torch.int, device=device) def _callback(self, x, *, buffer, done_generating): @@ -747,7 +752,6 @@ def _callback(self, x, *, buffer, done_generating): if len(buffer) == 4 or done_generating: print("".join(buffer), end="", flush=True) buffer.clear() - # print(, end='', flush=True) def _gen_model_input( self, @@ -785,7 +789,7 @@ def _gen_model_input( tokens, dtype=torch.int, device=self.builder_args.device ) - logging.debug(encoded) + logger.debug(encoded) return encoded, None # Llama 3.2 11B @@ -900,7 +904,7 @@ def _gen_model_input( value=0, ) - logging.debug(encoded) + logger.debug(encoded) return encoded, batch def chat( @@ -916,6 +920,11 @@ def chat( for p in itertools.chain(self.model.parameters(), self.model.buffers()) ] ) + if self.builder_args.distributed: + model_size = torch.tensor(model_size, dtype=torch.int64, device=self.device) + dist.all_reduce(model_size) + model_size = model_size.item() + if generator_args.compile: if self.builder_args.device == "cpu": if generator_args.max_autotune: @@ -974,11 +983,11 @@ def chat( print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) - get_system_prompt = input( + get_system_prompt = self.get_user_input( "Do you want to enter a system prompt? Enter y for yes and anything else for no. \n" ) if get_system_prompt == "y" or get_system_prompt == "Y": - self.system_prompt = input("What is your system prompt? \n") + self.system_prompt = self.get_user_input("What is your system prompt? \n") # `is_torchtune_model` is a misnomer since it doesn't capture all # torchtune models (i.e. Flamingo) @@ -1017,7 +1026,7 @@ def chat( device_sync(device=self.builder_args.device) is_first_sample: bool = i == 0 if generator_args.chat_mode: - prompt = input("User: ") + prompt = self.get_user_input("User: ") if prompt == "/bye": print("Exiting Chat.\n") break @@ -1088,8 +1097,6 @@ def callback(x, *, done_generating=False): torch._inductor.config.profiler_mark_wrapper_call = True torch._inductor.config.cpp.enable_kernel_profile = True if i != generator_args.num_samples - 1 or not self.profile: - import contextlib - prof = contextlib.nullcontext() else: torch.profiler._utils._init_for_cuda_graphs() @@ -1153,7 +1160,7 @@ def callback(x, *, done_generating=False): aggregate_metrics["first_token_per_sec"].append(first_token_sec) aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) - logging.info( + logger.info( f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\ \nGenerated {num_tokens_generated} tokens \ \nTime for inference {i + 1}: {t:.04f} sec total \ @@ -1164,11 +1171,11 @@ def callback(x, *, done_generating=False): \n Next token throughput: {next_tokens_sec:.04f} tokens/sec, {1 / next_tokens_sec:.04f} s/token \ " ) - logging.info( + logger.info( f"\nBandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s" ) if i == 0: - logging.info( + logger.info( f"*** This first iteration will include cold start effects for dynamic import, hardware caches{', JIT compilation' if jit_compile else ''}. ***" ) print("\n========================================\n") @@ -1214,21 +1221,325 @@ def callback(x, *, done_generating=False): print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") -def _launch_distributed_inference( - builder_args: BuilderArgs, -): - from torch.distributed import launcher - from torch.distributed.elastic.utils.distributed import get_free_port +class DistributedGenerator(LocalGenerator): + def __init__( + self, + builder_args: BuilderArgs, + speculative_builder_args: BuilderArgs, + tokenizer_args: TokenizerArgs, + generator_args: GeneratorArgs, + profile: Optional[Path], + quantize: bool, + draft_quantize: bool, + ): + super().__init__( + builder_args, + speculative_builder_args, + tokenizer_args, + generator_args, + profile, + quantize, + draft_quantize, + ) + self.rank = dist.get_rank() + # Assuming same number of GPUs per node + self.device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") - print("Launching distributed inference within generator") + def distributed_input(prompt: str) -> str: + if dist.get_rank() == 0: + text = [input(prompt)] + else: + text = [None] + + dist.broadcast_object_list(text) + return text[0] + self.get_user_input = distributed_input -def main(args): + if builder_args.pp > 1: + self.seqlen_prefill = 1024 # sequence length for prefill stage + + logger.warn(f"{color.yellow}Pipeline parallelism is still experimental and might be slow{color.reset}") + pp_mesh = self.model.device_mesh["pp"] + + self.pp_rank = pp_mesh.get_local_rank() + self.pp_group = pp_mesh.get_group() + + self.pp_degree = pp_mesh.size() + + # Convenience variables + self.first_pp_rank = 0 + self.last_pp_rank = self.pp_degree - 1 + + + self.first_pp_rank_global_id = dist.get_global_rank(self.pp_group, self.first_pp_rank) + self.last_pp_rank_global_id = dist.get_global_rank(self.pp_group, self.last_pp_rank) + + self.prefiller = self.create_prefill_stage() + self.decoder = self.create_decode_stage() + + def __del__(self): + dist.destroy_process_group() + + # Helper function to get example inputs and outputs for the stages. + def get_example_ins_outs(self, batch_size:int , seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function generates example inputs and outputs for the prefill and decode stages. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the example inputs and outputs. + """ + model_dtype = torch.bfloat16 + mb_ids = torch.randint( + 0, self.model.config.vocab_size, (batch_size, seqlen), device=self.device + ) + activation = torch.rand( + batch_size, seqlen, self.model.config.dim, device=self.device, dtype=model_dtype + ) + logits = torch.rand( + batch_size, seqlen, self.model.config.vocab_size, device=self.device, dtype=model_dtype + ) + example_inputs = (mb_ids if self.pp_rank == self.first_pp_rank else activation,) + example_outputs = (logits if self.pp_rank == self.last_pp_rank else activation,) + return example_inputs, example_outputs + + def create_prefill_stage( + self, + ): + """ + Creates a pipeline stage for prefilling. + + Returns: + PipelineStage: The created pipeline stage. + """ + batch_size = 1 + + # Create prefill stage + logger.debug(f"Creating pipeline stage for prefill {self.pp_rank=}, {self.pp_degree=}") + example_inputs, example_outputs = self.get_example_ins_outs(batch_size, self.seqlen_prefill) + prefill_stage = PipelineStage( + self.model, + self.pp_rank, + self.pp_degree, + self.device, + input_args=example_inputs, + output_args=example_outputs, + group=self.pp_group, + ) + + # Create schedule + # Number of micro-batches for the schedule is 1, because each step() call we + # only push 1 micro-batch into the pipeline. But we can continuously push + # new micro-batches into the pipeline as they arrive, achieving same + # pipelining effect. + prefiller = ScheduleGPipe(prefill_stage, 1) + return prefiller + + def create_decode_stage( + self, + ): + """ + Creates a decode stage for the pipeline parallelism. + + Returns: + ScheduleGPipe: The decode stage. + """ + # seqlen = 1 now + seqlen_decode = 1 + batch_size = 1 + + # Create decode stage + # logger.info(f"Creating pipeline stage for decode {self.pp_rank=}, {self.pp_degree=}") + example_inputs, example_outputs = self.get_example_ins_outs(batch_size, seqlen_decode) + decode_stage = PipelineStage( + self.model, + self.pp_rank, + self.pp_degree, + self.device, + input_args=example_inputs, + output_args=example_outputs, + group=self.pp_group, + ) + # create schedule + decoder = ScheduleGPipe(decode_stage, 1) + + return decoder + + def prefill( + self, + model: Model, + x: torch.Tensor, + input_pos: torch.Tensor, + batch: Optional[Dict[str, Any]] = None, # Inputs for multimodal models + *, + sequential_prefill=True, + **sampling_kwargs, + ) -> torch.Tensor: + """ + This function is used to prefill the model with a given prompt. For pipeline parallelism we need to pad the input. + + Returns: + torch.Tensor: The prefilled tensor. + """ + if self.builder_args.pp == 1: + return super().prefill( + model, + x, + input_pos, + batch, + sequential_prefill=sequential_prefill, + **sampling_kwargs, + ) + + pad_token_id = self.tokenizer.pad_id if self.tokenizer.pad_id is not None else self.tokenizer.eos_id + prompt_length = x.size(1) + + padded_seq = torch.full( + (1, self.seqlen_prefill), pad_token_id, dtype=torch.int64, device=self.device + ) + padded_seq[:,:prompt_length] = x + input_pos = torch.arange( + self.seqlen_prefill, + device=self.device, + dtype=torch.int, + ) + + # Prefill phase + # Run context input through pipeline + # TODO: we need to pass `input_pos` and `cache_lane` to each stage. + lane = 0 + kwargs = {"input_pos": input_pos, "cache_lane": lane} + + if self.pp_rank == self.first_pp_rank: + logits = self.prefiller.step(padded_seq, **kwargs) + elif self.pp_rank == self.last_pp_rank: + logits = self.prefiller.step(**kwargs) + else: # middle pp ranks + self.prefiller.step(**kwargs) + + new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64) + + if self.pp_rank == self.last_pp_rank: + new_token = self.sample(logits[:,:prompt_length], need_probs=False, **sampling_kwargs)[0] + + + if self.pp_rank == self.last_pp_rank and self.pp_rank != self.first_pp_rank: + dist.send( + new_token, + dst=self.first_pp_rank_global_id, + group=self.pp_group, + ) + elif self.pp_rank == self.first_pp_rank and self.pp_rank != self.last_pp_rank: + dist.recv( + new_token, + src=self.last_pp_rank_global_id, + group=self.pp_group, + ) + + return new_token + + def decode_one_token( + self, + model: Model, + x: torch.Tensor, + input_pos: torch.Tensor, + need_probs: bool, + batch: Optional[Dict[str, Any]] = None, # Inputs for multimodal models + **sampling_kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Decodes a single token. + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing the decoded token and its probability. + """ + if self.builder_args.pp == 1: + return super().decode_one_token( + model, + x, + input_pos, + need_probs, + batch=batch, + **sampling_kwargs, + ) + + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + + new_token = x.view(1, -1) + + lane = 0 + kwargs = {"input_pos": input_pos, "cache_lane": lane} + # Run data through pipeline + if self.pp_rank == self.first_pp_rank: + logits = self.decoder.step(new_token, **kwargs) + elif self.pp_rank == self.last_pp_rank: + logits = self.decoder.step(**kwargs) + else: # middle pp ranks + self.decoder.step(**kwargs) + + # Decode the output + if self.pp_rank == self.last_pp_rank: + new_token, next_prob = self.sample(logits, need_probs=need_probs, **sampling_kwargs) + else: + new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64) + + if self.pp_rank == self.last_pp_rank and self.pp_rank != self.first_pp_rank: + dist.send( + new_token, + dst=self.first_pp_rank_global_id, + group=self.pp_group, + ) + elif self.pp_rank == self.first_pp_rank and self.pp_rank != self.last_pp_rank: + dist.recv( + new_token, + src=self.last_pp_rank_global_id, + group=self.pp_group, + ) + #TODO: Why do we get 2d tensor here? + new_token=new_token[0] + return new_token, None + + def sample( + self, + logits, + need_probs: bool, + temperature: float = 0, + top_k: Optional[int] = None, + ): + if temperature == 0 and not need_probs: + _, idx_next = torch.topk(logits[0, -1], k=1, dim=-1) + return (idx_next, None) + probs = self.logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = self.multinomial_sample_one_no_sync(probs) + if self.builder_args.pp == 1: + dist.broadcast(idx_next, src=0) + dist.broadcast(probs, src=0) + return idx_next, probs + + +def run_generator( + args, + rank: Optional[int] =None + ): + """ + This function creates and executes a generator + """ builder_args = BuilderArgs.from_args(args) speculative_builder_args = BuilderArgs.from_speculative_args(args) tokenizer_args = TokenizerArgs.from_args(args) generator_args = GeneratorArgs.from_args(args) - if not builder_args.distributed: + + #Setup rank 1 and up to suppress log messages and print messages + if builder_args.distributed and rank != 0: + logger.setLevel(logging.CRITICAL) + context = contextlib.redirect_stdout(None) + else: + context = contextlib.nullcontext() + + with context: + Generator = DistributedGenerator if builder_args.distributed else LocalGenerator + gen = Generator( builder_args, speculative_builder_args, @@ -1243,20 +1554,20 @@ def main(args): for _ in gen.chat(generator_args): pass - else: - dist_gen = DistributedGenerator( - args.model, - builder_args, - tokenizer_args, - generator_args, - args.profile, - args.quantize, - args.draft_quantize, - ) - response = "" - for output in dist_gen.generate(generator_args.prompt): - response += output.text - - print(f"Model output: {response}") - dist_gen.shutdown() +def main(args): + builder_args = BuilderArgs.from_args(args) + + if builder_args.distributed: + world_size = builder_args.tp * builder_args.pp + + ctx = mp.get_context('spawn') + with futures.ProcessPoolExecutor(max_workers=world_size-1, mp_context=ctx) as executor: + for i in range(1,world_size): + fn = partial(run_generator, args, i) + executor.submit(setup_env, world_size, i, fn) + #Starting rank 0 + fn = partial(run_generator, args, 0) + setup_env(world_size, 0, fn) + else: + run_generator(args) diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 99fd82fe8..b1ad151f1 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -24,7 +24,7 @@ from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform from torchchat.cli.download import is_model_downloaded, load_model_configs -from torchchat.generate import Generator, GeneratorArgs +from torchchat.generate import LocalGenerator, DistributedGenerator, GeneratorArgs from torchchat.model import FlamingoModel from torchchat.utils.build_utils import device_sync @@ -267,7 +267,7 @@ class CompletionResponseChunk: usage: Optional[UsageStats] = None -class OpenAiApiGenerator(Generator): +class OpenAiApiGeneratorMixin: """A wrapper over the Generator class to interface with the OpenAI API. Implements endpoints for completion requests, both chunked and non-chunked using the dataclasses @@ -486,6 +486,20 @@ def _callback(self, x, *, buffer, done_generating): pass +def create_openai_api_generator(distributed): + """ + Factory method to create an OpenAiApiGenerator + """ + + if distributed: + # Base class order matters to make sure OpenAiApiGeneratorMixin overrides methods in DistributedGenerator and Generator + return type('OpenAiApiGenerator', (OpenAiApiGeneratorMixin, DistributedGenerator), {}) + else: + return type('OpenAiApiGenerator', (OpenAiApiGeneratorMixin, LocalGenerator), {}) + + + + """ Helper functions for the OpenAI API Models endpoint. diff --git a/torchchat/usages/server.py b/torchchat/usages/server.py index 1fb76953b..afbfeebd7 100644 --- a/torchchat/usages/server.py +++ b/torchchat/usages/server.py @@ -4,38 +4,93 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import atexit import json import logging logger = logging.getLogger(__name__) +from contextlib import nullcontext from dataclasses import asdict +from functools import partial +from os import environ from typing import Dict, List, Union import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from concurrent import futures from flask import Flask, request, Response from torchchat.cli.builder import BuilderArgs, TokenizerArgs +from torchchat.distributed.utils import setup_env from torchchat.generate import GeneratorArgs from torchchat.usages.openai_api import ( CompletionRequest, get_model_info_list, - OpenAiApiGenerator, + create_openai_api_generator, retrieve_model_info, ) OPENAI_API_VERSION = "v1" +def run_worker( + args, + rank, + queue, + ): + """ + This function creates and executes a generator + """ + gen = initialize_generator(args) + + while True: + try: + req = queue.get() + except KeyboardInterrupt: + break + + if req == "stop": + break + + for _ in gen.chunked_completion(req): + pass + def create_app(args): # noqa: C901 """ Creates a flask app that can be used to serve the model as a chat API. """ app = Flask(__name__) - gen: OpenAiApiGenerator = initialize_generator(args) + builder_args = BuilderArgs.from_args(args) + procs = [] + if builder_args.distributed: + world_size = builder_args.tp * builder_args.pp + mp_context = mp.get_context('spawn') + queue = mp_context.Queue() + else: + world_size = 1 + queue = None + + + if builder_args.distributed: + for i in range(1, world_size): + fn = partial(run_worker, args, i, queue) + mp_context = mp.get_context('spawn') + procs.append(mp_context.Process(target=setup_env, args=(world_size, i, fn))) + procs[-1].start() + + environ["MASTER_ADDR"] = "localhost" + environ["MASTER_PORT"] = "29500" + environ["RDZV_BACKEND"] = "c10d" + environ["WORLD_SIZE"] = str(world_size) + environ["RANK"] = str(0) + environ["LOCALRANK"] = str(0) + + gen = initialize_generator(args) def _del_none(d: Union[Dict, List]) -> Union[Dict, List]: """Recursively delete None values from a dictionary.""" @@ -69,6 +124,10 @@ def chat_endpoint(): if req.stream: + if builder_args.distributed: + for _ in range(world_size-1): + queue.put(req) + def chunk_processor(chunked_completion_generator): """Inline function for postprocessing CompletionResponseChunk objects. @@ -86,8 +145,11 @@ def chunk_processor(chunked_completion_generator): ) return resp else: + if builder_args.distributed: + for _ in range(world_size-1): + queue.put(req) + response = gen.sync_completion(req) - print(response.choices[0].message.content) return json.dumps(_del_none(asdict(response))) @@ -102,16 +164,18 @@ def models_retrieve_endpoint(model_id): else: return "Model not found", 404 - return app + return app, (procs, queue) -def initialize_generator(args) -> OpenAiApiGenerator: +def initialize_generator(args) -> type: builder_args = BuilderArgs.from_args(args) speculative_builder_args = BuilderArgs.from_speculative_args(args) tokenizer_args = TokenizerArgs.from_args(args) generator_args = GeneratorArgs.from_args(args) generator_args.chat_mode = False + OpenAiApiGenerator = create_openai_api_generator(builder_args.distributed) + return OpenAiApiGenerator( builder_args=builder_args, speculative_builder_args=speculative_builder_args, @@ -124,5 +188,19 @@ def initialize_generator(args) -> OpenAiApiGenerator: def main(args): - app = create_app(args) + app, (procs, queue) = create_app(args) + + def shutdown_worker(): + while not queue.empty(): + queue.get(block=False) + for p in procs: + queue.put("stop") + for p in procs: + p.join(timeout=0.5) + for p in procs: + if p.is_alive(): + p.kill() + + atexit.register(shutdown_worker) + app.run() From e7670c3972d1c7fb96115927df70a2ec74f994f3 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Fri, 15 Nov 2024 18:03:14 -0800 Subject: [PATCH 02/13] Cleanup old distributed inference integration --- torchchat/distributed/dist_run.py | 629 ------------------------------ torchchat/distributed/generate.py | 271 ------------- 2 files changed, 900 deletions(-) delete mode 100644 torchchat/distributed/dist_run.py delete mode 100644 torchchat/distributed/generate.py diff --git a/torchchat/distributed/dist_run.py b/torchchat/distributed/dist_run.py deleted file mode 100644 index 389ae41c1..000000000 --- a/torchchat/distributed/dist_run.py +++ /dev/null @@ -1,629 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# Example run command: -# torchrun --nproc-per-node 4 dist_run.py llama2-7b-chat --pp 2 -# torchrun --nproc-per-node 4 dist_run.py llama3 --pp 2 - -import argparse -import os -from enum import auto, Enum -from pathlib import Path -from types import MethodType, SimpleNamespace -from typing import Any, Dict, List, Optional, Tuple - -import torch -import torch.distributed as dist -from torch.distributed.pipelining import PipelineStage, ScheduleGPipe -from torchchat.cli.builder import TokenizerArgs - -# TODO - these are not distributed specific, consider moving to new package -from torchchat.distributed.checkpoint_utils import ( - get_hf_config_file, - load_weights_from_hf_format, - load_weights_from_torchchat_format, -) - -from torchchat.distributed.logging_utils import SingletonLogger -from torchchat.distributed.utils import ( - bytes_to_readable, - Color as color, - CUDATrackTime, - get_module_size, - get_num_params, - GPUMemoryMonitor, -) -from torchchat.model import ModelArgs, Transformer, TransformerArgs -from torchchat.utils.build_utils import set_precision - -try: - from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer -except ImportError: - TiktokenTokenizer = None -try: - from sentencepiece import SentencePieceProcessor -except ImportError: - SentencePieceProcessor = None - - -logger = SingletonLogger.get_logger() - -# Using model name to identify the model to load, for example "llama2-7b-chat". -# You can change it to other values listed below. -# For details on the name-to-distribution mapping, see README.md or models.json. -NAME_TO_DISTRIBUTION_AND_DTYPE = { - "llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16), - "llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), - "llama3.1": ("meta-llama/Meta-Llama-3.1-8B-Instruct", torch.bfloat16), -} - - -def _init_distributed(): - dist.init_process_group("nccl") - rank = dist.get_rank() - world_size = dist.get_world_size() - # Assuming same number of GPUs per node - torch.cuda.set_device(rank % torch.cuda.device_count()) - return rank, world_size - - -def _create_device_mesh(pp_degree, tp_degree): - return dist.init_device_mesh( - "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") - ) - - -def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: - return SimpleNamespace(**dictionary) - - -def _patch_tokenizer(tokenizer): - """Patch the tokenizer to support decoding of token ids.""" - if isinstance(tokenizer, TiktokenTokenizer): - # Patch tiktokenizer to allow a list of sequences. - # TODO: Upstream to tokenizer modules - old_decode = tokenizer.decode - - def decode( - self, token_ids: List[int | List[int]], *args, **kwargs - ) -> str | List[str]: - if len(token_ids) < 1: - return "" - if isinstance(token_ids[0], list): - return [old_decode(t, *args, **kwargs) for t in token_ids] - else: - return old_decode(token_ids, *args, **kwargs) - - tokenizer.decode = MethodType(decode, tokenizer) - return tokenizer - - -def _build_chat_tokenizer( - tokenizer_args: TokenizerArgs, -) -> SentencePieceProcessor | TiktokenTokenizer: - """Builds a tokenizer for the given model name""" - - tokenizer_args = TokenizerArgs.from_args(tokenizer_args) - tokenizer = tokenizer_args.t - assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}" - logger.info( - f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}" - ) - - tokenizer = _patch_tokenizer(tokenizer) - - return tokenizer - - -def _load_model_weights( - stage_module: torch.nn.Module, - distribution: str, - device: torch.device, - model_config: ModelArgs, - chpt_from: str, -): - """Load the weights from the safetensor file(s) into the model stage. - Model config is needed b/c we permute wq and wk weights based on attn heads. - - Args: - stage_module (torch.nn.Module): The model stage to load the weights into. - distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct". - device (torch.device): The device to load the weights onto. - model_config (ModelArgs): The model config. - chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf". - """ - if chpt_from == "hf": - # This format stands for: index file + multiple binary files - load_weights_from_hf_format(stage_module, distribution, device, model_config) - elif chpt_from == "torchchat": - # This format stands for: - # single binary file, OR - # multiple binary files without index files. - load_weights_from_torchchat_format( - stage_module, distribution, device, model_config - ) - else: - raise ValueError(f"Unknown checkpoint format: {chpt_from}") - - -def _encode_strings( - strings: List[str], - tokenizer, - bos: bool, - device: torch.device, - dtype=torch.int64, -) -> List[torch.Tensor]: - """Encode a list of prompt strings into a list of tensor token ids.""" - encoded_list = [] - for string in strings: - tokens = tokenizer.encode(string) - if bos: - tokens = [tokenizer.bos_id()] + tokens - encoded_list.append(torch.tensor(tokens, dtype=dtype, device=device)) - return encoded_list - - -def _create_padded_prompts( - input_ids_list: List[torch.Tensor], - tokenizer, - seqlen: int, - start_pos: int, - device: torch.device, - pad_token_id: Optional[int] = None, -) -> Tuple[torch.Tensor, List[int]]: - """ - Create a padded tensor for multiple encoded input prompts. - - Returns: - Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths. - """ - pad_token_id = pad_token_id if pad_token_id is not None else tokenizer.eos_id() - - # Find the maximum prompt length - max_prompt_len = max(ids.size(0) for ids in input_ids_list) - - # Calculate the buffer size - max_new_tokens = max(0, min(seqlen - start_pos, seqlen - max_prompt_len)) - token_buffer_size = max_prompt_len + max_new_tokens - - # Create the padded batch tensor - batch_size = len(input_ids_list) - batch_seq = torch.full( - (batch_size, token_buffer_size), pad_token_id, dtype=torch.int64, device=device - ) - - prompt_lengths = [] - for i, input_ids in enumerate(input_ids_list): - prompt_len = input_ids.size(0) - batch_seq[i, :prompt_len] = input_ids - prompt_lengths.append(prompt_len) - - return batch_seq, prompt_lengths - - -def _batch_decode_next_tokens( - output: torch.Tensor, - pos: List[int] = None, - temperature: float = 1.0, - topk: int = 10, -) -> torch.Tensor: - """ - Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding. - - Args: - output (torch.Tensor): The output tensor to decode. - pos (List[int]): The positions of the `output` to decode in the sequence length dimension. - step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token. - temperature (float): Sampling temperature for non-deterministic decoding. - - Returns: - torch.Tensor: Decoded token ids. - """ - batch_size, seq_len, vocab_size = output.shape - - if pos is None: - # `pos` is not provided, so we can use the first token - next_token_logits = output[:, 0, :] - else: - # get the logits for each prompt at the specified positions - next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1] - - if temperature != 1.0: - next_token_logits = next_token_logits / temperature - - # Uses top-k sampling if temperature is not 1.0, otherwise use argmax - if temperature != 1.0: - top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size - top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1) - probs = torch.softmax(top_k_logits, dim=-1) - next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1) - next_tokens = top_k_indices.gather( - -1, next_token_indices.unsqueeze(-1) - ).squeeze(-1) - else: - # Argmax (deterministic) - next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) - - # Token ids in int tensor form - return next_tokens - - -def _update_padded_sequence( - padded_sequence: torch.Tensor, - new_token: torch.Tensor, - prompt_lengths: List[int], -) -> None: - for i in range(len(prompt_lengths)): - padded_sequence[i, prompt_lengths[i]] = new_token[i, 0] - # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") - - -# Decode token id into string and print it -def _decode_in_flight(token, tokenizer, tp_rank): - """decode token ids for all prompts in the batch and log them""" - # `token` is a tensor of shape (batch_size, 1). - # For TiktokenTokenizer, we need to squeeze it to 1D. - # For SentencePieceProcessor, we don't. - token_str = tokenizer.decode(token.tolist()) - # print the token string on tp rank 0 - if tp_rank == 0: - logger.info( - f"{color.green} responses ====>>>> " - f"{color.blue} {token_str} {color.reset}" - ) - return token_str - - -def _cleanup(): - dist.barrier() - dist.destroy_process_group() - - -prompts = [ - "What is Snow?", - # "Can you explain what is the purpose of back propagation in neural networks?", - "Who is Santa Claus?", - "Where does Santa live?", - "Who is Abraham Lincoln?", - # "How are models trained?", -] - - -def main( - model_name, - builder_args, - tokenizer_args, - pipe, -): - pp_degree = builder_args.pp - - rank, world_size = _init_distributed() - logger.info(f"Worker started: {rank=}, {world_size=}") - - gpu_memory_monitor = GPUMemoryMonitor("cuda") - logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") - - distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name] - logger.info(f"Using model weights from {distribution} and dtype {model_dtype}") - - # Model-level config - model_config = ModelArgs.from_name(distribution) - # Transformer-level config - config = TransformerArgs.from_params(model_config.transformer_args["text"]) - logger.info(f"Transformer Config: {config}") - - tokenizer = _build_chat_tokenizer(tokenizer_args) - - set_precision(model_dtype) - logger.info(f"Using cache precision {model_dtype}") - - hf_config = get_hf_config_file(distribution) - if hf_config is None: - raise ValueError(f"Config file not found for model id {distribution}") - - # Validate pipeline degree - assert world_size % pp_degree == 0 - assert config.n_layers % pp_degree == 0 - - # Tensor parallel is enabled in this program - tp_degree = world_size // pp_degree - - # Create device mesh - device_mesh = _create_device_mesh(pp_degree, tp_degree) - tp_mesh = device_mesh["tp"] - pp_mesh = device_mesh["pp"] - logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") - - tp_rank = tp_mesh.get_local_rank() - pp_rank = pp_mesh.get_local_rank() - tp_group = tp_mesh.get_group() - pp_group = pp_mesh.get_group() - logger.info(f"{pp_degree=}, {tp_degree=}") - - # Convenience variables - first_pp_rank = 0 - last_pp_rank = pp_degree - 1 - - # Assuming same number of GPUs per node - device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") - - # Fill in PP configs - config.stage_idx = pp_rank - config.n_stages = pp_degree - - with torch.device("meta"): - # TODO: we should create model instead of Transformer - model = Transformer(config) - - # Distribute model on TP mesh - # (Surprisingly, this works even though model is on meta device and mesh is of - # cuda devices) - model.distribute(tp_mesh) - if rank == 0: - logger.info(f"Model: {model}") - - # Load weights - logger.info(f"Loading weights for {pp_rank=} on {device=}") - with CUDATrackTime() as timer: - _load_model_weights(model, distribution, device, config, builder_args.chpt_from) - - logger.info( - f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" - ) - - # Batch size. Since we push batches dynamically through the pipeline rather - # than chunking them, this is effectively micro-batch size in pipeline - # sense. Thus it is interchangeable with micro-batch size below. - batch_size = 1 # len(prompt) - seqlen_prefill = 1024 # sequence length - dim = 4096 # embedding dimension - - # Setup KV caches (after model distribution) - # The number of cache lanes is the same as the maximum number of - # micro-batches that can be "in flight" in parallel -- imagine each - # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. - # When decoding is done for certain micro-batches, we can reuse the KV cache - # lanes. - # TODO: bump up the lane count - pipeline_lanes = 1 - with device: - model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes) - - # info on stage size and params - stage_size = get_module_size(model) - stage_size_formatted = bytes_to_readable(stage_size) - stage_num_params = get_num_params(model) - logger.info( - f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}" - ) - model.eval() - - # Helper function to get example inputs and outputs for the stages. - def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: - mb_ids = torch.randint( - 0, config.vocab_size, (batch_size, seqlen), device=device - ) - activation = torch.rand( - batch_size, seqlen, dim, device=device, dtype=model_dtype - ) - logits = torch.rand( - batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype - ) - example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,) - example_outputs = (logits if pp_rank == last_pp_rank else activation,) - return example_inputs, example_outputs - - # Create prefill stage - logger.info(f"Creating pipeline stage for prefill {pp_rank=}, {pp_degree=}") - example_inputs, example_outputs = get_example_ins_outs(seqlen_prefill) - prefill_stage = PipelineStage( - model, - pp_rank, - pp_degree, - device, - input_args=example_inputs, - output_args=example_outputs, - group=pp_group, - ) - - # Create schedule - # Number of micro-batches for the schedule is 1, because each step() call we - # only push 1 micro-batch into the pipeline. But we can continuously push - # new micro-batches into the pipeline as they arrive, achieving same - # pipelining effect. - prefiller = ScheduleGPipe(prefill_stage, 1) - - # Need these global ids due to the API definition of dist.send and recv - first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) - last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) - - pipe.send("ready") - - while True: - command = pipe.recv() - assert isinstance(command, (str, list)) - if isinstance(command, str): - if command == "stop": - break - else: - raise ValueError(f"Unknown command: {command}") - else: - prompt = command - assert ( - len(prompt) == batch_size - ), f"Expecting {batch_size=} prompts but got {len(prompt)=}" - logger.info(f"{color.green}Prompt: {prompt}{color.reset}") - - start_pos = 0 - # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen - input_pos = torch.arange(seqlen_prefill, device=device) - - # encode the prompt - input_ids = _encode_strings( - prompt, tokenizer, bos=True, device=device, dtype=torch.int64 - ) - - # create a padded tensor for the input prompt - padded_sequence, prompt_lengths = _create_padded_prompts( - input_ids, tokenizer, seqlen_prefill, start_pos, device - ) - - # New token generated each iteration - # need a row dimension for each prompt in the batch - new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) - # Store the generated tokens - res = [] - - # Prefill phase - # Run context input through pipeline - # TODO: we need to pass `input_pos` and `cache_lane` to each stage. - lane = 0 - kwargs = {"input_pos": input_pos, "cache_lane": lane} - with torch.no_grad(), CUDATrackTime() as timer: - if pp_rank == first_pp_rank: - output = prefiller.step(padded_sequence, **kwargs) - elif pp_rank == last_pp_rank: - output = prefiller.step(**kwargs) - else: # middle pp ranks - prefiller.step(**kwargs) - - logger.info( - f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" - ) - - # Decode the output -- first generated token - if pp_rank == last_pp_rank: - logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}") - new_token = _batch_decode_next_tokens(output, prompt_lengths) - res.append(new_token) - # TODO: Move to a separate decoding thread - resp = _decode_in_flight(new_token, tokenizer, tp_rank) - pipe.send((resp, new_token.tolist())) - else: - pipe.send(None) - - # seqlen = 1 now - seqlen_decode = 1 - input_pos = torch.tensor([prompt_lengths[0]], device=device) - - # Create decode stage - logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") - example_inputs, example_outputs = get_example_ins_outs(seqlen_decode) - decode_stage = PipelineStage( - model, - pp_rank, - pp_degree, - device, - input_args=example_inputs, - output_args=example_outputs, - group=pp_group, - ) - # create schedule - decoder = ScheduleGPipe(decode_stage, 1) - - # Decoding - with torch.no_grad(), CUDATrackTime() as timer: - while True: - command = pipe.recv() - assert isinstance(command, str) - if command == "stop": - break - elif command == "step": - pass - else: - raise ValueError(f"Unknown command: {command}") - - kwargs = {"input_pos": input_pos, "cache_lane": lane} - # sendrecv between last and first ranks, only if: - # first_pp_rank != last_pp_rank. - if pp_rank == last_pp_rank and pp_rank != first_pp_rank: - dist.send( - new_token, - dst=first_pp_rank_global_id, - group=pp_group, - ) - elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: - dist.recv( - new_token, - src=last_pp_rank_global_id, - group=pp_group, - ) - - # Run data through pipeline - if pp_rank == first_pp_rank: - output = decoder.step(new_token, **kwargs) - elif pp_rank == last_pp_rank: - output = decoder.step(**kwargs) - else: # middle pp ranks - decoder.step(**kwargs) - - # Decode the output - if pp_rank == last_pp_rank: - new_token = _batch_decode_next_tokens(output) - res.append(new_token) - # TODO: Move to a separate decoding thread - resp = _decode_in_flight(new_token, tokenizer, tp_rank) - pipe.send((resp, new_token)) - else: - pipe.send(None) - - # Increment input position - input_pos += 1 - - logger.info( - f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" - ) - - # Display the decoding results - - # output formatted response via last pp group and tp rank 0 - if pp_rank == last_pp_rank and tp_rank == 0: - # `res` is a list of tensors, each being a batch of generated token ids. - # We need to concatenate them to get the full sequence of generated - # token ids. Thus cat'ing along dim 1. - res = torch.cat(res, dim=1) - res_list = res.tolist() - - responses = tokenizer.decode(res_list) - - # Show prompts and responses - for prompt_text, response_text in zip(prompt, responses): - logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") - logger.info(f"Response: {color.red}{response_text} {color.reset}") - - # Cleanup - _cleanup() - logger.info( - f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" - ) - -# TODO: remove or make it work again -# if __name__ == "__main__": -# parser = argparse.ArgumentParser() -# parser.add_argument( -# "model_name", -# type=str, -# default="llama3", -# help="Name of the model to load", -# choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), -# ) -# parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") -# parser.add_argument( -# "--ntokens", -# type=int, -# default=40, -# help="Number of tokens to generate", -# ) -# parser.add_argument( -# "--chpt-from", -# type=str, -# default="hf", # TODO: change to torchchat once we support it well -# help="Checkpoint format to load from", -# choices=["hf", "torchchat"], -# ) -# args = parser.parse_args() - -# main() diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py deleted file mode 100644 index 51c472e4a..000000000 --- a/torchchat/distributed/generate.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import asyncio -import atexit -import importlib.util -import subprocess -import threading -from abc import abstractmethod -from collections import deque -from dataclasses import dataclass -from functools import partial -from os import environ -from pathlib import Path -from typing import List, Optional -from uuid import uuid4 - -import torch.multiprocessing as mp -from torchchat.cli.builder import BuilderArgs, TokenizerArgs -from torchchat.distributed.dist_run import NAME_TO_DISTRIBUTION_AND_DTYPE -from torchchat.distributed.logging_utils import SingletonLogger - -logger = SingletonLogger.get_logger() - - -def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): - environ["MASTER_ADDR"] = "localhost" - environ["MASTER_PORT"] = "29500" - environ["RDZV_BACKEND"] = "c10d" - environ["WORLD_SIZE"] = str(world_size) - environ["RANK"] = str(rank) - environ["LOCALRANK"] = str(rank) - - return target(*args, **kwargs) - - -def _launch_distributed_inference( - model_name: str, builder_args: BuilderArgs, tokenizer_args: TokenizerArgs -) -> tuple[List]: - # launch distributed inference worker, each worker gets a pipe to communicate with the main process - logger.info("Launching distributed inference ...") - - num_processes_per_node = builder_args.pp * builder_args.tp - - from torchchat.distributed.dist_run import main - - mp.set_start_method("spawn") - - pipes = [] - procs = [] - try: - for rank in range(num_processes_per_node): - server_pipe, client_pipe = mp.Pipe(duplex=True) - pipes.append(server_pipe) - procs.append( - mp.Process( - target=partial(_setup_env, num_processes_per_node, rank, main), - args=(model_name, builder_args, tokenizer_args, client_pipe), - ) - ) - procs[-1].start() - - for pipe in pipes: - assert pipe.recv() == "ready", "Starting the worker failed" - except Exception as e: - logger.error(f"Error during distributed inference: {str(e)}") - for p in procs: - p.kill() - raise e - - logger.info( - f"Done launching distributed inference on {num_processes_per_node} GPUs." - ) - return procs, pipes - - -@dataclass -class Output: - is_finished: bool = False - text: Optional[str] = None - token: Optional[list] = None - - -@dataclass -class Request: - request_id: int - prompt: str - - @classmethod - def new_request(cls, prompt): - return cls(request_id=uuid4().int, prompt=prompt) - - -class Scheduler(object): - def __init__( - self, - builder_args, - generator_args, - pipes, - loop, - ): - self.builder_args = builder_args - self.generator_args = generator_args - self.requests = {} - self.in_flight_requests = {} - self.in_flight_batch_order = [] - self.pipes = pipes - self.req_to_states = {} - self.req_to_results = {} - self.request_queue = mp.Queue() - self.loop = loop - - def schedule_request(self, req: Request): - # add request to queue and create deque and async event for response - self.req_to_states[req.request_id] = asyncio.Event() - self.req_to_results[req.request_id] = deque() - self.request_queue.put(req) - - def process_requests_loop(self): - # Continuously process requests (one at a time for now), results are routed into the requests deque - while True: - req = self.request_queue.get() - if req == "stop": - break - self.requests = {req.request_id: req.prompt} - - responses = {} - running = True - while running: - outputs = self.step() - self.req_to_results[req.request_id].append(outputs[0]) - - self.loop.call_soon_threadsafe(self.req_to_states[req.request_id].set) - - running &= not outputs[0].is_finished - - async def wait_for_request(self, req: Request) -> Output: - # Wait for request to deliver result, uses event to trigger and reads from left side of deque - is_finished = False - while not is_finished: - await self.req_to_states[req.request_id].wait() - while len(self.req_to_results[req.request_id]): - output = self.req_to_results[req.request_id].popleft() - is_finished |= output.is_finished - yield output - del self.req_to_states[req.request_id] - del self.req_to_results[req.request_id] - - def step(self) -> List[Output]: - # Make a prefill or decoding step and receive results - responses = [] - # TODO: Implement a scheduler to handle the requests - if len(self.in_flight_requests) > 0: - # Receive decoded token - for p in self.pipes: - p.send("step") - for p in self.pipes: - responses.append(p.recv()) - - else: - # Send requests to backend - self.in_flight_batch_order = list(self.requests.keys()) - prompts = [self.requests[k] for k in self.in_flight_batch_order] - for p in self.pipes: - p.send(prompts) - self.in_flight_requests = self.requests - self.requests = {} - self.current_step = 0 - # Receive first token - for p in self.pipes: - responses.append(p.recv()) - # Filter out None responses from in-between stages - responses = [r for r in responses if r is not None][0] - outputs = [] - for k, v in zip(self.in_flight_batch_order, zip(responses[0], responses[1])): - text, token_ids = v - outputs.append( - Output( - # TODO: Look for tokenizer.eos_id as well - is_finished=self.current_step >= self.generator_args.max_new_tokens, - text=text, - token=token_ids, - ) - ) - if self.current_step >= self.generator_args.max_new_tokens: - for p in self.pipes: - p.send("stop") - self.in_flight_requests = [] - - self.current_step += 1 - - return outputs - - -class DistributedGenerator(object): - def __init__( - self, - # TODO: switch this to torchchat method - model_name: str, - builder_args: BuilderArgs, - tokenizer_args: TokenizerArgs, - # TODO: move GeneratorArgs into a different module - generator_args, - profile: Optional[Path], - quantize: bool, - draft_quantize: bool, - ): - self.model_name = model_name - self.builder_args = builder_args - self.generate_args = generator_args - - self.check_args() - - self.procs, self.pipes = _launch_distributed_inference( - model_name, builder_args, tokenizer_args - ) - - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - self.scheduler = Scheduler(builder_args, generator_args, self.pipes, self.loop) - - # TODO: Mode into process and use pipe or queue for comm - self.scheduler_thread = threading.Thread( - target=self.scheduler.process_requests_loop - ) - self.scheduler_thread.start() - - atexit.register(self.shutdown) - - def shutdown(self): - # Stop all processes and threads - self.scheduler.request_queue.put("stop") - self.scheduler_thread.join() - - for p in self.pipes: - p.send("stop") - for p in self.procs: - p.kill() - - def generate(self, text): - # Function to generate text from prompt - req = Request.new_request(text) - self.scheduler.schedule_request(req) - - generator = self.scheduler.wait_for_request(req) - - running = True - while running: - output = self.loop.run_until_complete(generator.__anext__()) - running &= not output.is_finished - - yield output - - def check_args(self): - if self.generate_args.chat_mode: - raise NotImplementedError( - "Currently we only support generate with --distributed" - ) - elif self.builder_args.tp < 2: - raise ValueError("TP degree must be at least 2 for distributed inference") - elif self.model_name not in NAME_TO_DISTRIBUTION_AND_DTYPE.keys(): - raise ValueError( - f"Distributed inference currently only supports then following models: {list(NAME_TO_DISTRIBUTION_AND_DTYPE.keys())}" - ) - elif self.builder_args.chpt_from == "torchchat": - raise ValueError( - f"Distributed inference currently only supports HF checkpoints" - ) From d5bca9b4beb5838a50099acd1b567beb192ba22d Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:48:34 -0800 Subject: [PATCH 03/13] Read distribution from model_config --- torchchat/cli/builder.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 6ea5ff25f..867bde34b 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -64,6 +64,7 @@ class BuilderArgs: pp: int = 1 tp: int = 1 chpt_from: str = "hf" + distribution_path: Optional[str] = None is_chat_model: bool = False prefill_possible: bool = False dynamic_shapes: bool = False @@ -129,6 +130,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": model_config.transformer_params_key or model_config.name.split("/")[-1] ) + distribution_path = model_config.distribution_path + dso_path = getattr(args, "dso_path", None) pte_path = getattr(args, "pte_path", None) aoti_package_path = getattr(args, "aoti_package_path", None) @@ -194,6 +197,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": pp=pp, tp=tp, chpt_from=chpt_from, + distribution_path=distribution_path, is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), @@ -607,23 +611,6 @@ def do_nothing(max_batch_size, max_seq_length): except Exception: raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") elif builder_args.distributed: - # Using params_table to identify the model to load, for example "Meta-Llama-3.1-8B". - #TODO This is a hacky way to please the distributed loading api and needs to be replaced - NAME_TO_DISTRIBUTION = { - "Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct", - "Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "Meta-Llama-3-70B": "meta-llama/Meta-Llama-3-70B-Instruct", - "Meta-Llama-3.1-70B": "meta-llama/Meta-Llama-3.1-70B-Instruct", - - } - # TODO: Use information in builder_args directly to build model and load weights - assert builder_args.params_table - try: - distribution = NAME_TO_DISTRIBUTION[builder_args.params_table] - except KeyError as e: - print(f"Unknown params_table: {builder_args.params_table}. Suported model names are: llama3.1, llama3, llama2-7b-chat") - raise e - pp_degree = builder_args.pp tp_degree = builder_args.tp @@ -687,7 +674,7 @@ def do_nothing(max_batch_size, max_seq_length): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - load_model_weights(model, distribution, device, config, builder_args.chpt_from) + load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" From 76895cc870e5997e9cc05a3dadb5e03c9ba07f52 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:58:14 -0800 Subject: [PATCH 04/13] Declare distribution_path if args.model is not given --- torchchat/cli/builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 867bde34b..90995e656 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -116,6 +116,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": checkpoint_path = args.checkpoint_path params_table = args.params_table + distribution_path = None if args.model: # Using a named, well-known model model_config = resolve_model_config(args.model) From 3ef1296a917c7dd8c16b038a1a0d84eb21f73f4c Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Tue, 19 Nov 2024 14:50:22 -0800 Subject: [PATCH 05/13] Address some nits from PR review --- torchchat/usages/openai_api.py | 13 ++++--------- torchchat/usages/server.py | 5 +---- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index b1ad151f1..b67cd0eac 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from io import BytesIO from pwd import getpwuid -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Type import torch @@ -486,18 +486,13 @@ def _callback(self, x, *, buffer, done_generating): pass -def create_openai_api_generator(distributed): +def create_openai_api_generator(distributed: bool) -> Type: """ Factory method to create an OpenAiApiGenerator """ - if distributed: - # Base class order matters to make sure OpenAiApiGeneratorMixin overrides methods in DistributedGenerator and Generator - return type('OpenAiApiGenerator', (OpenAiApiGeneratorMixin, DistributedGenerator), {}) - else: - return type('OpenAiApiGenerator', (OpenAiApiGeneratorMixin, LocalGenerator), {}) - - + # Base class order matters to make sure OpenAiApiGeneratorMixin overrides methods in DistributedGenerator and Generator + return type('OpenAiApiGenerator', (OpenAiApiGeneratorMixin, DistributedGenerator if distributed else LocalGenerator), {}) """ diff --git a/torchchat/usages/server.py b/torchchat/usages/server.py index afbfeebd7..7ffe1371a 100644 --- a/torchchat/usages/server.py +++ b/torchchat/usages/server.py @@ -67,14 +67,11 @@ def create_app(args): # noqa: C901 builder_args = BuilderArgs.from_args(args) procs = [] + queue = None if builder_args.distributed: world_size = builder_args.tp * builder_args.pp mp_context = mp.get_context('spawn') queue = mp_context.Queue() - else: - world_size = 1 - queue = None - if builder_args.distributed: for i in range(1, world_size): From 7cb98c9aba5fdd5c391825c2139ee090a81afd06 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:02:05 -0800 Subject: [PATCH 06/13] Added comment on model size all reduce + type hint --- torchchat/generate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 1a137c6e9..42904eeb5 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -260,7 +260,7 @@ def __init__( self.draft_quantize = draft_quantize self.is_torchtune_model = generator_args.is_torchtune_model self.dtype = builder_args.precision - self.get_user_input = input + self.get_user_input : Callable = input self.rank: Optional[int] = None @@ -921,6 +921,8 @@ def chat( ] ) if self.builder_args.distributed: + # During distributed inference the model gets sharded among the ranks + # So we need to all reduce the model size to get the total model size model_size = torch.tensor(model_size, dtype=torch.int64, device=self.device) dist.all_reduce(model_size) model_size = model_size.item() @@ -1257,7 +1259,7 @@ def distributed_input(prompt: str) -> str: dist.broadcast_object_list(text) return text[0] - self.get_user_input = distributed_input + self.get_user_input: Callable = distributed_input if builder_args.pp > 1: self.seqlen_prefill = 1024 # sequence length for prefill stage From 10fb55a29d60479df4f9e5c292a88dbaec0e69d1 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:07:20 -0800 Subject: [PATCH 07/13] Apply suggestions from code review Co-authored-by: Jack-Khuu --- torchchat/generate.py | 42 ++++++++++++++++++------------------------ 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 42904eeb5..ea4c98da5 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1287,7 +1287,7 @@ def __del__(self): dist.destroy_process_group() # Helper function to get example inputs and outputs for the stages. - def get_example_ins_outs(self, batch_size:int , seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: + def get_example_ins_outs(self, batch_size: int , seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: """ This function generates example inputs and outputs for the prefill and decode stages. @@ -1308,9 +1308,7 @@ def get_example_ins_outs(self, batch_size:int , seqlen: int) -> Tuple[torch.Tens example_outputs = (logits if self.pp_rank == self.last_pp_rank else activation,) return example_inputs, example_outputs - def create_prefill_stage( - self, - ): + def create_prefill_stage(self): """ Creates a pipeline stage for prefilling. @@ -1340,9 +1338,7 @@ def create_prefill_stage( prefiller = ScheduleGPipe(prefill_stage, 1) return prefiller - def create_decode_stage( - self, - ): + def create_decode_stage(self): """ Creates a decode stage for the pipeline parallelism. @@ -1422,24 +1418,22 @@ def prefill( else: # middle pp ranks self.prefiller.step(**kwargs) - new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64) - if self.pp_rank == self.last_pp_rank: new_token = self.sample(logits[:,:prompt_length], need_probs=False, **sampling_kwargs)[0] - - - if self.pp_rank == self.last_pp_rank and self.pp_rank != self.first_pp_rank: - dist.send( - new_token, - dst=self.first_pp_rank_global_id, - group=self.pp_group, - ) - elif self.pp_rank == self.first_pp_rank and self.pp_rank != self.last_pp_rank: - dist.recv( - new_token, - src=self.last_pp_rank_global_id, - group=self.pp_group, - ) + if self.pp_rank != self.first_pp_rank: + dist.send( + new_token, + dst=self.first_pp_rank_global_id, + group=self.pp_group, + ) + else: + new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64) + if self.pp_rank == self.first_pp_rank: + dist.recv( + new_token, + src=self.last_pp_rank_global_id, + group=self.pp_group, + ) return new_token @@ -1485,7 +1479,7 @@ def decode_one_token( # Decode the output if self.pp_rank == self.last_pp_rank: - new_token, next_prob = self.sample(logits, need_probs=need_probs, **sampling_kwargs) + new_token, _ = self.sample(logits, need_probs=need_probs, **sampling_kwargs) else: new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64) From 28d7836577db028b0e040810e58edee6b30f60c6 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 16 Dec 2024 20:18:43 -0800 Subject: [PATCH 08/13] Make sure speculative decoding is disable for pp >1 and remark this in the comments as well --- torchchat/generate.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index ea4c98da5..d1561714e 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1237,6 +1237,9 @@ def __init__( quantize: bool, draft_quantize: bool, ): + + is_speculative = speculative_builder_args.checkpoint_path is not None + assert is_speculative == False, "Distributed inference with pp > 1 does not support speculative inference yet." super().__init__( builder_args, speculative_builder_args, @@ -1449,8 +1452,9 @@ def decode_one_token( """ Decodes a single token. + # TODO: implement speculative decoding with pp>1 Returns: - Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing the decoded token and its probability. + Tuple[torch.Tensor, None]: A tuple containing the decoded token and None. """ if self.builder_args.pp == 1: return super().decode_one_token( @@ -1511,9 +1515,7 @@ def sample( return (idx_next, None) probs = self.logits_to_probs(logits[0, -1], temperature, top_k) idx_next = self.multinomial_sample_one_no_sync(probs) - if self.builder_args.pp == 1: - dist.broadcast(idx_next, src=0) - dist.broadcast(probs, src=0) + return idx_next, probs From 68eec0bd99640ed9db18e4411242986e2fca4d5c Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 16 Dec 2024 20:23:09 -0800 Subject: [PATCH 09/13] Refactor conditions in pp --- torchchat/generate.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index d1561714e..484be641f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1484,23 +1484,22 @@ def decode_one_token( # Decode the output if self.pp_rank == self.last_pp_rank: new_token, _ = self.sample(logits, need_probs=need_probs, **sampling_kwargs) + if self.pp_rank != self.first_pp_rank: + dist.send( + new_token, + dst=self.first_pp_rank_global_id, + group=self.pp_group, + ) else: new_token = torch.zeros(1, 1, device=self.device, dtype=torch.int64) - - if self.pp_rank == self.last_pp_rank and self.pp_rank != self.first_pp_rank: - dist.send( - new_token, - dst=self.first_pp_rank_global_id, - group=self.pp_group, - ) - elif self.pp_rank == self.first_pp_rank and self.pp_rank != self.last_pp_rank: - dist.recv( - new_token, - src=self.last_pp_rank_global_id, - group=self.pp_group, - ) - #TODO: Why do we get 2d tensor here? - new_token=new_token[0] + if self.pp_rank == self.first_pp_rank: + dist.recv( + new_token, + src=self.last_pp_rank_global_id, + group=self.pp_group, + ) + #TODO: Why do we get 2d tensor here? + new_token=new_token[0] return new_token, None def sample( From 3ad31e8db3705a51fd38fcc7400fe08b4c63d72b Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 16 Dec 2024 21:06:15 -0800 Subject: [PATCH 10/13] Rename and alter signature of setup_env to reflect that it also runs the target --- torchchat/distributed/utils.py | 4 ++-- torchchat/generate.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchchat/distributed/utils.py b/torchchat/distributed/utils.py index e935226b0..85bfe04fc 100644 --- a/torchchat/distributed/utils.py +++ b/torchchat/distributed/utils.py @@ -258,7 +258,7 @@ def get_device_info( ) return device_info -def setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): +def run_in_dist_env(world_size: int, rank: int, target: callable): environ["MASTER_ADDR"] = "localhost" environ["MASTER_PORT"] = "29500" environ["RDZV_BACKEND"] = "c10d" @@ -266,4 +266,4 @@ def setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): environ["RANK"] = str(rank) environ["LOCALRANK"] = str(rank) - return target(*args, **kwargs) + return target() diff --git a/torchchat/generate.py b/torchchat/generate.py index 484be641f..65c77d31b 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -45,7 +45,7 @@ ) from torchchat.distributed.utils import ( Color as color, - setup_env, + run_in_dist_env, ) from torchchat.model import Model, ModelType from torchchat.utils.build_utils import device_sync, set_precision @@ -1565,9 +1565,9 @@ def main(args): with futures.ProcessPoolExecutor(max_workers=world_size-1, mp_context=ctx) as executor: for i in range(1,world_size): fn = partial(run_generator, args, i) - executor.submit(setup_env, world_size, i, fn) + executor.submit(run_in_dist_env, world_size, i, fn) #Starting rank 0 fn = partial(run_generator, args, 0) - setup_env(world_size, 0, fn) + run_in_dist_env(world_size, 0, fn) else: run_generator(args) From e07b03d6837b01a9124ae229e7b3b5a693f3ad28 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 16 Dec 2024 21:14:09 -0800 Subject: [PATCH 11/13] Rename setup_env in server + fix condition --- torchchat/usages/server.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchchat/usages/server.py b/torchchat/usages/server.py index 7ffe1371a..550539a88 100644 --- a/torchchat/usages/server.py +++ b/torchchat/usages/server.py @@ -24,7 +24,7 @@ from flask import Flask, request, Response from torchchat.cli.builder import BuilderArgs, TokenizerArgs -from torchchat.distributed.utils import setup_env +from torchchat.distributed.utils import run_in_dist_env from torchchat.generate import GeneratorArgs from torchchat.usages.openai_api import ( @@ -73,11 +73,10 @@ def create_app(args): # noqa: C901 mp_context = mp.get_context('spawn') queue = mp_context.Queue() - if builder_args.distributed: for i in range(1, world_size): fn = partial(run_worker, args, i, queue) mp_context = mp.get_context('spawn') - procs.append(mp_context.Process(target=setup_env, args=(world_size, i, fn))) + procs.append(mp_context.Process(target=run_in_dist_env, args=(world_size, i, fn))) procs[-1].start() environ["MASTER_ADDR"] = "localhost" From 7ac16f9ecd8131f8a25eea44756742031d7571a0 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 19 Dec 2024 02:44:42 -0800 Subject: [PATCH 12/13] Update generate.py --- torchchat/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 4d9221161..354585547 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1229,7 +1229,7 @@ def callback(x, *, done_generating=False): aggregate_metrics["first_token_per_sec"].append(first_token_sec) aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) - logger.info( + logging.info( f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\ \nGenerated {num_tokens_generated} tokens \ \nTime for inference {i + 1}: {t:.04f} sec total \ @@ -1240,11 +1240,11 @@ def callback(x, *, done_generating=False): \n Next token throughput: {next_tokens_sec:.04f} tokens/sec, {1 / next_tokens_sec:.04f} s/token \ " ) - logger.info( + logging.info( f"\nBandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s" ) if i == 0: - logger.info( + logging.info( f"*** This first iteration will include cold start effects for dynamic import, hardware caches{', JIT compilation' if jit_compile else ''}. ***" ) print("\n========================================\n") From 765015357dd90fccff572a877db15e7fe4a1b9ed Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:11:44 -0800 Subject: [PATCH 13/13] Add default value to add_generation_prompt to preserve bc --- torchchat/generate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 354585547..f29b96615 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -85,7 +85,7 @@ def __init__(self, tokenizer): def encode_dialog_prompt( self, dialog: DIALOG_TYPE, - add_generation_prompt: bool, + add_generation_prompt: bool = True, ) -> List[int]: """Encode a sequence of messages into a sequence of token IDs, including the chat template @@ -136,7 +136,7 @@ def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]: def encode_dialog_prompt( self, dialog: _ChatFormatter.DIALOG_TYPE, - add_generation_prompt: bool, + add_generation_prompt: bool = True, ) -> List[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) @@ -166,7 +166,7 @@ def _get_content_str(message: _ChatFormatter.MESSAGE_TYPE) -> str: def encode_dialog_prompt( self, dialog: _ChatFormatter.DIALOG_TYPE, - add_generation_prompt: bool, # UNUSED + add_generation_prompt: bool = True, # UNUSED ) -> List[int]: new_turn = True tokens = [] @@ -197,7 +197,7 @@ class HFTokenizerChatFormatter(_ChatFormatter): def encode_dialog_prompt( self, dialog: _ChatFormatter.DIALOG_TYPE, - add_generation_prompt: bool, + add_generation_prompt: bool = True, ) -> List[int]: rendered = self.tokenizer.apply_chat_template( dialog, add_generation_prompt=add_generation_prompt