Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inclusion of InternVLChatModel In PP_SUPPORTED_MODELS(Pipeline Parallelism) #7860

Merged
merged 56 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
4e260ae
Merge pull request #1 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 21, 2024
dbf6ee2
changes for internvl pipeline parallelism
Manikandan-Thangaraj-ZS0321 Aug 21, 2024
e36dfd2
Merge pull request #2 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 21, 2024
08b8538
Inclusion of InternVLChatModel in PP_SUPPORTED_MODELS
Manikandan-Thangaraj-ZS0321 Aug 22, 2024
5129c87
Merge pull request #3 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 22, 2024
5efea42
Merge pull request #4 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 22, 2024
3aca806
Merge pull request #5 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 23, 2024
c7ef1aa
Merge pull request #6 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 26, 2024
d7f2d54
refactor
Manikandan-Thangaraj-ZS0321 Aug 26, 2024
a20827e
Merge remote-tracking branch 'origin/main' into main
Manikandan-Thangaraj-ZS0321 Aug 26, 2024
caaca9a
refactor
Manikandan-Thangaraj-ZS0321 Aug 26, 2024
2137631
refactor
Manikandan-Thangaraj-ZS0321 Aug 26, 2024
927e3f8
refactor
Manikandan-Thangaraj-ZS0321 Aug 26, 2024
7e8ef5c
refactor
Manikandan-Thangaraj-ZS0321 Aug 26, 2024
c891114
refactor
Manikandan-Thangaraj-ZS0321 Aug 26, 2024
85168ed
Merge pull request #7 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 26, 2024
80fa9dd
Merge pull request #8 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 27, 2024
654248c
Added the InternVL2-8B for testing the pipeline parallelism in test_p…
Manikandan-Thangaraj-ZS0321 Aug 27, 2024
a2fd2b6
Merge remote-tracking branch 'origin/main' into main
Manikandan-Thangaraj-ZS0321 Aug 27, 2024
6923683
Merge pull request #9 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 27, 2024
cf155c2
updating branch
Manikandan-Thangaraj-ZS0321 Aug 28, 2024
8ca55fa
Merge pull request #10 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Aug 28, 2024
4ae7573
updation
Manikandan-Thangaraj-ZS0321 Aug 28, 2024
595bd82
Merge remote-tracking branch 'origin/main' into main
Manikandan-Thangaraj-ZS0321 Aug 28, 2024
7aefcbf
updation
Manikandan-Thangaraj-ZS0321 Aug 28, 2024
4f7b9a5
updating branch
Manikandan-Thangaraj-ZS0321 Aug 28, 2024
86fb726
Merge pull request #12 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Sep 2, 2024
af98e95
Merge pull request #13 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Sep 2, 2024
530c8a3
Updating Branch
Manikandan-Thangaraj-ZS0321 Sep 2, 2024
88186f4
Merge remote-tracking branch 'origin/main' into main
Manikandan-Thangaraj-ZS0321 Sep 2, 2024
0d6ac3a
Updating Branch
Manikandan-Thangaraj-ZS0321 Sep 2, 2024
caaffb9
Updating Branch
Manikandan-Thangaraj-ZS0321 Sep 2, 2024
831d447
Refactor
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
a3e9a98
Refactor
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
e218d07
Refactor
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
ca6e920
test case completion
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
7591225
test case completion
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
85504b1
fixing imports
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
d14d6b6
fixing imports in utils.py
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
63039c2
Merge pull request #16 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
41f83dd
optional settings for tokeniser for trust_remote_code
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
3a658ea
Merge remote-tracking branch 'origin/main' into main
Manikandan-Thangaraj-ZS0321 Sep 3, 2024
0034663
InternLM2ForCausalLM and internlm/internlm2_5-7b-chat inclusion
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
6a78c24
Merge pull request #17 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
0ebf029
line formatting
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
f689444
line formatting
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
0cd208c
sorting in utils.py
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
9c3ef5f
formatting
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
7e32e34
formatting
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
468c994
formatting
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
92cfe53
formatting _PP_SUPPORTED_MODELS list
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
9a31c28
removal of internlm/internlm2_5-7b-chat in test_pipeline_parallel.py
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
158bf06
Merge pull request #18 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Sep 4, 2024
91dde87
Merge pull request #19 from vllm-project/main
Manikandan-Thangaraj-ZS0321 Sep 5, 2024
cb3f602
increasing TP_SIZE in test_pipeline_parallel.py
Manikandan-Thangaraj-ZS0321 Sep 5, 2024
5d9bc77
Merge remote-tracking branch 'origin/main' into main
Manikandan-Thangaraj-ZS0321 Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,27 @@
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"


@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
"MODEL_NAME, DIST_BACKEND"),
[
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
])
@pytest.mark.parametrize(
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
"MODEL_NAME, DIST_BACKEND"),
[
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-8B", "ray"),
(1, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
],
)
@fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
Expand Down Expand Up @@ -71,6 +75,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if EAGER_MODE:
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
if TRUST_REMOTE_CODE:
pp_args.append("--trust-remote-code")
tp_args.append("--trust-remote-code")
pp_env = None
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
and CHUNKED_PREFILL):
Expand Down
7 changes: 6 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,12 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
"""

tokenizer = AutoTokenizer.from_pretrained(model)
trust_remote_code = "--trust-remote-code"
if trust_remote_code in arg1 or trust_remote_code in arg2:
tokenizer = AutoTokenizer.from_pretrained(model,
trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(model)

prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"]
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
"InternVLChatModel",
"InternLM2ForCausalLM"
]
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
52 changes: 38 additions & 14 deletions vllm/model_executor/models/internlm2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# -*- coding: utf-8 -*-
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
Expand All @@ -28,6 +28,9 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)


class InternLM2MLP(nn.Module):

Expand Down Expand Up @@ -234,6 +237,7 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
Expand All @@ -243,11 +247,15 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLMDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids)
Expand All @@ -260,21 +268,31 @@ def forward(
attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand All @@ -298,6 +316,8 @@ def __init__(
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def forward(
self,
Expand All @@ -308,7 +328,7 @@ def forward(
intermediate_tensors: IntermediateTensors,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(
Expand Down Expand Up @@ -345,6 +365,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -353,6 +375,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def __init__(self,
nn.Linear(llm_hidden_size, llm_hidden_size))

self.img_context_token_id = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
Expand Down Expand Up @@ -461,7 +463,7 @@ def forward(
positions,
kv_caches,
attn_metadata,
None,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

Expand Down
16 changes: 16 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
if name.startswith(missing_layer_name):
return True
return False


def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):

def make_empty_intermediate_tensors(
batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
key: torch.zeros((batch_size, hidden_size),
dtype=dtype,
device=device)
for key in keys
})

return make_empty_intermediate_tensors
Loading