From 872876592c45ab3fb5a933a4f51ad114805bed78 Mon Sep 17 00:00:00 2001 From: Anas Ahouzi <112881240+aahouzi@users.noreply.github.com> Date: Fri, 15 Mar 2024 09:31:11 +0100 Subject: [PATCH] [Neural Speed] Enable StableLM2-1.6B & StableLM2-Zephyr-1.6B & StableLM-3B (#156) Co-authored-by: intellinjun --- docs/supported_models.md | 12 + neural_speed/__init__.py | 2 + neural_speed/application/CMakeLists.txt | 7 +- neural_speed/application/main_pybind.cpp | 6 +- neural_speed/application/whisper_pybind.cpp | 2 +- neural_speed/convert/__init__.py | 8 +- neural_speed/convert/convert_stablelm.py | 325 +++++++++++++ neural_speed/models/CMakeLists.txt | 1 + neural_speed/models/model_utils/model_types.h | 5 +- neural_speed/models/stablelm/stablelm.cpp | 428 ++++++++++++++++++ neural_speed/models/stablelm/stablelm.h | 53 +++ .../models/stablelm/stablelm_utils.cpp | 212 +++++++++ scripts/convert.py | 29 +- scripts/run.py | 20 +- tests/model-test/cpp_graph_inference.sh | 6 +- 15 files changed, 1095 insertions(+), 21 deletions(-) create mode 100644 neural_speed/convert/convert_stablelm.py create mode 100644 neural_speed/models/stablelm/stablelm.cpp create mode 100644 neural_speed/models/stablelm/stablelm.h create mode 100644 neural_speed/models/stablelm/stablelm_utils.cpp diff --git a/docs/supported_models.md b/docs/supported_models.md index 7db9c6b77..683cf8174 100644 --- a/docs/supported_models.md +++ b/docs/supported_models.md @@ -259,6 +259,18 @@ Neural Speed supports the following models: Latest 2048 + + + StableLM-3B, + StableLM2-1_6B + StableLM2-Zephyr-1_6B + ✅ + + + ✅ + + + Latest Whisper-tiny, diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py index 8116b2086..322a69dcd 100644 --- a/neural_speed/__init__.py +++ b/neural_speed/__init__.py @@ -69,6 +69,8 @@ def __import_package(self, model_type): import neural_speed.qwen_cpp as cpp_model elif model_type == "phi": import neural_speed.phi_cpp as cpp_model + elif model_type == "stablelm": + import neural_speed.stablelm_cpp as cpp_model elif model_type == "whisper": import neural_speed.whisper_cpp as cpp_model elif model_type == "mixtral": diff --git a/neural_speed/application/CMakeLists.txt b/neural_speed/application/CMakeLists.txt index 46a3c44cb..cde77862e 100644 --- a/neural_speed/application/CMakeLists.txt +++ b/neural_speed/application/CMakeLists.txt @@ -70,6 +70,7 @@ compile_quant(quant_mistral quant_model.cpp mistral llama) compile_quant(quant_mixtral quant_model.cpp mixtral llama) compile_quant(quant_qwen quant_model.cpp qwen qwen) compile_quant(quant_phi quant_model.cpp phi phi) +compile_quant(quant_stablelm quant_model.cpp stablelm stablelm) compile_quant(quant_whisper quant_whisper.cpp whisper whisper) # all models running @@ -93,8 +94,9 @@ set(mymap_polyglot 13) set(mymap_mistral 14) set(mymap_qwen 15) set(mymap_phi 16) -set(mymap_whisper 17) -set(mymap_mixtral 18) +set(mymap_stablelm 17) +set(mymap_whisper 18) +set(mymap_mixtral 19) @@ -131,6 +133,7 @@ compile_run(run_baichuan main_run.cpp main_pybind.cpp baichuan baichuan) compile_run(run_mistral main_run.cpp main_pybind.cpp mistral llama) compile_run(run_qwen main_run.cpp main_pybind.cpp qwen qwen) compile_run(run_phi main_run.cpp main_pybind.cpp phi phi) +compile_run(run_stablelm main_run.cpp main_pybind.cpp stablelm stablelm) compile_run(run_mixtral main_run.cpp main_pybind.cpp mixtral llama) # speech recognition diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp index 9149cd439..532df4b04 100644 --- a/neural_speed/application/main_pybind.cpp +++ b/neural_speed/application/main_pybind.cpp @@ -911,10 +911,14 @@ PYBIND11_MODULE(phi_cpp, m) #elif MODEL_NAME_ID == 17 -PYBIND11_MODULE(whisper_cpp, m) +PYBIND11_MODULE(stablelm_cpp, m) #elif MODEL_NAME_ID == 18 +PYBIND11_MODULE(whisper_cpp, m) + +#elif MODEL_NAME_ID == 19 + PYBIND11_MODULE(mixtral_cpp, m) #endif diff --git a/neural_speed/application/whisper_pybind.cpp b/neural_speed/application/whisper_pybind.cpp index 6e978ef16..c3e5b31f7 100644 --- a/neural_speed/application/whisper_pybind.cpp +++ b/neural_speed/application/whisper_pybind.cpp @@ -454,7 +454,7 @@ void Model::inference(const std::string& fname_inp) { return; } -#if MODEL_NAME_ID == 17 +#if MODEL_NAME_ID == 18 PYBIND11_MODULE(whisper_cpp, m) #endif diff --git a/neural_speed/convert/__init__.py b/neural_speed/convert/__init__.py index 3cc4f2301..40287b0f6 100644 --- a/neural_speed/convert/__init__.py +++ b/neural_speed/convert/__init__.py @@ -29,14 +29,12 @@ } -def convert_model(model, outfile, outtype="f32", model_hub="huggingface", use_quantized_model=False): +def convert_model(model, outfile, outtype="f32", format="NE", model_hub="huggingface", use_quantized_model=False): if model_hub == "modelscope": from modelscope import AutoConfig - config = AutoConfig.from_pretrained(model, trust_remote_code=True) else: from transformers import AutoConfig - config = AutoConfig.from_pretrained(model, trust_remote_code=True) - + config = AutoConfig.from_pretrained(model, trust_remote_code=True) model_type = model_maps.get(config.model_type, config.model_type) if use_quantized_model: @@ -47,6 +45,8 @@ def convert_model(model, outfile, outtype="f32", model_hub="huggingface", use_qu cmd.extend(["python", path]) cmd.extend(["--outfile", outfile]) cmd.extend(["--outtype", outtype]) + if model_type in {"phi", "stablelm"}: + cmd.extend(["--format", format]) cmd.extend(["--model_hub", model_hub]) cmd.extend([model]) diff --git a/neural_speed/convert/convert_stablelm.py b/neural_speed/convert/convert_stablelm.py new file mode 100644 index 000000000..af00ba4fc --- /dev/null +++ b/neural_speed/convert/convert_stablelm.py @@ -0,0 +1,325 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Convert Hugging Face fine-tuned gpt-neox-like models to ne format +# +# Usage: +# +# python3 models/convert-h5-to-ne.py +# +# This script is similar to "convert-pt-to-ne.py" +# +import os +import struct +import numpy as np +from pathlib import Path +import argparse +from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, + Union) +from transformers import AutoModelForCausalLM, AutoTokenizer +import gguf + +# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + +def stablelm_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams): + print("stablelm.gguf converting: ") + list_vars = model.state_dict() + n_rot = int(hparams["partial_rotary_factor"] * hparams["hidden_size"] / hparams["num_attention_heads"]) + for name in list_vars.keys(): + print(name, list_vars[name].shape, list_vars[name].dtype) + + print(hparams) + + gguf_file = fname_out + '.gguf' + gguf_writer = gguf.GGUFWriter(gguf_file, "stablelm") + + gguf_writer.add_uint32('magic', 0x67676d66) + gguf_writer.add_uint32('version', 1) + gguf_writer.add_uint32('n_vocab', hparams["vocab_size"]) + gguf_writer.add_embedding_length(hparams["hidden_size"]) + gguf_writer.add_head_count(hparams["num_attention_heads"]) + gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) + + gguf_writer.add_block_count(hparams["num_hidden_layers"]) + gguf_writer.add_rope_dimension_count(n_rot) + gguf_writer.add_uint32('ftype', ftype) + gguf_writer.add_context_length(hparams["max_position_embeddings"]) + gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + + gguf_writer.add_bos_token_id(hparams["bos_token_id"]) + gguf_writer.add_eos_token_id(hparams["eos_token_id"]) + gguf_writer.add_pad_token_id(hparams["pad_token_id"] if hparams["pad_token_id"] else 0) + gguf_writer.add_sep_token_id(hparams["sep_token_id"] if hparams["sep_token_id"] else 0) + + def write_vocab_gguf(dir_model, hparams, gguf_writer): + tokens: list[bytearray] = [] + toktypes: list[int] = [] + + tokenizer = AutoTokenizer.from_pretrained(dir_model) + vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) + assert max(tokenizer.vocab.values()) < vocab_size + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + added_vocab = tokenizer.get_added_vocab() + + for i in range(vocab_size): + if i not in reverse_vocab: + pad_token = f"[PAD{i}]".encode('utf-8') + tokens.append(bytearray(pad_token)) + toktypes.append(gguf.TokenType.USER_DEFINED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + if tokenizer.added_tokens_decoder[i].special: + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + gguf_writer.add_tokenizer_model("gpt2") + gguf_writer.add_token_list(tokens) + gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) + special_vocab.add_to_gguf(gguf_writer) + + write_vocab_gguf(dir_model, hparams, gguf_writer) + + # tensor info + print("gguf: get tensor metadata") + for name in list_vars.keys(): + data = list_vars[name].squeeze().numpy() + + print("Processing variable: " + name + " with shape: ", data.shape) + if 'inv_freq' in name: + continue + + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + gguf_writer.add_tensor(name, data) + + print("gguf: write header") + gguf_writer.write_header_to_file() + print("gguf: write metadata") + gguf_writer.write_kv_data_to_file() + print("gguf: write tensors") + gguf_writer.write_tensors_to_file() + + gguf_writer.close() + + print("Done. Output file: " + gguf_file) + print("") + +def stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): + n_rot = int(hparams["partial_rotary_factor"] * hparams["hidden_size"] / hparams["num_attention_heads"]) + model.eval() + for p in model.parameters(): + p.requires_grad = False + hparams = model.config.to_dict() + vocab_size = hparams["vocab_size"] + print("Model loaded: ", dir_model) + + fout = open(fname_out, "wb") + + # 0x67676d6c is unversioned ne + # 0x67676d66 is versioned ggmf (requires token scores) + ne_file_magic = 0x67676d66 + #ne_file_version = 0x00000001 # v1 + + fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex + fout.write(struct.pack("i", 1)) + fout.write(struct.pack("i", hparams["vocab_size"])) + fout.write(struct.pack("i", hparams["hidden_size"])) + fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", hparams["num_attention_heads"])) + fout.write(struct.pack("i", hparams["num_key_value_heads"])) # multi-query attention + fout.write(struct.pack("i", hparams["num_hidden_layers"])) + fout.write(struct.pack("i", n_rot)) + fout.write(struct.pack("i", ftype)) + fout.write(struct.pack("i", hparams["max_position_embeddings"])) + fout.write(struct.pack("f", 0.0)) + fout.write(struct.pack("f", 0.0)) + fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt) + fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt) + + fout.write(struct.pack("i", 0)) # multi_query_group_num + fout.write(struct.pack("i", hparams["intermediate_size"])) # ffn_hidden_size + fout.write(struct.pack("i", 0)) # inner_hidden_size for ChatGLM2 + + fout.write(struct.pack("i", 0)) # n_experts + fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps + fout.write(struct.pack("f", hparams["rope_theta"])) # freq_base + fout.write(struct.pack("f", 1.0)) # freq_scale, was removed in config.json (by default=1.0) + fout.write(struct.pack("f", 1.0)) # rope_scaling_factor, was removed in config.json (by default=1.0) + + fout.write(struct.pack("i", 0)) # original_max_position_embeddings + fout.write(struct.pack("i", 0)) # use_yarn + + fout.write(struct.pack("i", hparams["bos_token_id"])) + fout.write(struct.pack("i", hparams["eos_token_id"])) + fout.write(struct.pack("i", hparams["pad_token_id"] if hparams["pad_token_id"] else 0)) + fout.write(struct.pack("i", hparams["sep_token_id"] if hparams["sep_token_id"] else 0)) + + for i in range(vocab_size): + if i < vocab_size: + text = tokenizer.decode([i]).encode('utf-8') + fout.write(struct.pack("i", len(text))) + fout.write(text) + fout.write(struct.pack("f", 0.0 - i)) + else: + text = tokenizer.decode([vocab_size - 1]).encode('utf-8') + fout.write(struct.pack("i", len(text))) + fout.write(text) + fout.write(struct.pack("f", -10000)) + + list_vars = model.state_dict() + + print(hparams) + + for name in list_vars.keys(): + # No gradients for these + list_vars[name].requires_grad = False + src = name + print(src, ' -> ', name) + data = list_vars[src].squeeze().numpy() + data = data.astype(np.float32) + + n_dims = len(data.shape) + print(name, n_dims, data.shape) + + # default type is fp32 + ftype_cur = 0 + if ftype == 1 and n_dims > 1: + print(" Converting to float16", data.shape, data[:3, :3].tolist()) + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32", data.shape, data[:3, :3].tolist() if n_dims > 1 else data[:3].tolist()) + data = data.astype(np.float32) + + # header + str = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + print(str) + fout.write(str) + + # data + data.tofile(fout) + + fout.close() + + print("Done. Output file: " + fname_out) + print("") + +def main(args_in: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser(description="Convert a model to an NE or GGUF compatible file") + parser.add_argument( + "--outtype", + choices=["f32", "f16"], + help="output format (default: based on input)" + ) + parser.add_argument( + "--outfile", + type=Path, + help="path to write to; default: based on input" + ) + parser.add_argument( + "--model_hub", + choices=["huggingface","modelscope"], + default="huggingface", + help="hub to load model" + ) + parser.add_argument( + "--format", + type=str, + default="NE", + choices=["NE", "GGUF"], + help="convert to the GGUF or NE format" + ) + parser.add_argument( + "model", + type=Path, + help="directory containing model file" + ) + + args = parser.parse_args(args_in) + + dir_model = args.model.as_posix() + fname_out = args.outfile.as_posix() + + # possible data types + # ftype == 0 -> float32 + # ftype == 1 -> float16 + ftype = 0 + if args.outtype == "f16": + ftype = 1 + if args.model_hub == "modelscope": + from modelscope import AutoModelForCausalLM, AutoTokenizer + else: + from transformers import AutoModelForCausalLM, AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + print("Loading model: ", dir_model) + model = AutoModelForCausalLM.from_pretrained(dir_model, trust_remote_code=True) + hparams = model.config.to_dict() + if args.format == "GGUF": + stablelm_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams) + else: + stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams) + + + +if __name__ == '__main__': + main() diff --git a/neural_speed/models/CMakeLists.txt b/neural_speed/models/CMakeLists.txt index f62b64e41..a0bcf1f1a 100644 --- a/neural_speed/models/CMakeLists.txt +++ b/neural_speed/models/CMakeLists.txt @@ -35,3 +35,4 @@ add_model(whisper whisper/whisper.cpp whisper/whisper_utils.cpp ${MODEL_UTILS_SO add_model(chatglm chatglm/chatglm.cpp chatglm/chatglm_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(chatglm2 chatglm/chatglm2.cpp chatglm/chatglm2_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(phi phi/phi.cpp phi/phi_utils.cpp ${MODEL_UTILS_SOURCE}) +add_model(stablelm stablelm/stablelm.cpp stablelm/stablelm_utils.cpp ${MODEL_UTILS_SOURCE}) diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index 96eacb2e0..619130004 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -83,6 +83,7 @@ enum model_archs { MODEL_CHATGLM, MODEL_QWEN, MODEL_PHI, + MODEL_STABLELM, MODEL_WHISPER }; @@ -483,8 +484,8 @@ class model_name_to_arch { {"dolly", MODEL_GPTNEOX}, {"polyglot", MODEL_GPTNEOX}, {"starcoder", MODEL_STARCODER}, {"falcon", MODEL_FALCON}, {"bloom", MODEL_BLOOM}, {"chatglm2", MODEL_CHATGLM2}, {"chatglm", MODEL_CHATGLM}, {"baichuan", MODEL_BAICHUAN}, {"mistral", MODEL_LLAMA}, - {"qwen", MODEL_QWEN}, {"phi", MODEL_PHI}, {"whisper", MODEL_WHISPER}, - {"mixtral", MODEL_LLAMA}}; + {"qwen", MODEL_QWEN}, {"phi", MODEL_PHI}, {"stablelm", MODEL_STABLELM}, + {"whisper", MODEL_WHISPER}, {"mixtral", MODEL_LLAMA}}; }; #ifdef __cplusplus diff --git a/neural_speed/models/stablelm/stablelm.cpp b/neural_speed/models/stablelm/stablelm.cpp new file mode 100644 index 000000000..4b0dc9935 --- /dev/null +++ b/neural_speed/models/stablelm/stablelm.cpp @@ -0,0 +1,428 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/data_types.h" +#include "core/ne.h" +#include "core/ne_layers.h" +#include "core/ne_bestla.h" +#include "core/layers/mha_dense.h" +#include "models/model_utils/model_config.h" +#include "models/model_utils/model_utils.h" +#include "models/model_utils/util.h" + +// evaluate the transformer +// +// - lctx: model context +// - tokens: new batch of tokens to process +// - n_past: the offset to which the kv is cached to +// - n_total: the number of tokens evaluated so far (including evicted tokens if there is any) +// - n_threads: number of threads to use +// + +static bool stablelm_model_eval_internal(model_context* ctx, const model_input* inputs, const int n_input, + const int n_threads) { + const int64_t t_start_us = ne_time_us(); + model_context& lctx = *ctx; + + // static batching for now + const int N = inputs->n_tokens; + const int n_past = inputs->n_past; + const int n_total = inputs->n_total; + const bool shift_roped_k = lctx.shift_roped_k; + const bool is_ring_full = shift_roped_k && n_total > n_past; + NE_ASSERT(("Shift-RoPE-K to be implemented for the neox-mode RoPE!", !is_ring_full)); + const int batch_size = lctx.batch_size; + MODEL_ASSERT(batch_size == n_input); + const int kv_n_ctx_block = lctx.kv_n_ctx_block; + + const auto& model = lctx.model; + const auto& hparams = model.hparams; + + const auto& kv_self = model.kv_self; + + MODEL_ASSERT(!!kv_self.ctx); + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = lctx.n_ctx; + const int n_keep = lctx.n_keep; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + const int n_rot = hparams.n_rot; + const int head_dim = n_embd / n_head; + + auto& mem_per_token = lctx.mem_per_token; + auto& buf_compute = lctx.buf_compute; + + struct ne_init_params params = { + /*.mem_size =*/buf_compute.size, + /*.mem_buffer =*/buf_compute.addr, + /*.no_alloc =*/false, + }; + + struct ne_context* ctx0 = ne_init(params); + + // for big progptneoxs, if BLAS is enabled, it is better to use only one thread + // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance + ne_cgraph gf = {}; + gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads; + + const bool run_mha_reordered = kv_self.k->type == NE_TYPE_BTLA; + kv_cache_info_t kv_cache_info = {}; + if (run_mha_reordered) { + NE_ASSERT(("kv cache should be the same dtype", kv_self.v->type == NE_TYPE_BTLA)); + attn_shape_t attn_shape = { + /* .batch_size = */ 1, + /* .head_num = */ n_head, + /* .heads_kv = */ n_head, + /* .head_size = */ head_dim, + /* .sl_q = */ N, // Note: make sure that bestla reordered attn supports next token inference + /* .sl_kv = */ n_past + N, + }; + + NE_ASSERT(("bestla managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead", + bestla_reordered_attn_fp32_support(&attn_shape))); + kv_shape_t kv_shape{ + /* .heads_kv = */ static_cast(n_head), + /* .head_size = */ static_cast(head_dim), + /* .sl_kv_max = */ static_cast(n_ctx), + }; + bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); + } + struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size); + ne_set_name(embd, "embd"); + for (int i = 0; i < batch_size; ++i) { + memcpy(static_cast(embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd)); + } + + struct ne_tensor* inpL = ne_get_rows(ctx0, model.others[0], embd); + + for (int il = 0; il < n_layer; ++il) { + struct ne_tensor* cur; + + lctx.use_buf(ctx0, 0); + + { + // layer_norm + { + cur = ne_norm(ctx0, inpL, hparams.norm_eps); + + // cur = cur*attention_norm(broadcasted) + cur = ne_mul(ctx0, cur, model.layers[il].norm[0]); + cur = ne_add(ctx0, cur, model.layers[il].norm[1]); + } + + // Compute QKV + struct ne_tensor* Qcur; + struct ne_tensor* Kcur; + struct ne_tensor* Vcur; + if (n_layer == 24) { // Stablelm2 1.6B & Stablelm2 Zephyr 1.6B + Qcur = + ne_reshape_4d(ctx0, ne_add(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), model.layers[il].attn[1]), + head_dim, n_head, N, 1); + Kcur = + ne_reshape_4d(ctx0, ne_add(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), model.layers[il].attn[3]), + head_dim, n_head, N, 1); + Vcur = + ne_reshape_4d(ctx0, ne_add(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[4], cur), model.layers[il].attn[5]), + head_dim, n_head, N, 1); + } else { // Stablelm 3B + Qcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_dim, n_head, N, 1); + Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_dim, n_head, N, 1); + Vcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), head_dim, n_head, N, 1); + } + + // using mode = 2 for GPT-NeoX mode + struct ne_tensor* Qcur_Part = ne_view_4d(ctx0, ne_permute(ctx0, Qcur, 0, 2, 1, 3), n_rot, n_head, N, 1, + Qcur->nb[1], Qcur->nb[2], Qcur->nb[3], 0); + Qcur_Part = ne_rope_inplace(ctx0, Qcur_Part, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); + ne_build_forward_expand(&gf, Qcur_Part); + ne_set_name(Qcur, "Qcur"); + + struct ne_tensor* Kcur_Part = ne_view_4d(ctx0, ne_permute(ctx0, Kcur, 0, 2, 1, 3), n_rot, n_head, N, 1, + Kcur->nb[1], Kcur->nb[2], Kcur->nb[3], 0); + Kcur_Part = ne_rope_inplace(ctx0, Kcur_Part, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); + ne_build_forward_expand(&gf, Kcur_Part); + ne_set_name(Kcur, "kcur"); + const float attn_scale = 1.0f / sqrtf(static_cast(head_dim)); + + // store key and value to memory + ne_tensor* Kcur_temp; + if (!run_mha_reordered) { + { + std::vector Kcur_bs(batch_size); + std::vector Vcur_bs(batch_size); + std::vector k_bs(batch_size); + std::vector v_bs(batch_size); + for (int i = 0; i < batch_size; ++i) { + // batch K + Kcur_bs[i] = ne_permute(ctx0, + ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim, + ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N, + i * ne_element_size(Kcur) * n_embd * N), + 0, 2, 1, 3); + Kcur_temp = Kcur_bs[i]; + ne_set_name(Kcur_bs[i], "kcur_bs"); + k_bs[i] = ne_view_4d( + ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, + ((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block + + i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k))); + + // batch V + Vcur_bs[i] = ne_permute(ctx0, + ne_reshape_4d(ctx0, + ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd, + i * ne_element_size(Vcur) * n_embd * N), + head_dim, n_head, N, 1), + 1, 2, 0, 3); + v_bs[i] = + ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, + ((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block + + i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v))); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i])); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i])); + } + } + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3); + ne_set_name(Q, "Q"); + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + struct ne_tensor* K = + ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, + il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block); + ne_set_name(K, "K"); + // K * Q + struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ne_tensor* KQ_scaled = + ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(static_cast((n_embd) / n_head)))); + + // KQ_masked = mask_past(KQ_scaled) + struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ne_tensor* V = + ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, + il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block); + + // KQV = transpose(V) * KQ_soft_max + struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC)); + } else { + const auto seq_kv = n_past + N; + const auto k_size = kv_cache_info.k_bytes; + const auto v_size = kv_cache_info.v_bytes; + + // store key and value to memory + { + const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, n_ctx, n_head, // ne + 0, 0, // nb (bestla managed) + il * k_size); // offset + ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past, false)); + const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor + head_dim, n_ctx, n_head, // ne + 0, 0, // nb (bestla managed) + il * v_size); // offset + ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past, false)); + } + + struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); + ne_set_name(Q, "Q"); + + struct ne_tensor* K = + ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, seq_kv, n_head, // ne + kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (bestla managed) + il * k_size); // offset + *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout + ne_set_name(K, "K"); + struct ne_tensor* V = + ne_view_3d(ctx0, kv_self.v, // tensor + seq_kv, head_dim, n_head, // ne + kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (bestla managed) + il * v_size); // offset + *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout + ne_set_name(V, "V"); + + ne_attn_flags_t attn_flags = 0; + if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases + struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); + cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0); + } + + // projection + { + if (n_layer == 24) { // Stablelm2 1.6B & Stablelm2 Zephyr 1.6B + cur = ne_mul_mat(ctx0, model.layers[il].attn[6], cur); + } else { // Stablelm 3B + cur = ne_mul_mat(ctx0, model.layers[il].attn[3], cur); + } + } + } + lctx.use_buf(ctx0, 1); + + cur = ne_add(ctx0, cur, inpL); + inpL = cur; + + // FFN Block + { + // Post Attention norm + { + cur = ne_norm(ctx0, cur, hparams.norm_eps); + cur = ne_mul(ctx0, cur, model.layers[il].norm[2]); + cur = ne_add(ctx0, cur, model.layers[il].norm[3]); + } + + if (bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[0]->data, model.layers[il].ffn[1]->data, + model.layers[il].ffn[2]->data, N, cur->ne[0], + model.layers[il].ffn[0]->ne[1], model.layers[il].ffn[1]->ne[1])) { + cur = ne_ffn_silu(ctx0, model.layers[il].ffn[0], model.layers[il].ffn[1], model.layers[il].ffn[2], cur); + } else { + struct ne_tensor* tmp = ne_mul_mat(ctx0, model.layers[il].ffn[2], cur); + cur = ne_mul_mat(ctx0, model.layers[il].ffn[0], cur); + cur = ne_silu(ctx0, cur); + cur = ne_mul(ctx0, cur, tmp); + cur = ne_mul_mat(ctx0, model.layers[il].ffn[1], cur); + } + } + + // input for next layer + inpL = ne_add(ctx0, cur, inpL); + ne_set_name(inpL, "inpL"); + } + + lctx.use_buf(ctx0, 0); + // used at the end to optionally extract the embeddings + struct ne_tensor* embeddings = nullptr; + // norm + { + inpL = ne_norm(ctx0, inpL, hparams.norm_eps); + inpL = ne_add(ctx0, ne_mul(ctx0, inpL, model.others[1]), model.others[2]); + } + + // lm_head + inpL = ne_mul_mat(ctx0, model.others[3], inpL); + + lctx.use_buf(ctx0, -1); + + // logits -> probs + // inpL = ne_soft_max_inplace(ctx0, inpL); + + // run the computation + ne_build_forward_expand(&gf, inpL); + ne_graph_compute(ctx0, &gf); + + if (ns_log_level() == 0 || ns_log_level() == 2) { + ne_graph_profiling(&gf); + } + + // update kv token count + lctx.model.kv_self.n = n_past + N; + + // extract logits + { + auto& logits_out = lctx.logits; + + size_t bs_stride = n_vocab * N; + if (lctx.logits_all) { + logits_out.resize(n_vocab * N * batch_size); + for (int i = 0; i < batch_size; ++i) { + memcpy(logits_out.data() + i * bs_stride, reinterpret_cast(ne_get_data(inpL)) + (i * bs_stride), + sizeof(float) * n_vocab * N); + } + } else { + // return result for just the last token + logits_out.resize(n_vocab * batch_size); + for (int i = 0; i < batch_size; ++i) { + memcpy(logits_out.data() + (i * n_vocab), + reinterpret_cast(ne_get_data(inpL)) + (i * bs_stride) + (n_vocab * (N - 1)), + sizeof(float) * n_vocab); + } + } + } + + // extract embeddings + if (!lctx.embedding.empty()) { + auto& embedding_out = lctx.embedding; + + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), reinterpret_cast(ne_get_data(embeddings)) + (n_embd * (N - 1)), + sizeof(float) * n_embd); + } + + if (mem_per_token == 0) { + mem_per_token = ne_used_mem(ctx0) / N; + } + + ne_free(ctx0); + + // measure the performance only for the single-token evals + int64_t time_interval = ne_time_us() - t_start_us; + if (N == 1) { + lctx.t_eval_us += time_interval; + lctx.n_eval++; + } else if (N > 1) { + lctx.t_p_eval_us += time_interval; + lctx.n_p_eval += N; + } + lctx.eval_times.push_back(time_interval); + + return true; +} + +int model_eval(struct model_context* ctx, const model_input* inputs, const int n_input, int n_threads) { + if (!stablelm_model_eval_internal(ctx, inputs, n_input, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + // get a more accurate load time, upon first eval + + if (!ctx->has_evaluated_once) { + ctx->t_load_us = ne_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } + + return 0; +} \ No newline at end of file diff --git a/neural_speed/models/stablelm/stablelm.h b/neural_speed/models/stablelm/stablelm.h new file mode 100644 index 000000000..3df5b75cb --- /dev/null +++ b/neural_speed/models/stablelm/stablelm.h @@ -0,0 +1,53 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef STABLELM_H +#define STABLELM_H + +#include "models/model_utils/model_files.h" +#include "models/model_utils/model_types.h" + +enum stablelm_model { + STABLELM_UNKNOWN, + STABLELM_1_6B, + STABLELM_3B, +}; + +static const model_scratch stablelm_mem_req(int n_layers) { + switch (n_layers) { + case 24: + return {512ull * MB, 512ull * MB, 1026ull * MB}; // StableLM2-1.6B & StableLM2-Zephyr-1.6B + case 32: + return {1024ull * MB, 1024ull * MB, 1026ull * MB}; // StableLM-3B + default: + MODEL_ASSERT(false); + } +} + +class stablelm : public IModel { + private: + model_archs name = MODEL_STABLELM; + std::unique_ptr ml; + uint32_t n_layer, n_embd, n_ff, n_vocab; + int n_ctx, n_gpu_layer; + bool use_mmap, use_mlock, vocab_only; + model_scratch scratch; + + public: + void init(const char* path_model, model_context* ctx, int n_gpu_layers, bool use_mmap_, bool use_mlock_, + bool vocab_only_) override; + void load(model_context* ctx, model_progress_callback progress_callback, void* progress_callback_user_data) override; +}; + +#endif // STABLELM_H diff --git a/neural_speed/models/stablelm/stablelm_utils.cpp b/neural_speed/models/stablelm/stablelm_utils.cpp new file mode 100644 index 000000000..e12877995 --- /dev/null +++ b/neural_speed/models/stablelm/stablelm_utils.cpp @@ -0,0 +1,212 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/data_types.h" +#include "core/ne.h" +#include "core/ne_layers.h" +#include "models/stablelm/stablelm.h" +#include "models/model_utils/model_config.h" +#include "models/model_utils/model_files.h" +#include "models/model_utils/model_types.h" +#include "models/model_utils/quant_utils.h" +#include "models/model_utils/util.h" +#include "models/models.h" +void model_load_internal(const std::string& fname, model_archs arch, model_context* ctx, int n_gpu_layers, + bool use_mmap, bool use_mlock, bool vocab_only, model_progress_callback progress_callback, + void* progress_callback_user_data) { + std::unique_ptr ms(new stablelm()); + ms->init(fname.c_str(), ctx, n_gpu_layers, use_mmap, use_mlock, vocab_only); + ms->load(ctx, progress_callback, progress_callback_user_data); + model_context& lctx = *ctx; + lctx.support_bestla_kv = true; +} + +void stablelm::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bool use_mmap_, bool use_mlock_, + bool vocab_only_) { + model_context& lctx = *ctx; + n_gpu_layer = n_gpu_layer_; + use_mmap = use_mmap_; + use_mlock = use_mlock_; + vocab_only = vocab_only_; + auto& model = lctx.model; + ml.reset(new model_model_loader(path_model, use_mmap, vocab_only)); + lctx.vocab = std::move(ml->file_loaders.at(0)->vocab); + model.hparams = ml->file_loaders.at(0)->hparams; + model_file_version file_version = ml->file_loaders.at(0)->file_version; + auto& hparams = model.hparams; + n_ff = hparams.ffn_hidden_size; + fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); + fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); + fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); + fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size()); + fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); + fprintf(stderr, "%s: max_seq_len = %u\n", __func__, hparams.max_seq_len); + n_embd = hparams.n_embd; + n_vocab = hparams.n_vocab; + n_layer = hparams.n_layer; + n_embd = hparams.n_embd; + scratch = stablelm_mem_req(n_layer); + model.scratchs = scratch; +} + +#define MODEL_BACKEND_OFFLOAD NE_BACKEND_CPU +void stablelm::load(model_context* ctx, model_progress_callback progress_callback, void* progress_callback_user_data) { + model_context& lctx = *ctx; + auto& model = lctx.model; + auto& ne_ctx = model.ctx; + + size_t ctx_size; + size_t mmapped_size; + ml->calc_sizes(&ctx_size, &mmapped_size); + fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + + // create the ne context + lctx.model.buf.resize(ctx_size); + if (use_mlock) { + lctx.model.mlock_buf.init(lctx.model.buf.addr); + lctx.model.mlock_buf.grow_to(lctx.model.buf.size); + } + + struct ne_init_params params = { + /*.mem_size =*/lctx.model.buf.size, + /*.mem_buffer =*/lctx.model.buf.addr, + /*.no_alloc =*/ml->use_mmap, + }; + + model.ctx = ne_init(params); + if (!model.ctx) { + throw format("ne_init() failed"); + } + + ml->ne_ctx = ne_ctx; + + const int i_gpu_start = n_layer - n_gpu_layer; + model.layers.resize(n_layer); + size_t vram_total = 0; + + // Embedding layer + Normalization layer + lm_head layer + model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); + model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU); + model.others[2] = ml->get_tensor("model.norm.bias", {n_embd}, NE_BACKEND_CPU); + model.others[3] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, + n_gpu_layer > static_cast(n_layer) ? MODEL_BACKEND_OFFLOAD : NE_BACKEND_CPU); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD; + auto& layer = model.layers[i]; + std::string layers_i = "model.layers." + std::to_string(i); + + // Norm + layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend); + layer.norm[1] = ml->get_tensor(layers_i + ".input_layernorm.bias", {n_embd}, backend); + + // qkv GEMM + out proj GEMM + if (ml->verify_tensor(layers_i + ".self_attn.q_proj.bias")) { // Stablelm2 1.6B & Stablelm2 Zephyr 1.6B + layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend); + layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.q_proj.bias", {n_embd}, backend); + layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd}, backend); + layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd}, backend); + layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd}, backend); + layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd}, backend); + layer.attn[6] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); + } else { // Stablelm 3B + layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend); + layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd}, backend); + layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd}, backend); + layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); + } + + // Post Attention norm + layer.norm[2] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend); + layer.norm[3] = ml->get_tensor(layers_i + ".post_attention_layernorm.bias", {n_embd}, backend); + + // ffn GEMM + layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", {n_embd, n_ff}, backend); + layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", {n_ff, n_embd}, backend); + layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, n_ff}, backend); + + if (backend != NE_BACKEND_CPU) { + if (ml->verify_tensor(layers_i + ".self_attn.q_proj.bias")) { + vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.norm[2]) + + ne_nbytes(layer.norm[3]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.attn[4]) + + ne_nbytes(layer.attn[5]) + ne_nbytes(layer.attn[6]) + ne_nbytes(layer.ffn[0]) + + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); + } else { + vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.norm[2]) + + ne_nbytes(layer.norm[3]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.ffn[0]) + + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); + } + } + } + + // print memory requirements + // this is the total memory required to run the inference + const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory + scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); + + (void)n_gpu_layer; + + // populate `tensors_by_name` + for (model_load_tensor& lt : ml->tensors_map.tensors) { + model.tensors_by_name.emplace_back(lt.name, lt.ne_tensor); + } + + ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : nullptr); + + if (progress_callback) { + progress_callback(1.0f, progress_callback_user_data); + } + + model.mapping = std::move(ml->mapping); +} + +#undef MODEL_BACKEND_OFFLOAD +class stablelm_quant_layer : public quant_layer_base { + public: + quant_params_internal get_layer_config(std::string layername, std::vector ne, ne_type type) override { + bool quantize = layername.rfind("weight") == layername.size() - 6; // ends with 'weight'? + if (layername == "model.embed_tokens.weight") { + // special layer process, can be loaded by config file + return quant_params_internal(); // return q4_0 to cover the usage of getrow + } + quantize &= (ne.size() == 2); + if (quantize) { + return mGCfg; // use global quant config + } else { + return quant_params_internal{quant_bits::count}; // non-quant + } + } +}; +REGISTER_QUANT_LAYER_CLASS(stablelm); diff --git a/scripts/convert.py b/scripts/convert.py index ef8c5a780..bf1fa1be1 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -32,9 +32,30 @@ def main(args_in: Optional[List[str]] = None) -> None: type=str, help="Access token ID for models that require it (LLaMa2, etc..)", ) - parser.add_argument("--outfile", type=Path, required=True, help="path to write to") - parser.add_argument("model", type=Path, help="directory containing model file or model id") - parser.add_argument("--use_quantized_model", action="store_true", help="use quantized model: awq/gptq/autoround") + parser.add_argument( + "--outfile", + type=Path, + required=True, + help="path to write to" + ) + parser.add_argument( + "--format", + type=str, + default="NE", + choices=["NE", "GGUF"], + help="Convert to the GGUF or NE format" + ) + parser.add_argument( + "--use_quantized_model", + action="store_true", + help="use quantized model: awq/gptq/autoround" + ) + parser.add_argument( + "model", + type=Path, + help="directory containing model file or model id" + ) + args = parser.parse_args(args_in) if args.model.exists(): @@ -47,7 +68,7 @@ def main(args_in: Optional[List[str]] = None) -> None: print("You are required to input an access token ID for {}, please add it in option --token or download model weights locally".format(args.model)) sys.exit(f"{e}") - convert_model(dir_model, args.outfile, args.outtype, use_quantized_model=args.use_quantized_model) + convert_model(dir_model, args.outfile, args.outtype, format=args.format, use_quantized_model=args.use_quantized_model) if __name__ == "__main__": diff --git a/scripts/run.py b/scripts/run.py index 631e0716e..4552136e8 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -74,6 +74,13 @@ def main(args_in: Optional[List[str]] = None) -> None: help="Data type of Gemm computation: int8/bf16/fp32 (default: int8)", default="int8", ) + parser.add_argument( + "--format", + type=str, + default="NE", + choices=["NE", "GGUF"], + help="Convert to the GGUF or NE format" + ) parser.add_argument( "--use_ggml", action="store_true", @@ -180,8 +187,10 @@ def main(args_in: Optional[List[str]] = None) -> None: # 1. convert path = Path(parent_path, "convert.py") + outfile = f"gguf_{model_type}_f32" if str(args.format) == "GGUF" else f"ne_{model_type}_f32.bin" convert_cmd = ["python", path] - convert_cmd.extend(["--outfile", Path(work_path, "ne_{}_f32.bin".format(model_type))]) + convert_cmd.extend(["--format", str(args.format)]) + convert_cmd.extend(["--outfile", Path(work_path, outfile)]) convert_cmd.extend(["--outtype", "f32"]) convert_cmd.append(dir_model) print("Convert model ...") @@ -189,12 +198,11 @@ def main(args_in: Optional[List[str]] = None) -> None: # 2. quantize path = Path(parent_path, "quantize.py") + quant_file = f"gguf_{model_type}_{args.weight_dtype}.gguf" if str(args.format) == "GGUF" else f"ne_{model_type}_{args.weight_dtype}.bin" quant_cmd = ["python", path] quant_cmd.extend(["--model_name", model_type]) - quant_cmd.extend(["--model_file", Path(work_path, "ne_{}_f32.bin".format(model_type))]) - quant_cmd.extend( - ["--out_file", - Path(work_path, "ne_{}_{}.bin".format(model_type, args.weight_dtype, args.group_size))]) + quant_cmd.extend(["--model_file", Path(work_path, outfile + ".gguf" if str(args.format) == "GGUF" else outfile)]) + quant_cmd.extend(["--out_file", Path(work_path, quant_file)]) quant_cmd.extend(["--weight_dtype", args.weight_dtype]) quant_cmd.extend(["--group_size", str(args.group_size)]) quant_cmd.extend(["--scale_dtype", args.scale_dtype]) @@ -210,7 +218,7 @@ def main(args_in: Optional[List[str]] = None) -> None: path = Path(parent_path, "inference.py") infer_cmd = ["python", path] infer_cmd.extend(["--model_name", model_type]) - infer_cmd.extend(["-m", Path(work_path, "ne_{}_{}.bin".format(model_type, args.weight_dtype, args.group_size))]) + infer_cmd.extend(["-m", Path(work_path, quant_file)]) infer_cmd.extend(["--prompt", args.prompt]) infer_cmd.extend(["--file", args.file]) infer_cmd.extend(["--n_predict", str(args.n_predict)]) diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh index fe416a3e1..46971447e 100644 --- a/tests/model-test/cpp_graph_inference.sh +++ b/tests/model-test/cpp_graph_inference.sh @@ -155,10 +155,10 @@ model_name_map["qwen-7b"]="Qwen/Qwen-7B-Chat" model_name_map["magicoder"]="ise-uiuc/Magicoder-S-DS-6.7B" model_name_map["whisper"]="openai/whisper-tiny" model_name_map["phi2"]="microsoft/phi-2" +model_name_map["stablelm"]="stabilityai/stablelm-2-1_6b" model_name_map["qwen-1_5"]="Qwen/Qwen1.5-7B-Chat" model_name_map["mixtral"]="mistralai/Mixtral-8x7B-Instruct-v0.1" - function main() { conda_env="$1" model="$2" @@ -270,6 +270,10 @@ function main() { quant_script="./build/bin/quant_phi" convert_script="${convert_script}/convert_phi.py" infer_cmd="./build/bin/run_phi" + elif [[ "${model}" == "stablelm" ]]; then + quant_script="./build/bin/quant_stablelm" + convert_script="${convert_script}/convert_stablelm.py" + infer_cmd="./build/bin/run_stablelm" elif [[ "${model}" == "mixtral" ]]; then quant_script="./build/bin/quant_mixtral" convert_script="${convert_script}/convert_mixtral.py"