diff --git a/common/common.cpp b/common/common.cpp index 1591790e6df4c..21003343e4740 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -789,6 +789,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.model = argv[i]; return true; } + if (arg == "-hl" || arg == "--hot-lora") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.hot_lora = argv[i]; + return true; + } if (arg == "-md" || arg == "--model-draft") { if (++i >= argc) { invalid_param = true; @@ -2435,6 +2443,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + const char* c_string = params.hot_lora.c_str(); + strncpy(cparams.hot_lora, c_string, sizeof(cparams.hot_lora) - 1); + cparams.hot_lora[sizeof(cparams.hot_lora) - 1] = '\0'; // Ensure null-termination + cparams.seed = params.seed; cparams.logits_all = params.logits_all; cparams.embeddings = params.embedding; diff --git a/common/common.h b/common/common.h index 2345d855eed3c..cd9d6370cf47f 100644 --- a/common/common.h +++ b/common/common.h @@ -100,6 +100,7 @@ struct gpt_params { std::string model = ""; // model path std::string model_draft = ""; // draft model for speculative decoding + std::string hot_lora = ""; // lora model path for hot swapping std::string model_alias = "unknown"; // model alias std::string model_url = ""; // model url to download std::string hf_repo = ""; // HF repo diff --git a/data/hot-lora.txt b/data/hot-lora.txt new file mode 100644 index 0000000000000..e88891d2f5eaf --- /dev/null +++ b/data/hot-lora.txt @@ -0,0 +1,2 @@ + +test data to train adapter \ No newline at end of file diff --git a/ggml.c b/ggml.c index 1fc77743bc7b9..a9cb2bc73b48e 100644 --- a/ggml.c +++ b/ggml.c @@ -4313,6 +4313,52 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam return NULL; } +//////// LORA + +struct lora_tensor_pair* build_lora_weights_map(struct ggml_context* ctx) { + struct lora_tensor_pair* pair = malloc(sizeof(struct lora_tensor_pair)); + if (!pair) return NULL; + pair->pairs = NULL; + pair->count = 0; + pair->capacity = 0; + + struct ggml_object * obj = ctx->objects_begin; + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TYPE_TENSOR) { + struct ggml_tensor * tensor = (struct ggml_tensor *)(mem_buffer + obj->offs); + char * tensor_name = tensor->name; + + if (strlen(tensor_name) > 6 && (strcmp(tensor_name + strlen(tensor_name) - 6, ".loraA") == 0 || + strcmp(tensor_name + strlen(tensor_name) - 6, ".loraB") == 0)) { + if (pair->count == pair->capacity) { + pair->capacity = pair->capacity > 0 ? pair->capacity * 2 : 4; + pair->pairs = realloc(pair->pairs, pair->capacity * sizeof(struct lora_tensor_info)); + } + + pair->pairs[pair->count].name = strdup(tensor_name); + pair->pairs[pair->count].tensor = tensor; + pair->count++; + } + } + obj = obj->next; + } + + return pair; +} + +void free_lora_tensor_pair(struct lora_tensor_pair* pair) { + if (!pair) return; + for (int i = 0; i < pair->count; i++) { + free(pair->pairs[i].name); + } + free(pair->pairs); + free(pair); +} + +//////// LORA + //////////////////////////////////////////////////////////////////////////////// // ggml_dup @@ -5285,6 +5331,7 @@ struct ggml_tensor * ggml_group_norm_inplace( return ggml_group_norm_impl(ctx, a, n_groups, true); } + // ggml_mul_mat struct ggml_tensor * ggml_mul_mat( diff --git a/ggml.h b/ggml.h index 13502a3622fc4..d843699084840 100644 --- a/ggml.h +++ b/ggml.h @@ -835,6 +835,25 @@ extern "C" { GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor); GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); + struct lora_tensor_info { + char* name; + struct ggml_tensor* tensor; + }; + + struct lora_tensor_pair { + struct lora_tensor_info* pairs; // Dynamic array of tensor pairs + int count; + int capacity; + }; + + // Function to build tensor pairs + struct lora_tensor_pair* build_lora_weights_map(struct ggml_context* ctx); + + // Cleanup function for lora_tensor_pair + void free_lora_tensor_pair(struct lora_tensor_pair* pair); + + + GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); diff --git a/llama.cpp b/llama.cpp index 8b675ea993a38..986dae59cc07e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -119,6 +119,256 @@ static void llama_log_callback_default(ggml_log_level level, const char * text, // helpers // +///////// LORA + + +struct lora_weights { + ggml_tensor* loraA; + ggml_tensor* loraB; +}; + +struct export_lora_params { + std::string fn_model_base; + std::string fn_model_out; + std::vector lora; + int n_threads; +}; + +static struct export_lora_params get_default_export_lora_params() { + struct export_lora_params result; + result.fn_model_base = ""; + result.fn_model_out = ""; + result.n_threads = GGML_DEFAULT_N_THREADS; + return result; +} + +struct lora_info { + std::string filename; + float scale; +}; +// TODO lora_data should maybe sub lora_weights +struct lora_data { + struct lora_info info; + std::vector data; + struct ggml_context * ctx; + // the backend to perform the computation (CPU, CUDA, METAL) + ggml_backend_t backend = NULL; + + // the backend buffer to storage the tensors data of a and b + ggml_backend_buffer_t buffer; + + uint32_t lora_r; + uint32_t lora_alpha; +}; + +struct llama_file_lora { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + llama_file_lora(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + size = 0; + } else { + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } + + void seek(size_t offset, int whence) { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + GGML_ASSERT(ret == 0); // same + } + + void read_raw(void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, size, 1, fp); + if (ferror(fp)) { + die_fmt("read error: %s", strerror(errno)); + } + if (ret != 1) { + die("unexpectedly reached end of file"); + } + } + + std::uint32_t read_u32() { + std::uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + std::string read_string(std::uint32_t len) { + std::vector chars(len); + read_raw(chars.data(), len); + return std::string(chars.data(), len); + } + + void write_raw(const void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, size, 1, fp); + if (ret != 1) { + die_fmt("write error: %s", strerror(errno)); + } + } + + void write_u32(std::uint32_t val) { + write_raw(&val, sizeof(val)); + } + + bool eof() { + return tell() >= size; + } + + ~llama_file_lora() { + if (fp) { + std::fclose(fp); + } + } +}; + +static void free_lora(struct lora_data * lora) { + if (lora->ctx != NULL) { + ggml_free(lora->ctx); + } + delete lora; +} + +static struct lora_data * load_lora(struct lora_info * info) { + struct lora_data * result = new struct lora_data; + result->info = *info; + result->ctx = NULL; + result->backend = NULL; + result->buffer = NULL; + result->lora_r = 1; + result->lora_alpha = 1; + +#ifdef GGML_USE_CUDA + fprintf(stderr, "%s: using CUDA backend\n", __func__); + result->backend = ggml_backend_cuda_init(0); // init device 0 + if (!result->backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } +#endif + +#ifdef GGML_USE_METAL + fprintf(stderr, "%s: using Metal backend\n", __func__); + result->backend = ggml_backend_metal_init(); + if (!result->backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } +#endif + + // if there aren't GPU Backends fallback to CPU backend + if (!result->backend) { + result->backend = ggml_backend_cpu_init(); + } + + + struct llama_file_lora file(info->filename.c_str(), "rb"); + if (file.fp == NULL) { + fprintf(stderr, "warning: Could not open lora adapter '%s'. Ignoring this adapter.\n", + info->filename.c_str()); + free_lora(result); + return NULL; + } + + struct ggml_init_params params_ggml; + params_ggml.mem_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE; + params_ggml.mem_buffer = NULL; + params_ggml.no_alloc = true; + result->ctx = ggml_init(params_ggml); + + uint32_t magic = file.read_u32(); + if (magic != LLAMA_FILE_MAGIC_GGLA) { + die_fmt("unexpected lora header file magic in '%s'", info->filename.c_str()); + } + uint32_t version = file.read_u32(); + if (version != 1) { + die_fmt("unexpected lora file version '%u' in '%s'", (unsigned) version, info->filename.c_str()); + } + result->lora_r = file.read_u32(); + result->lora_alpha = file.read_u32(); + // read tensor infos from file + std::vector name_buf; + std::vector tensors; + std::vector tensors_offset; + size_t total_nbytes_pad = 0; + while(!file.eof()) { + int64_t ne[4] = {1,1,1,1}; + uint32_t n_dims = file.read_u32(); + uint32_t namelen = file.read_u32(); + uint32_t type = file.read_u32(); + for (uint32_t k = 0; k < n_dims; ++k) { + ne[k] = (int64_t)file.read_u32(); + } + name_buf.clear(); + name_buf.resize(namelen + 1, '\0'); + file.read_raw(name_buf.data(), namelen); + file.seek((0-file.tell()) & 31, SEEK_CUR); + size_t offset = file.tell(); + struct ggml_tensor * tensor = ggml_new_tensor(result->ctx, (enum ggml_type) type, n_dims, ne); + ggml_set_name(tensor, name_buf.data()); + size_t nbytes = ggml_nbytes(tensor); + size_t nbytes_pad = ggml_nbytes_pad(tensor); + total_nbytes_pad += nbytes_pad; + tensors.push_back(tensor); + tensors_offset.push_back(offset); + file.seek(nbytes, SEEK_CUR); + } + + + + result->buffer = ggml_backend_alloc_ctx_tensors(result->ctx, result->backend); + if (!result->buffer) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for lora tensors\n", __func__); + } + // read tensor data + result->data.resize(total_nbytes_pad); + for (size_t i = 0; i < tensors.size(); ++i) { + struct ggml_tensor * tensor = tensors[i]; + size_t offset = tensors_offset[i]; + size_t nbytes = ggml_nbytes(tensor); + file.seek(offset, SEEK_SET); + std::vector read_buf; + read_buf.resize(nbytes); + file.read_raw(read_buf.data(), nbytes); + ggml_backend_tensor_set(tensor, read_buf.data(), 0, nbytes); + // Transpose lora matrix A + std::string original_name(tensor->name); + if (std::string(tensor->name).find(".loraA") != std::string::npos) { + tensor = ggml_cont(result->ctx, + ggml_transpose(result->ctx, tensor) + ); + ggml_set_name(tensor, original_name.c_str()); + } + } + return result; +} + +///////// LORA + static size_t utf8_len(char src) { const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; uint8_t highbits = static_cast(src) >> 4; @@ -2295,6 +2545,9 @@ struct llama_context { } llama_cparams cparams; + std::map lora_weights_map; // only one LoRA adapter at the moment + lora_data llama_lora_data; + float lora_scale = 1.0f; std::vector backends; #ifdef GGML_USE_METAL @@ -2370,6 +2623,37 @@ struct llama_context { struct llama_control_vector cvec; }; + + +static ggml_tensor * ggml_mul_mat_lora( + llama_context * lctx, + ggml_context * ctx0, + ggml_tensor * weight, + ggml_tensor * cur) { + ggml_tensor * mm = ggml_mul_mat(ctx0, weight, cur); + + auto it = lctx->lora_weights_map.find(weight->name); + if (it == lctx->lora_weights_map.end()) { + return mm; + } + + ggml_tensor * loraA = it->second.loraA; + ggml_tensor * loraB = it->second.loraB; + + ggml_tensor * t_lora = ggml_mul_mat(ctx0, + loraB, + ggml_mul_mat(ctx0, loraA, cur) + ); + + if (lctx->lora_scale != 1.0f) { + t_lora = ggml_scale(ctx0, t_lora, lctx->lora_scale); + } + + ggml_tensor * t_patch = ggml_add(ctx0, mm, t_lora); + return t_patch; + +} + static size_t llama_get_device_count(const llama_model & model) { size_t count = 1; #if defined(GGML_USE_CUDA) @@ -3712,7 +3996,7 @@ struct llama_model_loader { std::vector> read_buf; std::vector>> validation_result; - + // Allocate tensors data to buffer for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { const auto * weight = get_weight(ggml_get_name(cur)); if (weight == nullptr) { @@ -6770,8 +7054,9 @@ static struct ggml_tensor * llm_build_ffn( llm_ffn_op_type type_op, llm_ffn_gate_type type_gate, const llm_build_cb & cb, - int il) { - struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur) : cur; + int il, + struct llama_context * lctx = nullptr) { + struct ggml_tensor * tmp = up ? ggml_mul_mat_lora(lctx, ctx, up, cur) : cur; cb(tmp, "ffn_up", il); if (up_b) { @@ -6783,12 +7068,12 @@ static struct ggml_tensor * llm_build_ffn( switch (type_gate) { case LLM_FFN_SEQ: { - cur = ggml_mul_mat(ctx, gate, tmp); + cur = ggml_mul_mat_lora(lctx, ctx, gate, tmp); cb(cur, "ffn_gate", il); } break; case LLM_FFN_PAR: { - cur = ggml_mul_mat(ctx, gate, cur); + cur = ggml_mul_mat_lora(lctx, ctx, gate, cur); cb(cur, "ffn_gate", il); } break; } @@ -6836,7 +7121,7 @@ static struct ggml_tensor * llm_build_ffn( cb(cur, "ffn_gate_par", il); } - cur = ggml_mul_mat(ctx, down, cur); + cur = ggml_mul_mat_lora(lctx, ctx, down, cur); if (down_b) { cb(cur, "ffn_down", il); } @@ -7447,21 +7732,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = ggml_mul_mat_lora(&lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = ggml_mul_mat_lora(&lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = ggml_mul_mat_lora(&lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -7510,7 +7795,8 @@ struct llm_build_context { model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + LLM_FFN_SILU, LLM_FFN_PAR, cb, il, + &lctx); cb(cur, "ffn_out", il); } else { // MoE branch @@ -15870,6 +16156,7 @@ struct llama_context_params llama_context_default_params() { /*.n_seq_max =*/ 1, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.hot_lora =*/ "", /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.rope_freq_base =*/ 0.0f, @@ -16025,6 +16312,29 @@ void llama_free_model(struct llama_model * model) { delete model; } + +static std::map get_lora_weights_map(struct ggml_context* ctx) { + struct lora_tensor_pair* pair = build_lora_weights_map(ctx); + std::map map; + + if (pair) { + for (int i = 0; i < pair->count; i++) { + std::string name(pair->pairs[i].name); + std::string base_name = name.substr(0, name.size() - 6); + std::string suffix = name.substr(name.size() - 6); + + if (suffix == ".loraA") { + map[base_name].loraA = pair->pairs[i].tensor; + } else if (suffix == ".loraB") { + map[base_name].loraB = pair->pairs[i].tensor; + } + } + free_lora_tensor_pair(pair); + } + + return map; +} + struct llama_context * llama_new_context_with_model( struct llama_model * model, struct llama_context_params params) { @@ -16056,6 +16366,31 @@ struct llama_context * llama_new_context_with_model( llama_context * ctx = new llama_context(*model); + /// LORA load start + struct export_lora_params * lora_params = new struct export_lora_params; + struct lora_info lora; + lora.filename = params.hot_lora; + if (strlen(params.hot_lora) > 0) { + + lora.scale = 1.0f; // redundant as already inside lora_context, but should be here for multiple loras? + lora_params->lora.push_back(lora); + // load all loras (only 1 supported here) + std::vector loras; + for (size_t i = 0; i < lora_params->lora.size(); ++i) { + struct lora_data * llama_lora_data = load_lora(&lora_params->lora[i]); + if (llama_lora_data != NULL) { + loras.push_back(llama_lora_data); + } + } + if (loras.size() == 0) { + fprintf(stderr, "warning: no lora adapters will be applied.\n"); + } + // Assign data and get mapping (index 0 as only 1 lora is supoprted now) + ctx->llama_lora_data = *loras[0]; + ctx->lora_weights_map = get_lora_weights_map((ctx->llama_lora_data).ctx); + } + /// LORA load end + const auto & hparams = model->hparams; auto & cparams = ctx->cparams; diff --git a/llama.h b/llama.h index 62908261f2791..d593eb45c9dab 100644 --- a/llama.h +++ b/llama.h @@ -45,6 +45,9 @@ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 1 +#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) +#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) + #ifdef __cplusplus extern "C" { #endif @@ -289,6 +292,7 @@ extern "C" { uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing + char hot_lora[256]; // path to the hot lora file enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id