Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate distributed inference into torchchat cli #1327

Merged
merged 25 commits into from
Oct 25, 2024

Conversation

mreso
Copy link
Contributor

@mreso mreso commented Oct 24, 2024

This PR is the first step towards integrating distributed inference into the torchchat CLI.
Currently only torchchat.py generate is supported.

Test:
python torchchat.py generate llama3.1 --distributed --max-new-tokens 40 --prompt "What is Snow?"
Output:

Launching distributed inference ...
10-24 12:48:32.449 - dist_run:304 - Worker started: rank=0, world_size=2
10-24 12:48:32.449 - dist_run:307 -  GPU capacity: NVIDIA PG509-210 (None) with 79.15GiB memory
10-24 12:48:32.449 - dist_run:310 - Using model weights from meta-llama/Meta-Llama-3.1-8B-Instruct and dtype torch.bfloat16
known configs: ['13B', '30B', '34B', '70B', '7B', 'CodeLlama-7b-Python-hf', 'Llama-3.2-11B-Vision', 'Llama-Guard-3-1B-INT4', 'Llama-Guard-3-1B', 'Meta-Llama-3-70B', 'Meta-Llama-3-8B', 'Meta-Llama-3.1-70B-Tune', 'Meta-Llama-3.1-70B', 'Meta-Llama-3.1-8B-Tune', 'Meta-Llama-3.1-8B', 'Meta-Llama-3.2-1B', 'Meta-Llama-3.2-3B', 'Mistral-7B', 'llava-1.5', 'stories110M'
, 'stories15M', 'stories42M']
10-24 12:48:32.450 - dist_run:316 - Transformer Config: TransformerArgs(block_size=131072, vocab_size=128256, n_layers=32, n_heads=32, dim=4096, hidden_dim=14336, n_local_heads=8, head_dim=128, rope_base=500000.0, norm_eps=1e-05, multiple_of=1024, ffn_dim_multiplier=1.3, use_tiktoken=True, max_seq_length=8192, rope_scaling={'factor': 8.0, 'low_freq_factor': 1.
0, 'high_freq_factor': 4.0, 'original_max_position_embeddings': 8192}, n_stages=1, stage_idx=0, attention_bias=False, feed_forward_bias=False, tie_word_embeddings=False)
10-24 12:48:32.693 - dist_run:112 - using tokenizer = tokenizer.tiktoken.Tokenizer
10-24 12:48:32.693 - dist_run:321 - Using cache precision torch.bfloat16
10-24 12:48:32.879 - dist_run:338 - Created device mesh: DeviceMesh('cuda', [[0, 1]], mesh_dim_names=('pp', 'tp'))
tp_mesh=DeviceMesh('cuda', [0, 1], mesh_dim_names=('tp',)), pp_mesh=DeviceMesh('cuda', [0], mesh_dim_names=('pp',))
10-24 12:48:32.879 - dist_run:344 - pp_degree=1, tp_degree=2
10-24 12:48:32.905 - dist_run:304 - Worker started: rank=1, world_size=2
10-24 12:48:32.905 - dist_run:307 -  GPU capacity: NVIDIA PG509-210 (None) with 79.15GiB memory
10-24 12:48:32.905 - dist_run:310 - Using model weights from meta-llama/Meta-Llama-3.1-8B-Instruct and dtype torch.bfloat16
known configs: ['13B', '30B', '34B', '70B', '7B', 'CodeLlama-7b-Python-hf', 'Llama-3.2-11B-Vision', 'Llama-Guard-3-1B-INT4', 'Llama-Guard-3-1B', 'Meta-Llama-3-70B', 'Meta-Llama-3-8B', 'Meta-Llama-3.1-70B-Tune', 'Meta-Llama-3.1-70B', 'Meta-Llama-3.1-8B-Tune', 'Meta-Llama-3.1-8B', 'Meta-Llama-3.2-1B', 'Meta-Llama-3.2-3B', 'Mistral-7B', 'llava-1.5', 'stories110M'
, 'stories15M', 'stories42M']
10-24 12:48:32.906 - dist_run:316 - Transformer Config: TransformerArgs(block_size=131072, vocab_size=128256, n_layers=32, n_heads=32, dim=4096, hidden_dim=14336, n_local_heads=8, head_dim=128, rope_base=500000.0, norm_eps=1e-05, multiple_of=1024, ffn_dim_multiplier=1.3, use_tiktoken=True, max_seq_length=8192, rope_scaling={'factor': 8.0, 'low_freq_factor': 1.
0, 'high_freq_factor': 4.0, 'original_max_position_embeddings': 8192}, n_stages=1, stage_idx=0, attention_bias=False, feed_forward_bias=False, tie_word_embeddings=False)
10-24 12:48:33.027 - dist_run:366 - Model: Transformer(`
....
10-24 12:48:39.142 - checkpoint_utils:252 - Loaded 104 tensors from /home/mreso/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/model-00002-of-00004.safetensors
10-24 12:48:43.856 - checkpoint_utils:206 - Fully updated state dict.
10-24 12:48:43.856 - checkpoint_utils:208 - Loading 291 weights into stage dict
10-24 12:48:43.857 - checkpoint_utils:206 - Fully updated state dict.
10-24 12:48:43.857 - checkpoint_utils:208 - Loading 291 weights into stage dict
10-24 12:48:43.900 - checkpoint_utils:215 - Successfully loaded 291 weights into stage module
10-24 12:48:43.901 - checkpoint_utils:215 - Successfully loaded 291 weights into stage module
10-24 12:48:44.059 - checkpoint_utils:392 - Success - Loaded 291 weights, 0 missing weights
10-24 12:48:44.060 - dist_run:373 - Total weight loading time: 10.6224 seconds for rank 0
10-24 12:48:44.272 - checkpoint_utils:392 - Success - Loaded 291 weights, 0 missing weights
10-24 12:48:44.272 - dist_run:373 - Total weight loading time: 10.7729 seconds for rank 1
10-24 12:48:44.300 - dist_run:399 - Stage 0 has 8.03B params, Size: 15.08 GiB
10-24 12:48:44.301 - dist_run:420 - Creating pipeline stage for prefill pp_rank=0, pp_degree=1
10-24 12:48:44.621 - dist_run:399 - Stage 1 has 8.03B params, Size: 15.08 GiB
10-24 12:48:44.623 - dist_run:420 - Creating pipeline stage for prefill pp_rank=0, pp_degree=1
Done launching distributed inference on 2 GPUs.
10-24 12:48:44.632 - dist_run:458 - Prompt: ['What is Snow?']
10-24 12:48:44.632 - dist_run:458 - Prompt: ['What is Snow?']
NCCL version 2.21.5+cuda12.4
10-24 12:48:47.252 - dist_run:493 - Prefilling time: 2.6194 seconds for rank 1
10-24 12:48:47.252 - dist_run:493 - Prefilling time: 2.619 seconds for rank 0
10-24 12:48:47.252 - dist_run:499 - Decoding...prompt_lengths=[5]
10-24 12:48:47.252 - dist_run:499 - Decoding...prompt_lengths=[5]
10-24 12:48:47.255 - dist_run:513 - Creating pipeline stage for decode pp_rank=0, pp_degree=1
10-24 12:48:47.257 - dist_run:273 -  responses ====>>>>  [' Snow']
10-24 12:48:47.258 - dist_run:513 - Creating pipeline stage for decode pp_rank=0, pp_degree=1
10-24 12:48:48.212 - dist_run:273 -  responses ====>>>>  [' is']
10-24 12:48:52.411 - dist_run:273 -  responses ====>>>>  [' a']
10-24 12:48:52.665 - dist_run:273 -  responses ====>>>>  [' natural']
10-24 12:48:52.904 - dist_run:273 -  responses ====>>>>  [' weather']
10-24 12:48:53.154 - dist_run:273 -  responses ====>>>>  [' phenomenon']
10-24 12:48:53.395 - dist_run:273 -  responses ====>>>>  [' that']
10-24 12:48:53.639 - dist_run:273 -  responses ====>>>>  [' occurs']
10-24 12:48:53.885 - dist_run:273 -  responses ====>>>>  [' when']
10-24 12:48:54.133 - dist_run:273 -  responses ====>>>>  [' water']
10-24 12:48:54.377 - dist_run:273 -  responses ====>>>>  [' vapor']
10-24 12:48:54.624 - dist_run:273 -  responses ====>>>>  [' in']
10-24 12:48:54.876 - dist_run:273 -  responses ====>>>>  [' the']
10-24 12:48:55.124 - dist_run:273 -  responses ====>>>>  [' atmosphere']
10-24 12:48:55.372 - dist_run:273 -  responses ====>>>>  [' freezes']
10-24 12:48:55.617 - dist_run:273 -  responses ====>>>>  [' into']
10-24 12:48:55.862 - dist_run:273 -  responses ====>>>>  [' ice']
10-24 12:48:56.110 - dist_run:273 -  responses ====>>>>  [' crystals']
10-24 12:48:56.349 - dist_run:273 -  responses ====>>>>  ['.']
10-24 12:48:56.589 - dist_run:273 -  responses ====>>>>  [' This']
10-24 12:48:56.833 - dist_run:273 -  responses ====>>>>  [' process']
10-24 12:48:57.096 - dist_run:273 -  responses ====>>>>  [' typically']
10-24 12:48:57.335 - dist_run:273 -  responses ====>>>>  [' happens']
10-24 12:48:57.584 - dist_run:273 -  responses ====>>>>  [' when']
10-24 12:48:57.833 - dist_run:273 -  responses ====>>>>  [' the']
10-24 12:48:58.085 - dist_run:273 -  responses ====>>>>  [' air']
10-24 12:48:58.339 - dist_run:273 -  responses ====>>>>  [' temperature']
10-24 12:48:58.595 - dist_run:273 -  responses ====>>>>  [' co']
10-24 12:48:58.841 - dist_run:273 -  responses ====>>>>  ['ols']
10-24 12:48:59.104 - dist_run:273 -  responses ====>>>>  [' to']
10-24 12:48:59.354 - dist_run:273 -  responses ====>>>>  [' a']
10-24 12:48:59.604 - dist_run:273 -  responses ====>>>>  [' point']
10-24 12:48:59.851 - dist_run:273 -  responses ====>>>>  [' where']
10-24 12:49:00.105 - dist_run:273 -  responses ====>>>>  [' the']
10-24 12:49:00.350 - dist_run:273 -  responses ====>>>>  [' water']
10-24 12:49:00.597 - dist_run:273 -  responses ====>>>>  [' vapor']
10-24 12:49:00.850 - dist_run:273 -  responses ====>>>>  [' can']
10-24 12:49:01.106 - dist_run:273 -  responses ====>>>>  [' no']
10-24 12:49:01.353 - dist_run:273 -  responses ====>>>>  [' longer']
10-24 12:49:01.612 - dist_run:273 -  responses ====>>>>  [' remain']
10-24 12:49:01.860 - dist_run:273 -  responses ====>>>>  [' in']
10-24 12:49:01.908 - dist_run:576 - Decoding time: 14.6502 seconds for rank 0
10-24 12:49:01.909 - dist_run:594 - Prompt: What is Snow?
10-24 12:49:01.909 - dist_run:595 - Response:  Snow is a natural weather phenomenon that occurs when water vapor in the atmosphere freezes into ice crystals. This process typically happens when the air temperature cools to a point where the water vapor can no longer remain in
10-24 12:49:01.913 - dist_run:576 - Decoding time: 14.658 seconds for rank 1
Model output:  Snow is a natural weather phenomenon that occurs when water vapor in the atmosphere freezes into ice crystals. This process typically happens when the air temperature cools to a point where the water vapor can no longer remain in
[rank1]:[W1024 12:49:01.434958686 ProcessGroupNCCL.cpp:4121] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
[rank0]:[W1024 12:49:01.434959047 ProcessGroupNCCL.cpp:4121] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
10-24 12:49:03.187 - dist_run:599 - Success - Rank 1 has completed.
10-24 12:49:03.189 - dist_run:599 - Success - Rank 0 has completed.

Copy link

pytorch-bot bot commented Oct 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1327

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2d37d27 with merge base 7fe2c86 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 24, 2024
@mreso mreso requested a review from lessw2020 October 24, 2024 21:30
@lessw2020 lessw2020 requested a review from kwen2501 October 24, 2024 22:02
dist_run.py Outdated Show resolved Hide resolved
dist_run.py Outdated Show resolved Hide resolved
@@ -476,18 +490,19 @@ def _maybe_parallelize_model(


def _load_model(builder_args: BuilderArgs) -> Model:
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is now effectively dead and we should just remove it but a later PR.

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! Thanks for adding this - great job with adding modularity while keeping it light and concise.

@lessw2020 lessw2020 merged commit 9af34c1 into pytorch:main Oct 25, 2024
52 checks passed
@mreso mreso deleted the refactor/dist_run branch October 25, 2024 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants