Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support for fairseq2 Llama #11442

Merged
merged 8 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class _HfExamplesInfo:
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
Expand Down
3 changes: 2 additions & 1 deletion tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
None, mgleize/fairseq2-dummy-Llama-3.2-1B, main
13 changes: 7 additions & 6 deletions tests/weight_loading/test_weight_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ def test_weight_loading(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
quantization=QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:
with vllm_runner(
model_name=MODEL_NAME,
revision=REVISION,
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
quantization=None if QUANTIZATION == "None" else QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:

output = model.generate_greedy("Hello world!", max_tokens=20)
print(output)
Expand Down
34 changes: 22 additions & 12 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

param_data = param.data
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if output_dim is not None and not use_bitsandbytes_4bit:
if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
Expand Down Expand Up @@ -546,6 +548,11 @@ def weight_loader(self,

use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
Expand All @@ -554,9 +561,7 @@ def weight_loader(self,
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
Expand Down Expand Up @@ -941,6 +946,11 @@ def weight_loader(self,

use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
Expand All @@ -964,9 +974,7 @@ def weight_loader(self,
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size

# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

Expand Down Expand Up @@ -1070,6 +1078,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
Expand All @@ -1085,9 +1097,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
Expand Down
15 changes: 13 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ class Source:
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""

allow_patterns_overrides: Optional[list[str]] = None
"""If defined, weights will load exclusively using these patterns."""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
Expand Down Expand Up @@ -217,6 +220,7 @@ def _prepare_weights(
model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool,
allow_patterns_overrides: Optional[list[str]],
) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model.

Expand Down Expand Up @@ -248,6 +252,9 @@ def _prepare_weights(
if fall_back_to_pt:
allow_patterns += ["*.pt"]

if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides

if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
Expand Down Expand Up @@ -297,7 +304,8 @@ def _get_weights_iterator(
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt)
source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
Expand Down Expand Up @@ -339,6 +347,8 @@ def _get_all_weights(
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
None),
)
yield from self._get_weights_iterator(primary_weights)

Expand All @@ -352,7 +362,8 @@ def _get_all_weights(
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
model_config.revision,
fall_back_to_pt=True)
fall_back_to_pt=True,
allow_patterns_overrides=None)

def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
Expand Down
151 changes: 151 additions & 0 deletions vllm/model_executor/models/fairseq2_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2024 The vLLM team.
# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Llama model for fairseq2 weights."""

from typing import Iterable, Set, Tuple

import torch
from torch.nn import Parameter

from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import set_weight_attrs
from vllm.model_executor.models.llama import LlamaForCausalLM

from .utils import AutoWeightsLoader, WeightsMapper


class Fairseq2LlamaForCausalLM(LlamaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
# For the model loader to read only the relevant checkpoint files
self.allow_patterns_overrides = [
# either the full checkpoint
"model.pt",
# or the tp-sharded checkpoint of the current rank
f"model.{self.tp_rank}.pt",
]

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# fairseq2's serialization adds a wrapper to usual .pt state_dict's:
# { "model_key": my_model_name, "my_model_name": state_dict }
# which we first need to unpack
weights_wrapped = dict(weights)
weights = weights_wrapped[
weights_wrapped["model_key"]].items() # type: ignore

# remap keys
fs2_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"decoder_frontend.embed.": "model.embed_tokens.",
"decoder.": "model.",
"final_proj.": "lm_head.",
},
orig_to_new_substr={
".self_attn_layer_norm.": ".input_layernorm.",
".ffn_layer_norm.": ".post_attention_layernorm.",
".self_attn.output_proj.": ".self_attn.o_proj.",
".ffn.gate_proj.": ".mlp.gate_proj.",
".ffn.inner_proj.": ".mlp.up_proj.",
".ffn.output_proj.": ".mlp.down_proj.",
".layer_norm.": ".norm.",
},
)
weights = fs2_to_vllm_mapper.apply(weights)

params = dict(self.named_parameters())

loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(
(self.reshape_fairseq2_weights(name, loaded_weight, params)
for name, loaded_weight in weights))

def flag_sharded_weights(self, params: dict[str, Parameter]):
"""Sets the `is_sharded_weight` flag to True for all sharded weights"""
for name, param in params.items():
modules = name.split(".")
if "norm" in name and len(param.size()) < 2:
# layer norms are not sharded
continue
elif any(emb in modules for emb in ["embed_tokens", "lm_head"]):
# for now we repeat embedding layers for compatibility
continue
else:
# all other layers are sharded
set_weight_attrs(param, {"is_sharded_weight": True})

def reshape_fairseq2_weights(
self,
name: str,
loaded_weight: torch.Tensor,
params: dict[str, Parameter],
) -> Tuple[str, torch.Tensor]:
"""Reshape fairseq2's weights."""

def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
attn_in = self.config.head_dim * n_heads
# check for a sharded weight on dim 0
if attn_in // self.tp_size == w.size()[0]:
attn_in //= self.tp_size
n_heads //= self.tp_size
attn_out = self.config.hidden_size
return (w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1,
2).reshape(attn_in, attn_out))

modules = name.split(".")

# rotary embeds should be sliced
if "k_proj" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)

elif "q_proj" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)

# We make the loaded weights compatible with both
# full checkpoints and tp sharded checkpoints.
# Embeddings are repeated to fit the vocab size.
# Other weights are flagged for the weight_loader calls.
if any(emb in modules for emb in ["embed_tokens", "lm_head"]):
# Embeddings are sharded on dim 0
dim = 0
# In fairseq2, vocab size has to be divisible by tp_size
# so we don't worry about padding
if self.tp_size > 1 and loaded_weight.shape[
dim] < self.config.vocab_size:
assert loaded_weight.shape[
dim] * self.tp_size == self.config.vocab_size, \
"vocab_size should be divisible by tp_size."
repeats = [1] * len(loaded_weight.size())
repeats[dim] = self.tp_size
# repeat to match vocab size and to be easily 'narrow'able
loaded_weight = loaded_weight.repeat(repeats)
set_weight_attrs(params[name], {"is_sharded_weight": False})
# if embeddings are sharded, the rest is too
if "embed_tokens" in modules:
self.flag_sharded_weights(params)

return name, loaded_weight
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
Expand Down
Loading