Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added recipes for registry
Browse files Browse the repository at this point in the history
jessicazhongeee committed Jan 17, 2025
1 parent 5ad117b commit 68aee31
Showing 5 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion recipes/configs/llama3/70B_generation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ output_dir: ./
model:
_component_: torchtune.models.llama3.llama3_70b

tensor_parallel_plan:
parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Transform arguments
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/70B_generation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ output_dir: ./
model:
_component_: torchtune.models.llama3_1.llama3_1_70b

tensor_parallel_plan:
parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Transform arguments
2 changes: 1 addition & 1 deletion recipes/configs/llama3_3/70B_generation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ output_dir: ./
model:
_component_: torchtune.models.llama3_3.llama3_3_70b

tensor_parallel_plan:
parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Transform arguments
2 changes: 1 addition & 1 deletion recipes/dev/generate_v2_distributed.py
Original file line number Diff line number Diff line change
@@ -109,7 +109,7 @@ def setup(self, cfg: DictConfig) -> None:
parallelize_module(
model,
tp_device_mesh,
parallelize_plan=config.instantiate(cfg.tensor_parallel_plan),
parallelize_plan=config.instantiate(cfg.parallelize_plan),
)

with training.set_default_dtype(self._dtype), self._device:
8 changes: 8 additions & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
@@ -441,6 +441,14 @@ class Recipe:
name="llama3/70B_generation_distributed",
file_path="llama3/70B_generation_distributed.yaml",
),
Config(
name="llama3_1/70B_generation_distributed",
file_path="llama3_1/70B_generation_distributed.yaml",
),
Config(
name="llama3_3/70B_generation_distributed",
file_path="llama3_3/70B_generation_distributed.yaml",
),
],
supports_distributed=True,
),

0 comments on commit 68aee31

Please sign in to comment.