Skip to content

Commit

Permalink
add SmolVLM
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Jan 23, 2025
1 parent 25a97ce commit c3a654c
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 10 deletions.
37 changes: 28 additions & 9 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,10 @@ def set_gguf_parameters(self):
self.gguf_writer.add_vision_vit_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_vit_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_vit_image_std(self.preprocessor_config["image_std"])
self.gguf_writer.add_vision_vit_select_layer(self.find_hparam(["vision_feature_layer", "mm_vision_select_layer"]))
try:
self.gguf_writer.add_vision_vit_select_layer(self.find_hparam(["vision_feature_layer", "mm_vision_select_layer"]))
except KeyError:
self.gguf_writer.add_vision_vit_select_layer(0)

self.gguf_writer.add_file_type(self.ftype)
logger.info(f"gguf: file type = {self.ftype}")
Expand Down Expand Up @@ -506,8 +509,9 @@ def load_hparams(dir_model: Path):
hparams = json.load(f)
if "text_config" in hparams:
text_config = hparams["text_config"]
model_id = text_config.get("_name_or_path", None)
# for example, llava-1.5-7b-hf misses the language model config, need to retrieve it via model ID
if "_name_or_path" in text_config:
if model_id is not None and model_id != "None" and model_id != "":
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
hparams = {**text_config, **hparams}
return hparams
Expand Down Expand Up @@ -1616,7 +1620,7 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed norms: {norms}")


@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM", "Idefics3ForConditionalGeneration")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

Expand All @@ -1640,6 +1644,11 @@ def __init__(self, *args, **kwargs):
self.preprocessor_config = AutoImageProcessor.from_pretrained(vision_model_id).to_dict()
self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM

if "vision_config" in self.hparams and model_type == "idefics3":
self.vparams = self.hparams["vision_config"]
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
self.vision_arch = gguf.MODEL_ARCH.VISION_IDEFICS3

if self.vparams is not None and self.vision_arch is not None:
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])

Expand Down Expand Up @@ -1694,14 +1703,20 @@ def set_gguf_parameters(self):

# For vision model
if self.vparams is not None:
max_pos_embd = -1
self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
# TODO: should not hardcode these, but they are currently missing from config.json
if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA:
self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM:
self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.LDPV2)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
if self.vision_arch == gguf.MODEL_ARCH.VISION_IDEFICS3:
self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP)
self.gguf_writer.add_vision_vit_scale_factor(self.hparams["scale_factor"])
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-05)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd)

@staticmethod
Expand All @@ -1717,19 +1732,23 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
is_vision_tensor = "vision_tower" in name or "vision_model" in name

# For vision model
if name.startswith("language_model"):
name = name.replace("language_model.", "")
if name.startswith("model.text_model"):
name = name.replace("text_model.", "") # for SmolVLM
else:
name = name.replace("model.vision_tower.", "")
if "post_layernorm" in name:
if "post_layernorm" in name and self.vision_arch != gguf.MODEL_ARCH.VISION_IDEFICS3:
return [] # skip post_layernorm

if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
if not is_vision_tensor:
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)

# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
Expand Down
19 changes: 19 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ class Vit:
PATCH_MERGE_TYPE = "vision.vit.patch_merge_type"
HEAD_COUNT = "vision.vit.attention.head_count"
LAYERNORM_EPS = "vision.vit.attention.layer_norm_epsilon"
SCALE_FACTOR = "vision.vit.scale_factor" # only used by idefics3 for now

#
# recommended mapping of model tensor names for storage in gguf
Expand Down Expand Up @@ -311,6 +312,7 @@ class MODEL_ARCH(IntEnum):
VISION_LLAVA = auto()
VISION_MOBILEVLM = auto()
VISION_MINICPMV = auto()
VISION_IDEFICS3 = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -441,6 +443,7 @@ class MODEL_TENSOR(IntEnum):
POSNET_ATTN_OUT = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
V_MMPROJ_MLP = auto()
V_MMPROJ_PEG = auto()
V_ENC_EMBD_CLS = auto()
Expand Down Expand Up @@ -535,6 +538,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.VISION_LLAVA: "llava",
MODEL_ARCH.VISION_MOBILEVLM: "mobilevlm",
MODEL_ARCH.VISION_MINICPMV: "minicpmv",
MODEL_ARCH.VISION_IDEFICS3: "idefics3",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -664,6 +668,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
# vision
MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}",
MODEL_TENSOR.V_MMPROJ_FC: "v.mmproj.fc",
MODEL_TENSOR.V_MMPROJ_MLP: "v.mmproj.mlp.{bid}",
MODEL_TENSOR.V_MMPROJ_PEG: "v.mmproj.peg.{bid}",
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls",
Expand Down Expand Up @@ -1695,6 +1700,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_TOK_EMBD_SLICE,
MODEL_TENSOR.V_TOK_EMBD_END_SLICE,
],
MODEL_ARCH.VISION_IDEFICS3: [
MODEL_TENSOR.V_MMPROJ_FC,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_ATTN_Q,
MODEL_TENSOR.V_ENC_ATTN_K,
MODEL_TENSOR.V_ENC_ATTN_V,
MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_OUTPUT,
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_POST_NORM,
],
# TODO
}

Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,9 @@ def add_vision_vit_image_mean(self, value: Sequence[float]) -> None:
def add_vision_vit_image_std(self, value: Sequence[float]) -> None:
self.add_array(Keys.Vision.IMAGE_STD, value)

def add_vision_vit_scale_factor(self, value: int) -> None:
self.add_int32(Keys.Vision.Vit.SCALE_FACTOR, value)

def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
if not isinstance(value, str):
template_default = None
Expand Down
15 changes: 15 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,10 @@ class TensorNameMap:
"multi_modal_projector.linear_{bid}",
),

MODEL_TENSOR.V_MMPROJ_FC: (
"model.connector.modality_projection.proj", # SmolVLM
),

MODEL_TENSOR.V_MMPROJ_MLP: (
"model.mm_projector.mlp.mlp.{bid}",
),
Expand All @@ -809,51 +813,61 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
"vision_tower.vision_model.embeddings.patch_embedding",
"vpm.embeddings.patch_embedding",
"model.vision_model.embeddings.patch_embedding", # SmolVLM
),

MODEL_TENSOR.V_ENC_EMBD_POS: (
"vision_tower.vision_model.embeddings.position_embedding",
"vpm.embeddings.position_embedding",
"model.vision_model.embeddings.position_embedding", # SmolVLM
),

MODEL_TENSOR.V_ENC_ATTN_Q: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
"vpm.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
),

MODEL_TENSOR.V_ENC_ATTN_K: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"vpm.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
),

MODEL_TENSOR.V_ENC_ATTN_V: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"vpm.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
),

MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
"vpm.encoder.layers.{bid}.layer_norm1",
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
),

MODEL_TENSOR.V_ENC_OUTPUT: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
),

MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
"vpm.encoder.layers.{bid}.layer_norm2",
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
),

MODEL_TENSOR.V_ENC_FFN_UP: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
"vpm.encoder.layers.{bid}.mlp.fc1",
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM
),

MODEL_TENSOR.V_ENC_FFN_DOWN: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
"vpm.encoder.layers.{bid}.mlp.fc2",
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM
),

MODEL_TENSOR.V_PRE_NORM: (
Expand All @@ -862,6 +876,7 @@ class TensorNameMap:

MODEL_TENSOR.V_POST_NORM: (
"vision_tower.vision_model.post_layernorm",
"model.vision_model.post_layernorm", # SmolVLM
),

MODEL_TENSOR.V_RESMPL_POS_EMBD_K: (
Expand Down
21 changes: 21 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_VISION_LLAVA, "llava" },
{ LLM_ARCH_VISION_MOBILEVLM, "mobilevlm" },
{ LLM_ARCH_VISION_MINICPMV, "minicpmv" },
{ LLM_ARCH_VISION_IDEFICS3, "idefics3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -214,6 +215,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_VISION_VIT_PATCH_MERGE_TYPE, "vision.vit.patch_merge_type" },
{ LLM_KV_VISION_VIT_HEAD_COUNT, "vision.vit.attention.head_count" },
{ LLM_KV_VISION_VIT_LAYERNORM_EPS, "vision.vit.attention.layer_norm_epsilon" },
{ LLM_KV_VISION_VIT_SCALE_FACTOR, "vision.vit.scale_factor" },

// deprecated
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
Expand Down Expand Up @@ -1388,6 +1390,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_V_TOK_EMBD_END_SLICE, "v.tok_embd.end_slice" },
}
},
{
LLM_ARCH_VISION_IDEFICS3,
{
{ LLM_TENSOR_V_MMPROJ_FC, "v.mmproj.fc" },
{ LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" },
{ LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" },
{ LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" },
{ LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
{ LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
{ LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
{ LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
{ LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" },
{ LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
{ LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
{ LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
{ LLM_TENSOR_V_PRE_NORM, "v.pre_norm" },
{ LLM_TENSOR_V_POST_NORM, "v.post_norm" },
}
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ enum llm_arch {
LLM_ARCH_VISION_LLAVA,
LLM_ARCH_VISION_MOBILEVLM,
LLM_ARCH_VISION_MINICPMV,
LLM_ARCH_VISION_IDEFICS3,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -218,6 +219,7 @@ enum llm_kv {
LLM_KV_VISION_VIT_PATCH_MERGE_TYPE,
LLM_KV_VISION_VIT_HEAD_COUNT,
LLM_KV_VISION_VIT_LAYERNORM_EPS,
LLM_KV_VISION_VIT_SCALE_FACTOR,

// deprecated:
LLM_KV_TOKENIZER_PREFIX_ID,
Expand Down Expand Up @@ -354,6 +356,7 @@ enum llm_tensor {
LLM_TENSOR_POS_NET_ATTN_OUT,
// vision
LLM_TENSOR_V_MMPROJ,
LLM_TENSOR_V_MMPROJ_FC,
LLM_TENSOR_V_MMPROJ_MLP,
LLM_TENSOR_V_MMPROJ_PEG,
LLM_TENSOR_V_ENC_EMBD_CLS,
Expand Down
38 changes: 38 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_VISION_VIT_LAYERNORM_EPS, vparams.eps, true);
ml.get_key(LLM_KV_VISION_VIT_SELECT_LAYER, vparams.select_layer, true);
ml.get_key(LLM_KV_VISION_VIT_MAX_POS_EMBD, vparams.max_pos_embd, true);
ml.get_key(LLM_KV_VISION_VIT_SCALE_FACTOR, vparams.scale_factor, false);
{
std::string name;
ml.get_key(LLM_KV_VISION_VIT_PROJECTOR_TYPE, name, true);
Expand Down Expand Up @@ -3555,6 +3556,42 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
vit.mm_tok_embd_slice = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_SLICE, "weight"), {n_embd});
vit.mm_tok_embd_end_slice = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_END_SLICE, "weight"), {n_embd});

for (int i = 0; i < n_vlayer; ++i) {
auto & layer = vit.layers[i];

layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0);
layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0);
layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0);
layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0);
layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0);
layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0);

layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0);
layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0);

layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0);
layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0);
layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0);
layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0);

layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0);
layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0);
}
} break;
case LLM_ARCH_VISION_IDEFICS3:
{
int scale_factor = vit.hparams.scale_factor;
vit.projection = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_FC, "weight"), {n_vembd * scale_factor * scale_factor, n_embd});

vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd});
vit.patch_bias = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd});
vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd});

vit.post_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd});
vit.post_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd});

for (int i = 0; i < n_vlayer; ++i) {
auto & layer = vit.layers[i];

Expand Down Expand Up @@ -4085,6 +4122,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_VISION_LLAVA:
case LLM_ARCH_VISION_MOBILEVLM:
case LLM_ARCH_VISION_MINICPMV:
case LLM_ARCH_VISION_IDEFICS3:
GGML_ABORT("vision arch does not use RoPE");

// all model arches should be listed explicitly here
Expand Down
Loading

0 comments on commit c3a654c

Please sign in to comment.