diff --git a/Makefile b/Makefile index 8a903d7ed5914..b9fa279739680 100644 --- a/Makefile +++ b/Makefile @@ -926,6 +926,7 @@ OBJ_LLAMA = \ src/llama-vocab.o \ src/llama-grammar.o \ src/llama-sampling.o \ + src/llama-vision.o \ src/unicode.o \ src/unicode-data.o @@ -937,6 +938,7 @@ OBJ_COMMON = \ common/ngram-cache.o \ common/sampling.o \ common/train.o \ + common/vision.o \ common/build-info.o \ common/json-schema-to-grammar.o @@ -1120,6 +1122,7 @@ src/llama.o: \ src/llama-vocab.h \ src/llama-grammar.h \ src/llama-sampling.h \ + src/llama-vision.h \ src/unicode.h \ include/llama.h \ ggml/include/ggml-cuda.h \ @@ -1152,6 +1155,17 @@ src/llama-sampling.o: \ include/llama.h $(CXX) $(CXXFLAGS) -c $< -o $@ +src/llama-vision.o: \ + src/llama-vision.cpp \ + src/llama-vision.h \ + include/llama.h \ + ggml/include/ggml-cuda.h \ + ggml/include/ggml-metal.h \ + ggml/include/ggml.h \ + ggml/include/ggml-alloc.h \ + ggml/include/ggml-backend.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + $(LIB_LLAMA): \ $(OBJ_LLAMA) \ $(LIB_GGML) @@ -1209,6 +1223,12 @@ common/ngram-cache.o: \ common/ngram-cache.h $(CXX) $(CXXFLAGS) -c $< -o $@ +common/vision.o: \ + common/vision.cpp \ + common/vision.h \ + common/stb_image.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + $(LIB_COMMON): \ $(OBJ_COMMON) \ $(LIB_LLAMA) \ @@ -1457,7 +1477,6 @@ llama-server: \ examples/server/json-schema-to-grammar.mjs.hpp \ examples/server/loading.html.hpp \ common/json.hpp \ - common/stb_image.h \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) @@ -1480,7 +1499,6 @@ libllava.a: examples/llava/llava.cpp \ examples/llava/llava.h \ examples/llava/clip.cpp \ examples/llava/clip.h \ - common/stb_image.h \ common/base64.hpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -static -fPIC -c $< -o $@ -Wno-cast-qual diff --git a/common/common.cpp b/common/common.cpp index 8d0ed4f95a737..921928d979cf1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1474,6 +1474,27 @@ std::vector llama_tokenize( return result; } +// TODO: this function is hacky, need to be improved +std::vector llama_tokenize_with_img( + const struct llama_context * ctx, + const std::string & text, + bool add_special, + bool parse_special) { + static const std::string IMG_PLACEMENT = ""; + std::vector parts = string_split(text, IMG_PLACEMENT); + std::vector output; + for (const auto & part : parts) { + bool add_bos = &parts.front() == ∂ + auto tokens = llama_tokenize(ctx, part, add_special && add_bos, parse_special); + output.insert(output.end(), tokens.begin(), tokens.end()); + if (&parts.back() != &part) { + // add image token to middle of 2 parts + output.push_back(TOKEN_IMG_PLACEMENT); + } + } + return output; +} + std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' diff --git a/common/common.h b/common/common.h index cb87c4479ed0a..e6fa1c2d41fd9 100644 --- a/common/common.h +++ b/common/common.h @@ -378,6 +378,20 @@ static std::vector string_split(const std::string & str, char delim) { return values; } +// split string by a `std::string delim` instead of `char delim` +static std::vector string_split(std::string s, const std::string & delimiter) { + std::vector tokens; + size_t pos = 0; + std::string token; + while ((pos = s.find(delimiter)) != std::string::npos) { + token = s.substr(0, pos); + tokens.push_back(token); + s.erase(0, pos + delimiter.length()); + } + tokens.push_back(s); + return tokens; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); @@ -447,6 +461,17 @@ std::vector llama_tokenize( bool add_special, bool parse_special = false); +const llama_token TOKEN_IMG_PLACEMENT = -1000; + +// tokenize with "placeholder" for image embedding tokens +// "" will be replaced with TOKEN_IMG_PLACEMENT +// TODO: this function is hacky, need to be improved +std::vector llama_tokenize_with_img( + const struct llama_context * ctx, + const std::string & text, + bool add_special, + bool parse_special = false); + // tokenizes a token into a piece, optionally renders special/control tokens // should work similar to Python's `tokenizer.id_to_piece` std::string llama_token_to_piece( diff --git a/common/vision.cpp b/common/vision.cpp new file mode 100644 index 0000000000000..2b37ded16fa3c --- /dev/null +++ b/common/vision.cpp @@ -0,0 +1,38 @@ +#include "vision.h" + +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" + +#include +#include + +llama_img * load_image_from_file(const char * fname) { + std::ifstream file(fname, std::ios::binary); + if (!file) { + throw std::runtime_error("Unable to open file"); + } + std::vector image_bytes = std::vector( + std::istreambuf_iterator(file), + std::istreambuf_iterator()); + // decode image to byte array + int nx, ny, nc; + auto * bytes = (unsigned char *) image_bytes.data(); + auto * img = stbi_load_from_memory(bytes, image_bytes.size(), &nx, &ny, &nc, 3); + if (!img) { + throw std::runtime_error("failed to decode image bytes"); + } + // printf("nx=%d ny=%d nc=%d\n", nx, ny, nc); + // GGML_ASSERT(nc == 3); + // for (int y = 0; y < ny; y++) { + // for (int x = 0; x < nx; x++) { + // unsigned char * pix = img + x*nc + y*nc*nx; + // printf("%02x%02x%02x ", pix[0], pix[1], pix[2]); + // } + // printf("\n"); + // } + // printf("\n"); + llama_img * result = llama_img_init(nx, ny); + memcpy(result->data, img, nx*ny*3); + stbi_image_free(img); + return result; +} diff --git a/common/vision.h b/common/vision.h new file mode 100644 index 0000000000000..16c6325fd5ac2 --- /dev/null +++ b/common/vision.h @@ -0,0 +1,8 @@ +#pragma once + +#include "llama.h" + +#include +#include + +llama_img * load_image_from_file(const char * fname); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2cd5a8c11bc18..e6b4cd5f2c5a2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11,6 +11,7 @@ import os import re import sys +from transformers import AutoConfig from enum import IntEnum from pathlib import Path from hashlib import sha256 @@ -66,6 +67,12 @@ class Model: dir_model_card: Path is_lora: bool + # for vision model + preprocessor_config: dict[str, Any] | None = None + vparams: dict[str, Any] | None = None + v_tensor_map: gguf.TensorNameMap + v_tensor_names: set[str] | None + # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -95,6 +102,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -210,9 +218,13 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) - if new_name is None: + new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) + if new_name is not None: + return new_name + elif new_name_vision is not None: + return new_name_vision + else: raise ValueError(f"Can not map tensor {name!r}") - return new_name def set_gguf_parameters(self): self.gguf_writer.add_block_count(self.block_count) @@ -452,7 +464,22 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str] @staticmethod def load_hparams(dir_model: Path): with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) + hparams = json.load(f) + if "text_config" in hparams: + text_config = hparams["text_config"] + if "_name_or_path" in text_config: + text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict() + hparams = {**text_config, **hparams} + return hparams + + @staticmethod + def load_preprocessor_config(dir_model: Path): + file_path = dir_model / "preprocessor_config.json" + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + else: + return None @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -1501,10 +1528,17 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed norms: {norms}") -@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") +@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration") class LlamaModel(Model): model_arch = gguf.MODEL_ARCH.LLAMA + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if "vision_config" in self.hparams: + self.vparams = self.hparams["vision_config"] + if self.vparams is not None: + self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"]) + def set_vocab(self): try: self._set_vocab_sentencepiece() @@ -1554,6 +1588,26 @@ def set_gguf_parameters(self): if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + # For vision model + if self.vparams is not None and self.preprocessor_config is not None: + self.gguf_writer.add_vision_type("clip-vit") + self.gguf_writer.add_vision_image_size(self.vparams["image_size"]) + self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"]) + self.gguf_writer.add_vision_clip_architecture("llava") + self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"]) + self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"]) + self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"]) + self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"]) + self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"]) + self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"]) + self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"]) + self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 + self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd) + # TODO: should not hardcode these, but they are currently missing from config.json + self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP) + self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05) + @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: @@ -1568,6 +1622,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + # For vision model + if name.startswith("language_model"): + name = name.replace("language_model.", "") + if "post_layernorm" in name: + return [] # skip post_layernorm + if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")): diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 8aa7b0750cf20..ecc538256eaad 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -23,7 +23,7 @@ #include "ggml-vulkan.h" #endif -#define STB_IMAGE_IMPLEMENTATION +#include "vision.h" // without this, we get duplicated symbol error #include "stb_image.h" #include diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index c2b7267c8133e..0a28f9bf63bd2 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "vision.h" #include @@ -14,7 +15,9 @@ static void print_usage(int, char ** argv) { int main(int argc, char ** argv) { gpt_params params; - params.prompt = "Hello my name is"; + //params.prompt = "Hello my name is"; + params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n" + "USER:\nwhat did you see?\nASSISTANT:"; params.n_predict = 32; if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) { @@ -64,7 +67,7 @@ int main(int argc, char ** argv) { // tokenize the prompt std::vector tokens_list; - tokens_list = ::llama_tokenize(ctx, params.prompt, true); + tokens_list = ::llama_tokenize_with_img(ctx, params.prompt, true); const int n_ctx = llama_n_ctx(ctx); const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size()); @@ -84,22 +87,75 @@ int main(int argc, char ** argv) { LOG("\n"); for (auto id : tokens_list) { - LOG("%s", llama_token_to_piece(ctx, id).c_str()); + if (id == TOKEN_IMG_PLACEMENT) { + LOG(""); + } else { + LOG("%s", llama_token_to_piece(ctx, id).c_str()); + } } + LOG("\n\n"); + + // load image + llama_batch_img img_batch = llama_batch_img_init(1); + img_batch.imgs[0] = load_image_from_file("../models/eiffel-tower-3349075_1280.jpg"); + // create a llama_batch with size 512 // we use this object to submit token data for decoding llama_batch batch = llama_batch_init(512, 0, 1); // evaluate the initial prompt - for (size_t i = 0; i < tokens_list.size(); i++) { - llama_batch_add(batch, tokens_list[i], i, { 0 }, false); + int n_cur = 0; + int i_img = 0; + for (auto id : tokens_list) { + if (id == TOKEN_IMG_PLACEMENT) { + img_batch.pos[i_img] = n_cur; + n_cur += llama_img_n_tokens(ctx, img_batch.imgs[i_img]); + i_img++; + } else { + llama_batch_add(batch, id, n_cur, { 0 }, false); + printf("pos %d tok %d --> %s\n", n_cur, id, llama_token_to_piece(ctx, id).c_str()); + n_cur++; + } } // llama_decode will output logits only for the last token of the prompt batch.logits[batch.n_tokens - 1] = true; + if (llama_encode_vision(ctx, img_batch) != 0) { + LOG("%s: llama_encode_vision() failed\n", __func__); + return 1; + } + + n_cur = 0; + { + auto t1 = ::llama_tokenize(ctx, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", false); + auto t2 = ::llama_tokenize(ctx, "\nwhat did you see?\nASSISTANT:", false); + t1.insert(t1.begin(), 1); + + n_cur = 0; + llama_batch_clear(batch); + llama_batch_add(batch, 1, 0, { 0 }, false); + llama_decode(ctx, batch); + + n_cur = t1.size(); + llama_batch_clear(batch); + llama_batch batch0 = {int32_t(576), nullptr, _test_get_img_embd(ctx), nullptr, nullptr, nullptr, nullptr, n_cur, 1, 0, }; + llama_decode(ctx, batch0); + + n_cur = 0; + llama_batch_clear(batch); + for (auto t : t1) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; } + llama_decode(ctx, batch); + + n_cur = t1.size() + 576; + llama_batch_clear(batch); + printf("pos %d\n", n_cur); + for (auto t : t2) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; } + batch.logits[batch.n_tokens - 1] = true; + } + if (llama_decode(ctx, batch) != 0) { LOG("%s: llama_decode() failed\n", __func__); return 1; @@ -107,18 +163,17 @@ int main(int argc, char ** argv) { // main loop - int n_cur = batch.n_tokens; int n_decode = 0; const auto t_main_start = ggml_time_us(); - while (n_cur <= n_predict) { + for (int i = 0; i < n_predict; i++) { // sample the next token { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, -1); // is it an end of generation? - if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { + if (llama_token_is_eog(model, new_token_id)) { LOG("\n"); break; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2fd2e9d2be828..3e1a676c0cb73 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -173,11 +173,38 @@ class Tokenizer: MIDDLE_ID = "tokenizer.ggml.middle_token_id" EOT_ID = "tokenizer.ggml.eot_token_id" EOM_ID = "tokenizer.ggml.eom_token_id" + IMAGE_START_ID = "tokenizer.ggml.image_start_token_id" + IMAGE_END_ID = "tokenizer.ggml.image_end_token_id" class Adapter: TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" + class Vision: + # only support vision.type = "clip-vit" for now + TYPE = "vision.type" + IMAGE_SIZE = "vision.image_size" + PATCH_SIZE = "vision.patch_size" + IMAGE_MEAN = "vision.image_mean" + IMAGE_STD = "vision.image_std" + + class Clip: + ARCHITECTURE = "vision.clip.architecture" + CONTEXT_LENGTH = "vision.clip.context_length" + EMBEDDING_LENGTH = "vision.clip.embedding_length" + BLOCK_COUNT = "vision.clip.block_count" + FEED_FORWARD_LENGTH = "vision.clip.feed_forward_length" + PROJECTION_TYPE = "vision.clip.projection_type" + PROJECTION_DIM = "vision.clip.projection_dim" + USE_GELU = "vision.clip.use_gelu" + MAX_POS_EMBEDDING = "vision.clip.max_position_embeddings" + MAX_SLICES = "vision.clip.max_slices" + PROJECTOR_TYPE = "vision.clip.projector_type" + SELECT_LAYER = "vision.clip.select_layer" + PATCH_MERGE_TYPE = "vision.clip.patch_merge_type" + HEAD_COUNT = "vision.clip.attention.head_count" + LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon" + # # recommended mapping of model tensor names for storage in gguf # @@ -238,6 +265,8 @@ class MODEL_ARCH(IntEnum): GRANITE = auto() GRANITE_MOE = auto() CHAMELEON = auto() + # vision models + LLAVA_VISION = auto() class MODEL_TENSOR(IntEnum): @@ -345,6 +374,21 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_DOWN = auto() ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() + # vision + V_MMPROJ = auto() + V_ENC_EMBD_CLS = auto() + V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_POS = auto() + V_ENC_ATTN_Q = auto() + V_ENC_ATTN_K = auto() + V_ENC_ATTN_V = auto() + V_ENC_INPUT_NORM = auto() + V_ENC_OUTPUT = auto() + V_ENC_OUTPUT_NORM = auto() + V_ENC_FFN_UP = auto() + V_ENC_FFN_DOWN = auto() + V_PRE_NORM = auto() + V_POST_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -397,6 +441,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GRANITE: "granite", MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", + # vision + MODEL_ARCH.LLAVA_VISION: "llava", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -504,6 +550,21 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + # vision + MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}", + MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls", + MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch", + MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos", + MODEL_TENSOR.V_ENC_ATTN_Q: "v.enc.blk.{bid}.attn_q", + MODEL_TENSOR.V_ENC_ATTN_K: "v.enc.blk.{bid}.attn_k", + MODEL_TENSOR.V_ENC_ATTN_V: "v.enc.blk.{bid}.attn_v", + MODEL_TENSOR.V_ENC_INPUT_NORM: "v.enc.blk.{bid}.input_norm", + MODEL_TENSOR.V_ENC_OUTPUT: "v.enc.blk.{bid}.output", + MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.enc.blk.{bid}.output_norm", + MODEL_TENSOR.V_ENC_FFN_UP: "v.enc.blk.{bid}.ffn_up", + MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down", + MODEL_TENSOR.V_PRE_NORM: "v.pre_norm", + MODEL_TENSOR.V_POST_NORM: "v.post_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1279,6 +1340,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.LLAVA_VISION: [ + MODEL_TENSOR.V_MMPROJ, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + ], # TODO } @@ -1351,6 +1428,15 @@ class PoolingType(IntEnum): CLS = 2 +class CLIPProjectorType(Enum): + MLP = 'mlp' + + +class CLIPPatchMergeType(Enum): + FLAT = 'flat' + SPATIAL_UNPAD = 'spatial_unpad' + + class GGMLQuantizationType(IntEnum): F32 = 0 F16 = 1 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 5c460ef1bc260..02c2cf64e2026 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,8 @@ RopeScalingType, PoolingType, TokenType, + CLIPProjectorType, + CLIPPatchMergeType, ) from .quants import quant_shape_from_byte_shape @@ -814,6 +816,57 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + def add_vision_type(self, value: str) -> None: + self.add_string(Keys.Vision.TYPE, value) + + def add_vision_image_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.IMAGE_SIZE, value) + + def add_vision_patch_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.PATCH_SIZE, value) + + def add_vision_clip_architecture(self, value: str) -> None: + self.add_string(Keys.Vision.Clip.ARCHITECTURE, value) + + def add_vision_clip_context_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.CONTEXT_LENGTH, value) + + def add_vision_clip_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.EMBEDDING_LENGTH, value) + + def add_vision_clip_block_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.BLOCK_COUNT, value) + + def add_vision_clip_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.FEED_FORWARD_LENGTH, value) + + def add_vision_clip_head_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.HEAD_COUNT, value) + + def add_vision_clip_max_position_embeddings(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.MAX_POS_EMBEDDING, value) + + def add_vision_clip_projector_type(self, value: CLIPProjectorType) -> None: + self.add_string(Keys.Vision.Clip.PROJECTOR_TYPE, value.value) + + def add_vision_clip_max_slices(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.MAX_SLICES, value) + + def add_vision_clip_select_layer(self, value: int) -> None: + self.add_int32(Keys.Vision.Clip.SELECT_LAYER, value) + + def add_vision_clip_patch_merge_type(self, value: CLIPPatchMergeType) -> None: + self.add_string(Keys.Vision.Clip.PATCH_MERGE_TYPE, value.value) + + def add_vision_clip_layer_norm_epsilon(self, value: float) -> None: + self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value) + + def add_vision_clip_image_mean(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_MEAN, value) + + def add_vision_clip_image_std(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_STD, value) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 5ef91f11d312f..5ae4d65c782ea 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -679,6 +679,66 @@ class TensorNameMap: MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 ), + + MODEL_TENSOR.V_MMPROJ: ( + "multi_modal_projector.linear_{bid}", + ), + + MODEL_TENSOR.V_MMPROJ: ( + "multi_modal_projector.linear_{bid}", + ), + + MODEL_TENSOR.V_ENC_EMBD_CLS: ( + "vision_tower.vision_model.embeddings.class_embedding", + ), + + MODEL_TENSOR.V_ENC_EMBD_PATCH: ( + "vision_tower.vision_model.embeddings.patch_embedding", + ), + + MODEL_TENSOR.V_ENC_EMBD_POS: ( + "vision_tower.vision_model.embeddings.position_embedding", + ), + + MODEL_TENSOR.V_ENC_ATTN_Q: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", + ), + + MODEL_TENSOR.V_ENC_ATTN_K: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", + ), + + MODEL_TENSOR.V_ENC_ATTN_V: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", + ), + + MODEL_TENSOR.V_ENC_INPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", + ), + + MODEL_TENSOR.V_ENC_OUTPUT: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", + ), + + MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", + ), + + MODEL_TENSOR.V_ENC_FFN_UP: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", + ), + + MODEL_TENSOR.V_ENC_FFN_DOWN: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", + ), + + MODEL_TENSOR.V_PRE_NORM: ( + "vision_tower.vision_model.pre_layrnorm", + ), + + MODEL_TENSOR.V_POST_NORM: ( + "vision_tower.vision_model.post_layernorm", + ), } # architecture-specific block mappings diff --git a/include/llama.h b/include/llama.h index 4ea8a2c2b664b..e66dd0da188a7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -224,6 +224,21 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); + // represent an RGB image + // size of data must be equal to 3*nx*ny + typedef struct llama_img { + uint32_t nx; + uint32_t ny; + unsigned char * data; + } llama_img; + + // Input data for llama_vision_decode + typedef struct llama_batch_img { + int32_t n_imgs; + llama_img ** imgs; + llama_pos * pos; + } llama_batch_img; + // Input data for llama_decode // A llama_batch object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens @@ -875,6 +890,24 @@ extern "C" { // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // + // Vision + // + + // create new RGB image for input + LLAMA_API llama_img * llama_img_init(int width, int height); + LLAMA_API void llama_img_free(llama_img * img); + + // get number of tokens that an image occupies, used to determine the position in the batch + LLAMA_API int32_t llama_img_n_tokens(struct llama_context * ctx, llama_img * img); + + // create new image batch + LLAMA_API llama_batch_img llama_batch_img_init(int n_imgs); + LLAMA_API void llama_batch_img_free(llama_batch_img batch); + + // encode the input image batch + LLAMA_API int32_t llama_encode_vision(struct llama_context * ctx, llama_batch_img batch); + // // Vocab // @@ -1207,6 +1240,7 @@ extern "C" { LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx); + LLAMA_API float * _test_get_img_embd(struct llama_context * ctx); #ifdef __cplusplus } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 46a6ad56202f7..2916e1366ef67 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(llama llama-vocab.cpp llama-grammar.cpp llama-sampling.cpp + llama-vision.cpp unicode.h unicode.cpp unicode-data.cpp diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp new file mode 100644 index 0000000000000..93f9e4b529402 --- /dev/null +++ b/src/llama-vision.cpp @@ -0,0 +1,867 @@ +#include "llama.h" +#include "llama-vision.h" +#include "llama-impl.h" + +#include // memcpy +#include +#include + +#ifndef NDEBUG +// for debugging +#include +#include +#include + +// export clip_image_u8 to bmp file for debugging +// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c +struct clip_image_size; +static int bmp_export(const struct clip_image_u8 &img, const std::string &location); +#endif + +struct clip_image_size { + int width; + int height; +}; + +// RGB uint8 image +// Memory layout: RGBRGBRGB... +struct clip_image_u8 { + int nx; + int ny; + std::vector buf; + clip_image_u8() {} + clip_image_u8(const llama_img img) { + nx = img.nx; + ny = img.ny; + buf.resize(nx*ny*3); + memcpy(buf.data(), img.data, buf.size()); + } +}; + +struct clip_image_u8_batch { + struct clip_image_u8 * data; + size_t size; +}; + +// RGB float32 image (NHWC) +// Memory layout: RGBRGBRGB... +struct clip_image_f32 { + int nx; + int ny; + std::vector buf; +}; + +using clip_image_f32_batch = std::vector; +using clip_image_f8_batch = std::vector; + +clip_projector_type projector_type_from_name(std::string & name) { + if (name == "mlp") { + return CLIP_PROJECTOR_TYPE_MLP; + } + return CLIP_PROJECTOR_TYPE_UNKNOWN; +} + +mm_patch_merge mm_patch_merge_from_name(std::string & name) { + if (name == "flat") { + return MM_PATCH_MERGE_FLAT; + } else if (name == "spatial_unpad") { + return MM_PATCH_MERGE_SPATIAL_UNPAD; + } + return MM_PATCH_MERGE_UNKNOWN; +} + +int clip_n_patches(const clip_context & ctx) { + auto & hparams = ctx.model->hparams; + int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); + return n_patches; +} + +int clip_n_mmproj_embd(const clip_context & ctx) { + if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { + return ctx.model->mm_2_b->ne[0]; + } else { + GGML_ASSERT(false && "invalid proj type"); + } +} + +/** + * Selects the best resolution from a list of possible resolutions based on the original size. + * + * @param original_size The original size of the image in the format (width, height). + * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + * @return The best fit resolution in the format (width, height). + */ +static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector& possible_resolutions) { + int original_width = original_size.width; + int original_height = original_size.height; + + clip_image_size best_fit; + int max_effective_resolution = 0; + int min_wasted_resolution = std::numeric_limits::max(); + + for (const auto& resolution : possible_resolutions) { + int width = resolution.width; + int height = resolution.height; + float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); + int downscaled_width = static_cast(original_width * scale); + int downscaled_height = static_cast(original_height * scale); + int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); + int wasted_resolution = (width * height) - effective_resolution; + // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + best_fit = resolution; + } + } + + return best_fit; +} + +static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { + auto clip = [](int x, int lower, int upper) -> int { + return std::max(lower, std::min(x, upper)); + }; + + const int nx = img.nx; + const int ny = img.ny; + + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); + + float Cc; + float C[5]; + float d0, d2, d3, a0, a1, a2, a3; + int i, j, k, jj; + int x, y; + float dx, dy; + float tx, ty; + + tx = (float)nx / (float)target_width; + ty = (float)ny / (float)target_height; + + // Bicubic interpolation; adapted from ViT.cpp, inspired from : + // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 + // -> https://en.wikipedia.org/wiki/Bicubic_interpolation + + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { + x = (int)(tx * j); + y = (int)(ty * i); + + dx = tx * j - x; + dy = ty * i - y; + + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { + d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; + + d0 = C[0] - C[1]; + d2 = C[2] - C[1]; + d3 = C[3] - C[1]; + a0 = C[1]; + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; + + const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); + dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + } + } + } + } + + return true; +} + +static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { + std::vector patches; + int width = image.nx; + int height = image.ny; + for (int i = 0; i < height; i += patch_size) { + for (int j = 0; j < width; j += patch_size) { + clip_image_u8 patch; + patch.nx = std::min(patch_size, width - j); + patch.ny = std::min(patch_size, height - i); + patch.buf.resize(3 * patch.nx * patch.ny); + for (int y = 0; y < patch.ny; ++y) { + for (int x = 0; x < patch.nx; ++x) { + for (int c = 0; c < 3; ++c) { + patch.buf[3 * (y * patch.nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c]; + } + } + } + patches.push_back(patch); + } + } + return patches; +} + +// llava-1.6 type of resize_and_pad (black) +static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & image_output, const clip_image_size & target_resolution) { + int target_width = target_resolution.width; + int target_height = target_resolution.height; + + float scale_w = static_cast(target_width) / image.nx; + float scale_h = static_cast(target_height) / image.ny; + + int new_width, new_height; + + if (scale_w < scale_h) { + new_width = target_width; + new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); + } else { + new_height = target_height; + new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); + } + + clip_image_u8 resized_image; + // bilinear_resize(image, resized_image, new_width, new_height); + bicubic_resize(image, resized_image, new_width, new_height); + + clip_image_u8 padded_image; + padded_image.nx = target_width; + padded_image.ny = target_height; + padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black + + // Calculate padding offsets + int pad_x = (target_width - new_width) / 2; + int pad_y = (target_height - new_height) / 2; + + // Copy the resized image into the center of the padded buffer + for (int y = 0; y < new_height; ++y) { + for (int x = 0; x < new_width; ++x) { + for (int c = 0; c < 3; ++c) { + padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; + } + } + } + image_output = std::move(padded_image); +} + +static void normalize_image_u8_to_f32(const clip_image_u8 src, clip_image_f32 dst, const std::array & mean, const std::array & std) { + dst.nx = src.nx; + dst.ny = src.ny; + dst.buf.resize(src.buf.size()); + + for (size_t i = 0; i < src.buf.size(); ++i) { + int c = i % 3; // rgb + dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; + } +} + +// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector +// res_imgs memory is being allocated here, previous allocations will be freed if found +static bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img, clip_image_f32_batch & output_imgs) { + bool pad_to_square = true; + auto & params = ctx.model->hparams; + // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing + if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) { + pad_to_square = false; + } + + // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) + // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 + + clip_image_u8 temp; + if (pad_to_square && img.nx != img.ny) { + int longer_side = std::max(img.nx, img.ny); + temp.nx = longer_side; + temp.ny = longer_side; + temp.buf.resize(3 * longer_side * longer_side); + const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) + + // fill with background color + for (size_t i = 0; i < temp.buf.size(); i++) { + temp.buf[i] = bc[i % 3]; + } + + // copy from the input image + for (int y = 0; y < img.ny; y++) { + for (int x = 0; x < img.nx; x++) { + const int i = 3 * (y * img.nx + x); + const int j = 3 * (y * temp.nx + x); + temp.buf[j] = img.buf[i]; + temp.buf[j+1] = img.buf[i+1]; + temp.buf[j+2] = img.buf[i+2]; + } + } + } else { + if (params.image_grid_pinpoints[0] != 0) { + // "spatial_unpad" with "anyres" processing for llava-1.6 + std::vector possible_resolutions; + for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) { + clip_image_size s; + s.width = params.image_grid_pinpoints[i]; + s.height = params.image_grid_pinpoints[i+1]; + possible_resolutions.push_back(s); + } + clip_image_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions); + // clip_image_save_to_bmp(*img, "input.bmp"); + resize_and_pad_image(img, temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 + // clip_image_save_to_bmp(*temp, "resized.bmp"); + + std::vector patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + + clip_image_u8 image_original_resize; + // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + patches.insert(patches.begin(), image_original_resize); + // clip_image_f32_batch_init(patches.size()); + output_imgs.resize(patches.size()); + int num = 0; + for (auto & patch : patches) { + normalize_image_u8_to_f32(patch, output_imgs[num], params.image_mean, params.image_std); + num++; + } + return true; + } else { + temp.nx = img.nx; + temp.ny = img.ny; + temp.buf.resize(img.buf.size()); + memcpy(temp.buf.data(), img.buf.data(), temp.buf.size()); + } + } + + const int nx = temp.nx; + const int ny = temp.ny; + // bmp_export(temp, "resized_vanilla.bmp"); + + const int nx2 = params.image_size; + const int ny2 = params.image_size; + clip_image_f32 res; + res.nx = nx2; + res.ny = ny2; + res.buf.resize(3 * nx2 * ny2); + + const float scale = std::max(nx, ny) / (float)params.image_size; + + const int nx3 = int(nx / scale + 0.5f); + const int ny3 = int(ny / scale + 0.5f); + + const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; + const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; + + for (int y = 0; y < ny3; y++) { + for (int x = 0; x < nx3; x++) { + for (int c = 0; c < 3; c++) { + // linear interpolation + const float sx = (x + 0.5f) * scale - 0.5f; + const float sy = (y + 0.5f) * scale - 0.5f; + + const int x0 = std::max(0, (int)std::floor(sx)); + const int y0 = std::max(0, (int)std::floor(sy)); + + const int x1 = std::min(x0 + 1, nx - 1); + const int y1 = std::min(y0 + 1, ny - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = 3 * (y0 * nx + x0) + c; + const int j01 = 3 * (y0 * nx + x1) + c; + const int j10 = 3 * (y1 * nx + x0) + c; + const int j11 = 3 * (y1 * nx + x1) + c; + + const float v00 = temp.buf[j00]; + const float v01 = temp.buf[j01]; + const float v10 = temp.buf[j10]; + const float v11 = temp.buf[j11]; + + const float v0 = v00 * (1.0f - dx) + v01 * dx; + const float v1 = v10 * (1.0f - dx) + v11 * dx; + + const float v = v0 * (1.0f - dy) + v1 * dy; + + const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); + + const int i = 3 * (y * nx3 + x) + c; + + res.buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; + } + } + } + + output_imgs.resize(1); + output_imgs[0] = std::move(res); + + return true; +} + +static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, clip_image_size & image_size) { + auto & model = *ctx.model; + auto & hparams = ctx.model->hparams; + + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const int patch_size = hparams.patch_size; + const float eps = hparams.eps; + const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + + LLAMA_LOG_DEBUG("%s: num_patches = %d\n", __func__, num_patches); + + struct ggml_init_params params = { + /*.mem_size =*/ ctx.buf_compute_meta.size(), + /*.mem_buffer =*/ ctx.buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // input + struct ggml_tensor * embeddings; + { + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size.width, image_size.height, 3, batch_size); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + // auto * ne = inp->ne; printf("%d %d %d %d\n", ne[0], ne[1], ne[2], ne[3]); + + embeddings = inp; + if (model.class_embedding) { + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + ggml_set_name(embeddings, "embeddings"); + ggml_set_input(embeddings); + embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); + embeddings = ggml_acc(ctx0, embeddings, inp, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + } + + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + embeddings = ggml_add(ctx0, + embeddings, + ggml_get_rows(ctx0, model.position_embeddings, positions)); + } + + // pre-layernorm + if (model.pre_norm_w) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_b); + } + + // loop over layers + for (int il = 0; il < (int)hparams.n_layer + hparams.select_layer; il++) { + struct ggml_tensor * cur = embeddings; + + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_in_w), + model.layers[il].norm_in_b); + } + + // self-attention + { + + struct ggml_tensor * Q = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].q_w, cur), + model.layers[il].q_b); + + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * K = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].k_w, cur), + model.layers[il].k_b); + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * V = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].v_w, cur), + model.layers[il].v_b); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].output_w, cur), model.layers[il].output_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_out_w), + model.layers[il].norm_out_b); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_up_b); + + if (hparams.use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_down_b); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + embeddings = cur; + } + + // post-layernorm + if (model.post_norm_w) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_norm_w), model.post_norm_b); + } + + // llava projector + { + embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + + struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + ggml_set_name(patches, "patches"); + ggml_set_input(patches); + + // shape [1, 576, 1024] + // ne is whcn, ne = [1024, 576, 1, 1] + embeddings = ggml_get_rows(ctx0, embeddings, patches); + + if (hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { + embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + + embeddings = ggml_gelu(ctx0, embeddings); + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + } else { + GGML_ASSERT(false && "unsupported proj type"); + } + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + ggml_free(ctx0); + return gf; +} + +static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output) { + int batch_size = imgs.size(); + auto & model = *ctx.model; + auto & hparams = ctx.model->hparams; + + if (hparams.arch == VISION_ARCH_LLAVA) { + GGML_ASSERT(batch_size == 1); // TODO: support multiple images + } + + clip_image_size image_size{(int)hparams.image_size, (int)hparams.image_size}; + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + + LLAMA_LOG_DEBUG("%s: image_size = %d\n", __func__, hparams.image_size); + LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions); + + // build the inference graph + ggml_cgraph * gf = clip_image_build_graph(ctx, batch_size, image_size); + + // alloc memory for graph + bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf); + if (!ok) { + LLAMA_LOG_ERROR("failed to alloc memory for graph\n"); + return -1; + } + + // set raw input + { + struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); + float * data = (float *)malloc(ggml_nbytes(inp_raw)); + + for (int i = 0; i < batch_size; i++) { + const int nx = imgs[i].nx; + const int ny = imgs[i].ny; + const int n = nx * ny; + + for (int b = 0; b < batch_size; b++) { + for (int k = 0; k < 3; k++) { + for (int y = 0; y < ny; y++) { + for (int x = 0; x < nx; x++) { + data[(b * 3 * n) + k * n + y * nx + x] = imgs[b].buf[3 * (y * nx + x) + k]; + } + } + } + } + } + ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); + free(data); + } + + if (model.class_embedding) { + struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); + + void* zero_mem = malloc(ggml_nbytes(embeddings)); + memset(zero_mem, 0, ggml_nbytes(embeddings)); + ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); + free(zero_mem); + } + + { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + for (int i = 0; i < num_positions; i++) { + positions_data[i] = i; + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + } + + { + struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + int* patches_data = (int*)malloc(ggml_nbytes(patches)); + for (int i = 0; i < num_patches; i++) { + patches_data[i] = i + 1; + } + ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); + free(patches_data); + } + + // compute + ggml_backend_sched_graph_compute_async(ctx.sched, gf); + + // the last node is the embedding tensor + struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(ctx.sched, embeddings); + + // copy the embeddings to the location passed by the user + size_t out_nbytes = clip_n_patches(ctx)*clip_n_mmproj_embd(ctx)*sizeof(float); + GGML_ASSERT(out_nbytes == ggml_nbytes(embeddings)); + output.resize(out_nbytes); + ggml_backend_tensor_get_async(backend_embd, embeddings, output.data(), 0, ggml_nbytes(embeddings)); + + ggml_backend_sched_synchronize(ctx.sched); + + return 0; +} + +static int32_t clip_image_encode(clip_context & ctx, const clip_image_f32 & img, std::vector & output) { + clip_image_f32_batch imgs{img}; + return clip_image_batch_encode(ctx, imgs, output); +} + +static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, std::vector & output_embd) { + clip_image_u8 img_u8(img); + clip_image_f32_batch img_res_v; + auto & hparams = ctx.model->hparams; + // bmp_export(img_u8, "test_inp.bmp"); + + if (!clip_image_preprocess(ctx, img_u8, img_res_v)) { + LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__); + return -2; + } + + switch (hparams.mm_patch_merge_type) { + case MM_PATCH_MERGE_FLAT: + { + // flat / default llava-1.5 type embedding + // n_output = clip_n_patches(ctx); + int32_t encoded = clip_image_encode(ctx, img_res_v[0], output_embd); + if (encoded != 0) { + LLAMA_LOG_ERROR("Unable to encode image\n"); + return encoded; + } + } break; + case MM_PATCH_MERGE_SPATIAL_UNPAD: + { + // TODO: support llava-1.6 + (void)0; + } break; + default: + GGML_ASSERT(false && "unsupported mm_patch_merge_type"); + } + + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////// +// public API + +int32_t llama_encode_vision_internal(clip_context & ctx, llama_batch_img * batch) { + if (batch->n_imgs == 0) { + return 0; + } + + // TODO: batching is not working atm, should be fixed later + const int n_embd = clip_n_mmproj_embd(ctx); + const int n_tokens_per_img = clip_n_patches(ctx); + const int n_pos = n_tokens_per_img*batch->n_imgs; + + ctx.out_embd.resize(n_embd*n_pos); + ctx.out_pos.resize(n_pos); + + for (int i = 0; i < batch->n_imgs; i++) { + std::vector output_single; + int32_t status = encode_image_with_clip(ctx, *batch->imgs[i], output_single); + if (status != 0) { + return status; + } + // copy output embeddings to result + for (int k = 0; k < n_embd*n_tokens_per_img; k++) { + ctx.out_embd[n_embd*n_tokens_per_img*i + k] = output_single[k]; + } + // fill position for all output tokens + for (int p = 0; p < n_tokens_per_img; p++) { + ctx.out_pos[n_tokens_per_img*i + p] = batch->pos[i] + p; + } + } + + return 0; +} + +void llama_vision_clear_output(clip_context & ctx) { + ctx.out_embd.clear(); + ctx.out_pos.clear(); +} + +//////////////////////////////////////////////////////////////////////////////////////// +// for debugging +#ifndef NDEBUG + +static int bmp_export(const struct clip_image_u8 &img, const std::string &location) { + const uint32_t width = img.nx; + const uint32_t height = img.ny; + // swap red and blue channel + std::vector buffer(width*height*3); + for (uint32_t y = 0; y < height; y++) { + for (uint32_t x = 0; x < width; x++) { + size_t base = x*3 + y*3*width; + buffer[base+2] = img.buf[base]; + buffer[base+1] = img.buf[base+1]; + buffer[base] = img.buf[base+2]; + } + } + const bool hasAlphaChannel = false; + + std::ofstream fout(location, std::ios::out | std::ios::binary); + + if (fout.fail()) { + return 0; + } + + //Padding + const uint8_t padding = hasAlphaChannel ? 0 : (4 - (width * 3) % 4) % 4; + + //Bitmap file header. + const char signature[2] = { 'B', 'M' }; + const uint32_t fileSize = buffer.size() * sizeof(uint8_t) + padding * (height - 1) + 14 + 124; + const uint32_t offset = 14 + 124; + + //Bitmap information header file + const uint32_t DIBSize = 124; + const int32_t bitmapWidth = width; + const int32_t bitmapHeight = height; + const uint16_t numPlanes = 1; + const uint16_t bitsPerPixel = (hasAlphaChannel) ? 32 : 24; + const uint32_t compressionMethod = (hasAlphaChannel) ? 3 : 0; //BI_RGB = 0, BI_BITFIELDS = 3 + const uint32_t bitmapSize = buffer.size() * sizeof(uint8_t); + const int32_t horizontalResolution = 2834; + const int32_t verticalResolution = 2834; + const uint32_t numColors = 0; + const uint32_t impColorCount = 0; + const uint32_t redBitmask = (hasAlphaChannel) ? 0x0000FF00 : 0; //ARGB32 pixel format + const uint32_t greenBitmask = (hasAlphaChannel) ? 0x00FF0000 : 0; + const uint32_t blueBitmask = (hasAlphaChannel) ? 0xFF000000 : 0; + const uint32_t alphaBitmask = (hasAlphaChannel) ? 0x000000FF : 0; + + //Writing the file header and information header to the file + std::vector header(offset, 0); + header[0] = signature[0]; + header[1] = signature[1]; + +#define BMP_HEADERS(i, variableName) header[i] = variableName; header[i+1] = variableName >> 8; header[i+2] = variableName >> 16; header[i+3] = variableName >> 24; + + BMP_HEADERS(2, fileSize); + BMP_HEADERS(6, 0); + BMP_HEADERS(10, offset); + BMP_HEADERS(14, DIBSize); + BMP_HEADERS(18, bitmapWidth); + BMP_HEADERS(22, bitmapHeight); + + header[26] = (uint8_t)numPlanes; + header[27] = (uint8_t)(numPlanes >> 8); + header[28] = (uint8_t)bitsPerPixel; + header[29] = (uint8_t)(bitsPerPixel >> 8); + + BMP_HEADERS(30, compressionMethod); + BMP_HEADERS(34, (unsigned char)bitmapSize); + BMP_HEADERS(38, (unsigned char)horizontalResolution); + BMP_HEADERS(42, (unsigned char)verticalResolution); + BMP_HEADERS(46, (unsigned char)numColors); + BMP_HEADERS(50, (unsigned char)impColorCount); + BMP_HEADERS(54, (unsigned char)redBitmask); + BMP_HEADERS(58, (unsigned char)greenBitmask); + BMP_HEADERS(62, (unsigned char)blueBitmask); + BMP_HEADERS(66, alphaBitmask); + +#undef BMP_HEADERS + + fout.write((char *)header.data(), sizeof(uint8_t) * header.size()); + + //Writing the pixel array + const uint32_t bWidth = bitsPerPixel / 8 * width; + + for (int i = height - 1; i >= 0; i--) { + std::vector row(buffer.begin() + i * bWidth, buffer.begin() + i * bWidth + bWidth); + fout.write((char *)row.data(), row.size() * sizeof(uint8_t)); + fout.seekp(padding * sizeof(uint8_t), std::ios::cur); + } + + fout.close(); + return 1; +} + +#endif + diff --git a/src/llama-vision.h b/src/llama-vision.h new file mode 100644 index 0000000000000..950f497c88da1 --- /dev/null +++ b/src/llama-vision.h @@ -0,0 +1,124 @@ +#pragma once + +#include "ggml.h" + +#include +#include + +enum vision_arch { + VISION_ARCH_UNKNOWN, + VISION_ARCH_LLAVA, +}; + +enum clip_projector_type { + CLIP_PROJECTOR_TYPE_UNKNOWN, + CLIP_PROJECTOR_TYPE_MLP, +}; + +enum mm_patch_merge { + MM_PATCH_MERGE_UNKNOWN, + MM_PATCH_MERGE_FLAT, + MM_PATCH_MERGE_SPATIAL_UNPAD, +}; + +struct clip_hparams { + vision_arch arch = VISION_ARCH_UNKNOWN; + + uint32_t image_size; + uint32_t patch_size; + uint32_t hidden_size; + uint32_t n_intermediate; + uint32_t projection_dim; + uint32_t n_head; + uint32_t n_layer; + uint32_t max_pos_embd; + int32_t select_layer = 0; + bool use_gelu = false; + + float eps; + + clip_projector_type proj_type = CLIP_PROJECTOR_TYPE_UNKNOWN; + mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_FLAT; + + std::array image_mean; + std::array image_std; + + std::array image_grid_pinpoints; + int32_t image_crop_resolution; +}; + +struct clip_layer { + // attention + struct ggml_tensor * k_w = NULL; + struct ggml_tensor * k_b = NULL; + struct ggml_tensor * q_w = NULL; + struct ggml_tensor * q_b = NULL; + struct ggml_tensor * v_w = NULL; + struct ggml_tensor * v_b = NULL; + + struct ggml_tensor * output_w = NULL; + struct ggml_tensor * output_b = NULL; + + // layernorm 1 + struct ggml_tensor * norm_in_w = NULL; + struct ggml_tensor * norm_in_b = NULL; + + // ff + struct ggml_tensor * ffn_up_w = NULL; + struct ggml_tensor * ffn_up_b = NULL; + + struct ggml_tensor * ffn_down_w = NULL; + struct ggml_tensor * ffn_down_b = NULL; + + // layernorm 2 + struct ggml_tensor * norm_out_w = NULL; + struct ggml_tensor * norm_out_b = NULL; +}; + +struct clip_vision_model { + struct clip_hparams hparams; + + // embeddings + struct ggml_tensor * class_embedding = NULL; + struct ggml_tensor * patch_embeddings = NULL; + struct ggml_tensor * patch_bias = NULL; + struct ggml_tensor * position_embeddings = NULL; + + struct ggml_tensor * pre_norm_w = NULL; + struct ggml_tensor * pre_norm_b = NULL; + + std::vector layers; + + struct ggml_tensor * post_norm_w = NULL; + struct ggml_tensor * post_norm_b = NULL; + + struct ggml_tensor * projection = NULL; + + // LLaVA projection + struct ggml_tensor * mm_1_w = NULL; + struct ggml_tensor * mm_1_b = NULL; + struct ggml_tensor * mm_2_w = NULL; + struct ggml_tensor * mm_2_b = NULL; + + struct ggml_tensor * image_newline = NULL; +}; + +struct clip_context { + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + ggml_backend_sched_t sched = nullptr; + + const clip_vision_model * model; + + // temporary output data, to be picked up by llama_decode() + std::vector out_embd; // size == n_tokens * n_embd + std::vector out_pos; // position of each token +}; + +mm_patch_merge mm_patch_merge_from_name(std::string & name); +clip_projector_type projector_type_from_name(std::string & name); +int clip_n_patches(const clip_context & ctx); +int clip_n_mmproj_embd(const clip_context & ctx); + +int32_t llama_encode_vision_internal(clip_context & ctx, llama_batch_img * batch); +void llama_vision_clear_output(clip_context & ctx); diff --git a/src/llama.cpp b/src/llama.cpp index 44afb31d74e53..b1b44aacaa993 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,6 +1,7 @@ #include "llama-impl.h" #include "llama-vocab.h" #include "llama-sampling.h" +#include "llama-vision.h" #include "unicode.h" @@ -273,6 +274,11 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_UNKNOWN, "(unknown)" }, }; +static const std::map VISION_ARCH_NAMES = { + { VISION_ARCH_LLAVA, "llava" }, + { VISION_ARCH_UNKNOWN, "(unknown)" }, +}; + enum llm_kv { LLM_KV_GENERAL_TYPE, LLM_KV_GENERAL_ARCHITECTURE, @@ -379,6 +385,28 @@ enum llm_kv { LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, + + // TODO: these are vision-related KV, probably should be moved to a new enum + LLM_KV_VISION_TYPE, + LLM_KV_VISION_IMAGE_SIZE, + LLM_KV_VISION_PATCH_SIZE, + LLM_KV_VISION_IMAGE_MEAN, + LLM_KV_VISION_IMAGE_STD, + LLM_KV_VISION_CLIP_ARCHITECTURE, + LLM_KV_VISION_CLIP_CONTEXT_LENGTH, + LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, + LLM_KV_VISION_CLIP_BLOCK_COUNT, + LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, + LLM_KV_VISION_CLIP_PROJECTION_TYPE, + LLM_KV_VISION_CLIP_PROJECTION_DIM, + LLM_KV_VISION_CLIP_USE_GELU, + LLM_KV_VISION_CLIP_MAX_POS_EMBD, + LLM_KV_VISION_CLIP_MAX_SLICES, + LLM_KV_VISION_CLIP_PROJECTOR_TYPE, + LLM_KV_VISION_CLIP_SELECT_LAYER, + LLM_KV_VISION_CLIP_PATCH_MERGE_TYPE, + LLM_KV_VISION_CLIP_HEAD_COUNT, + LLM_KV_VISION_CLIP_LAYERNORM_EPS, }; static const std::map LLM_KV_NAMES = { @@ -487,6 +515,27 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + + { LLM_KV_VISION_TYPE, "vision.type" }, + { LLM_KV_VISION_IMAGE_SIZE, "vision.image_size" }, + { LLM_KV_VISION_PATCH_SIZE, "vision.patch_size" }, + { LLM_KV_VISION_IMAGE_MEAN, "vision.image_mean" }, + { LLM_KV_VISION_IMAGE_STD, "vision.image_std" }, + { LLM_KV_VISION_CLIP_ARCHITECTURE, "vision.clip.architecture" }, + { LLM_KV_VISION_CLIP_CONTEXT_LENGTH, "vision.clip.context_length" }, + { LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, "vision.clip.embedding_length" }, + { LLM_KV_VISION_CLIP_BLOCK_COUNT, "vision.clip.block_count" }, + { LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, "vision.clip.feed_forward_length" }, + { LLM_KV_VISION_CLIP_PROJECTION_TYPE, "vision.clip.projection_type" }, + { LLM_KV_VISION_CLIP_PROJECTION_DIM, "vision.clip.projection_dim" }, + { LLM_KV_VISION_CLIP_USE_GELU, "vision.clip.use_gelu" }, + { LLM_KV_VISION_CLIP_MAX_POS_EMBD, "vision.clip.max_position_embeddings" }, + { LLM_KV_VISION_CLIP_MAX_SLICES, "vision.clip.max_slices" }, + { LLM_KV_VISION_CLIP_PROJECTOR_TYPE, "vision.clip.projector_type" }, + { LLM_KV_VISION_CLIP_SELECT_LAYER, "vision.clip.select_layer" }, + { LLM_KV_VISION_CLIP_PATCH_MERGE_TYPE, "vision.clip.patch_merge_type" }, + { LLM_KV_VISION_CLIP_HEAD_COUNT, "vision.clip.attention.head_count" }, + { LLM_KV_VISION_CLIP_LAYERNORM_EPS, "vision.clip.attention.layer_norm_epsilon" }, }; struct LLM_KV { @@ -608,6 +657,23 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, }; +enum vision_tensor { + VISION_TENSOR_MMPROJ, + VISION_TENSOR_ENC_EMBD_CLS, + VISION_TENSOR_ENC_EMBD_PATCH, + VISION_TENSOR_ENC_EMBD_POS, + VISION_TENSOR_ENC_ATTN_Q, + VISION_TENSOR_ENC_ATTN_K, + VISION_TENSOR_ENC_ATTN_V, + VISION_TENSOR_ENC_INPUT_NORM, + VISION_TENSOR_ENC_OUTPUT, + VISION_TENSOR_ENC_OUTPUT_NORM, + VISION_TENSOR_ENC_FFN_UP, + VISION_TENSOR_ENC_FFN_DOWN, + VISION_TENSOR_PRE_NORM, + VISION_TENSOR_POST_NORM, +}; + static const std::map> LLM_TENSOR_NAMES = { { LLM_ARCH_LLAMA, @@ -1530,6 +1596,28 @@ static const std::map> LLM_TENSOR_NA }, }; +static const std::map> VISION_TENSOR_NAMES = { + { + VISION_ARCH_LLAVA, + { + { VISION_TENSOR_MMPROJ, "v.mmproj_%d" }, + { VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" }, + { VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { VISION_TENSOR_PRE_NORM, "v.pre_norm" }, + { VISION_TENSOR_POST_NORM, "v.post_norm" }, + } + } +}; + static llm_arch llm_arch_from_string(const std::string & name) { for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT if (kv.second == name) { @@ -1540,56 +1628,66 @@ static llm_arch llm_arch_from_string(const std::string & name) { return LLM_ARCH_UNKNOWN; } -// helper to handle gguf constants -// usage: -// -// const auto tn = LLM_TN(LLM_ARCH_LLAMA); -// -// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" -// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" -// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" -// -struct LLM_TN { - LLM_TN(llm_arch arch) : arch(arch) {} +template +struct BASE_TN { + Tname arch; + std::map> name_mapping; - llm_arch arch; + BASE_TN(Tname arch, std::map> name_mapping) : arch(arch), name_mapping(name_mapping) {} - std::string operator()(llm_tensor tensor) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return LLM_TENSOR_NAMES.at(arch).at(tensor); + return name_mapping.at(arch).at(tensor); } - std::string operator()(llm_tensor tensor, const std::string & suffix) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor, const std::string & suffix) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return LLM_TENSOR_NAMES.at(arch).at(tensor) + "." + suffix; + return name_mapping.at(arch).at(tensor) + "." + suffix; } - std::string operator()(llm_tensor tensor, int bid) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor, int bid) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid); + return ::format(name_mapping.at(arch).at(tensor).c_str(), bid); } - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor, const std::string & suffix, int bid) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid) + "." + suffix; + return ::format(name_mapping.at(arch).at(tensor).c_str(), bid) + "." + suffix; } - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor, const std::string & suffix, int bid, int xid) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix; + return ::format(name_mapping.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix; } }; +// helper to handle gguf constants +// usage: +// +// const auto tn = LLM_TN(LLM_ARCH_LLAMA); +// +// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" +// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" +// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" +// +struct LLM_TN : BASE_TN { + LLM_TN(llm_arch arch) : BASE_TN(arch, LLM_TENSOR_NAMES) {} +}; + +struct VISION_TN : BASE_TN { + VISION_TN(vision_arch arch) : BASE_TN(arch, VISION_TENSOR_NAMES) {} +}; + // // gguf helpers // @@ -2908,6 +3006,9 @@ struct llama_model { std::vector layers; + bool has_vision = false; + clip_vision_model clip; + llama_split_mode split_mode; int main_gpu; int n_gpu_layers; @@ -3403,6 +3504,9 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + + // vision + clip_context clip; }; struct llama_lora_weight { @@ -6123,6 +6227,58 @@ static void llm_load_hparams( default: (void)0; } + // vision model + auto & vparams = model.clip.hparams; + std::string vision_type; + ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); + if (vision_type == "clip-vit") { + model.has_vision = true; + ml.get_key(LLM_KV_VISION_IMAGE_SIZE, vparams.image_size, true); + ml.get_key(LLM_KV_VISION_PATCH_SIZE, vparams.patch_size, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, vparams.image_mean, 3, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, vparams.image_std, 3, true); + ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, vparams.hidden_size, true); + ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, vparams.n_layer, true); + ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true); + ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true); + ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true); + ml.get_key(LLM_KV_VISION_CLIP_SELECT_LAYER, vparams.select_layer, true); + { + std::string name; + ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, name, true); + vparams.proj_type = projector_type_from_name(name); + if (vparams.proj_type == CLIP_PROJECTOR_TYPE_UNKNOWN) { + throw std::runtime_error(format("unsupported clip projector type: %s", name.c_str())); + } + } + { + std::string name; + ml.get_key(LLM_KV_VISION_CLIP_PATCH_MERGE_TYPE, name, false); + vparams.mm_patch_merge_type = mm_patch_merge_from_name(name); + } + { + std::string arch; + ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true); + for (auto & it : VISION_ARCH_NAMES) { + if (arch == it.second) { + vparams.arch = it.first; + break; + } + } + } + } else if (!vision_type.empty()) { + throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); + } + + // arch-specific CLIP hparams + switch (vparams.arch) { + case VISION_ARCH_LLAVA: + { + ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, vparams.max_pos_embd, true); + } break; + default: (void)0; + } + model.ftype = ml.ftype; if (hparams.f_max_alibi_bias > 0.0f) { @@ -8811,7 +8967,70 @@ static bool llm_load_tensors( } } break; default: - throw std::runtime_error("unknown architecture"); + throw std::runtime_error("unknown llm architecture"); + } + } + + // load tensors for vision model + auto & vparams = model.clip.hparams; + if (model.has_vision) { + const int64_t n_layer = vparams.n_layer; + const int64_t n_embd = vparams.hidden_size; + const int64_t n_ff = vparams.n_intermediate; + const int64_t max_pos_embd = vparams.max_pos_embd; + const int64_t n_channel = 3; // always RGB + const int64_t patch_size = vparams.patch_size; + const auto tn = VISION_TN(vparams.arch); + + ggml_context * ctx_vision = ctx_map.at(model.buft_input.buft); // TODO: make dedicated buft for vision + auto ctx_for_layer = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); }; + + model.clip.layers.resize(n_layer); + + switch (vparams.arch) { + case VISION_ARCH_LLAVA: + { + model.clip.mm_1_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 1), {n_embd, n_ff}); + model.clip.mm_1_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 1), {n_ff}); + model.clip.mm_2_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 2), {n_ff, n_ff}); + model.clip.mm_2_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 2), {n_ff}); + + model.clip.class_embedding = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_CLS ), {n_embd}); + model.clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_embd}); + model.clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd}); + + model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd}); + model.clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd}); + model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + auto & layer = model.clip.layers[i]; + + layer.k_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd}); + layer.k_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_K, "bias" , i), {n_embd}); + layer.v_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd}); + layer.v_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_V, "bias" , i), {n_embd}); + layer.q_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.q_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_Q, "bias" , i), {n_embd}); + + layer.ffn_up_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_UP, "bias" , i), {n_ff}); + layer.ffn_down_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_DOWN, "bias" , i), {n_embd}); + + layer.norm_in_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_INPUT_NORM, "weight", i), {n_embd}); + layer.norm_in_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_INPUT_NORM, "bias" , i), {n_embd}); + layer.norm_out_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "weight", i), {n_embd}); + layer.norm_out_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "bias" , i), {n_embd}); + + layer.output_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT, "weight", i), {n_embd, n_embd}); + layer.output_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT, "bias" , i), {n_embd}); + } + } break; + default: + throw std::runtime_error("unknown vision architecture"); } } @@ -19434,6 +19653,14 @@ struct llama_context * llama_new_context_with_model( } } + // initialize vision context + if (model->has_vision) { + ctx->clip.model = &model->clip; + ctx->clip.sched = ctx->sched; + const size_t max_nodes = llama_model_max_nodes(*model); + ctx->clip.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + } + return ctx; } @@ -19883,6 +20110,8 @@ int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { void llama_kv_cache_clear(struct llama_context * ctx) { llama_kv_cache_clear(ctx->kv_self); + // clear vision embeddings output + llama_vision_clear_output(ctx->clip); } bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { @@ -20999,9 +21228,48 @@ int32_t llama_encode( return ret; } +float * _test_get_img_embd(struct llama_context * ctx) { return ctx->clip.out_embd.data(); } int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { + // hacky vision implementation, for testing only + if (!ctx->clip.out_embd.empty()) { + // int8_t * logits = new int8_t [ctx->clip.out_pos.size()]; + // int32_t * n_seq_id = new int32_t [ctx->clip.out_pos.size()]; + // llama_seq_id ** seq_id = new llama_seq_id *[ctx->clip.out_pos.size()]; + // llama_seq_id seq_id_0 = 0; + // printf("out_pos %d\n", ctx->clip.out_pos.size()); + // llama_batch ibatch = { + // /*n_tokens =*/ static_cast(ctx->clip.out_pos.size()), + // /*tokens =*/ nullptr, + // /*embd =*/ ctx->clip.out_embd.data(), + // /*pos =*/ ctx->clip.out_pos.data(), + // /*n_seq_id =*/ n_seq_id, + // /*seq_id =*/ seq_id, + // /*logits =*/ logits, + // /*all_pos_0 =*/ 0, + // /*all_pos_1 =*/ 0, + // /*all_seq_id =*/ 0, + // }; + // for (size_t i = 0; i < ctx->clip.out_pos.size(); i++) { + // ibatch.n_seq_id[i] = 1; + // ibatch.seq_id [i] = &seq_id_0; + // ibatch.logits [i] = 0; + // } + // llama_decode_internal(*ctx, ibatch); + // delete[] logits; + // delete[] n_seq_id; + // delete[] seq_id; + // llama_vision_clear_output(ctx->clip); + + //int n_eval = ctx->clip.out_pos.size(); + //int n_past = ctx->clip.out_pos[0]; + //printf("n_eval %d, n_past %d\n", n_eval, n_past); + //llama_batch ibatch = {int32_t(n_eval), nullptr, ctx->clip.out_embd.data(), nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; + //llama_decode_internal(*ctx, ibatch); + //llama_vision_clear_output(ctx->clip); + } + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); @@ -21577,6 +21845,49 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } +// +// vision +// + +llama_img * llama_img_init(int width, int height) { + llama_img * img = new llama_img(); + img->nx = width; + img->ny = height; + if (width > 0 && height > 0) { + img->data = (unsigned char *)malloc(width*height*3); + } + return img; +} + +void llama_img_free(llama_img * img) { + if (img->data) free(img->data); + delete img; +} + +int32_t llama_img_n_tokens(struct llama_context * ctx, llama_img * img) { + GGML_UNUSED(img); // reserved for future usage + return clip_n_patches(ctx->clip); +} + +llama_batch_img llama_batch_img_init(int n_imgs) { + llama_batch_img batch; + batch.n_imgs = n_imgs; + if (n_imgs > 0) { + batch.imgs = (llama_img **)malloc(n_imgs*sizeof(llama_img *)); + batch.pos = (llama_pos * )malloc(n_imgs*sizeof(llama_pos )); + } + return batch; +} + +void llama_batch_img_free(llama_batch_img batch) { + if (batch.imgs) free(batch.imgs); + if (batch.pos ) free(batch.pos ); +} + +int32_t llama_encode_vision(struct llama_context * ctx, llama_batch_img batch) { + return llama_encode_vision_internal(ctx->clip, &batch); +} + // // model split //