From 68aee3126ff82d436229e70b89d8ebc18b567b4e Mon Sep 17 00:00:00 2001 From: JessicaZhong Date: Fri, 17 Jan 2025 14:02:45 -0800 Subject: [PATCH] added recipes for registry --- recipes/configs/llama3/70B_generation_distributed.yaml | 2 +- recipes/configs/llama3_1/70B_generation_distributed.yaml | 2 +- recipes/configs/llama3_3/70B_generation_distributed.yaml | 2 +- recipes/dev/generate_v2_distributed.py | 2 +- torchtune/_recipe_registry.py | 8 ++++++++ 5 files changed, 12 insertions(+), 4 deletions(-) diff --git a/recipes/configs/llama3/70B_generation_distributed.yaml b/recipes/configs/llama3/70B_generation_distributed.yaml index 6316f77abe..78c77ba263 100644 --- a/recipes/configs/llama3/70B_generation_distributed.yaml +++ b/recipes/configs/llama3/70B_generation_distributed.yaml @@ -13,7 +13,7 @@ output_dir: ./ model: _component_: torchtune.models.llama3.llama3_70b -tensor_parallel_plan: +parallelize_plan: _component_: torchtune.models.llama3.base_llama_tp_plan # Transform arguments diff --git a/recipes/configs/llama3_1/70B_generation_distributed.yaml b/recipes/configs/llama3_1/70B_generation_distributed.yaml index 1967b90423..d71a94f8de 100644 --- a/recipes/configs/llama3_1/70B_generation_distributed.yaml +++ b/recipes/configs/llama3_1/70B_generation_distributed.yaml @@ -13,7 +13,7 @@ output_dir: ./ model: _component_: torchtune.models.llama3_1.llama3_1_70b -tensor_parallel_plan: +parallelize_plan: _component_: torchtune.models.llama3.base_llama_tp_plan # Transform arguments diff --git a/recipes/configs/llama3_3/70B_generation_distributed.yaml b/recipes/configs/llama3_3/70B_generation_distributed.yaml index 31b17f55f2..d39acf45ad 100644 --- a/recipes/configs/llama3_3/70B_generation_distributed.yaml +++ b/recipes/configs/llama3_3/70B_generation_distributed.yaml @@ -13,7 +13,7 @@ output_dir: ./ model: _component_: torchtune.models.llama3_3.llama3_3_70b -tensor_parallel_plan: +parallelize_plan: _component_: torchtune.models.llama3.base_llama_tp_plan # Transform arguments diff --git a/recipes/dev/generate_v2_distributed.py b/recipes/dev/generate_v2_distributed.py index 21fabb1289..b9fc5b4ddb 100644 --- a/recipes/dev/generate_v2_distributed.py +++ b/recipes/dev/generate_v2_distributed.py @@ -109,7 +109,7 @@ def setup(self, cfg: DictConfig) -> None: parallelize_module( model, tp_device_mesh, - parallelize_plan=config.instantiate(cfg.tensor_parallel_plan), + parallelize_plan=config.instantiate(cfg.parallelize_plan), ) with training.set_default_dtype(self._dtype), self._device: diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index acf23a0949..1c41519712 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -441,6 +441,14 @@ class Recipe: name="llama3/70B_generation_distributed", file_path="llama3/70B_generation_distributed.yaml", ), + Config( + name="llama3_1/70B_generation_distributed", + file_path="llama3_1/70B_generation_distributed.yaml", + ), + Config( + name="llama3_3/70B_generation_distributed", + file_path="llama3_3/70B_generation_distributed.yaml", + ), ], supports_distributed=True, ),