Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPT support in llama.cpp #3417

Merged
merged 17 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
b49792b
CUDA: added support for ggml_clamp (see also: https://github.com/gger…
jploski Sep 30, 2023
15236e8
mpt : added an implementation based (mostly) on falcon integration, m…
jploski Sep 30, 2023
84e30e8
mpt : protect against "clip_qkv": null in mpt-7b
jploski Sep 30, 2023
00e8c5c
mpt : quick fix to avoid "Strange model" warning when quantizing MPT …
jploski Sep 30, 2023
1be89c4
mpt : addendum to changeset:84e30e8 - leave parameter clamp_kqv out f…
jploski Sep 30, 2023
26c253e
mpt : standardized all tensor names to follow GGUF spec
jploski Sep 30, 2023
df072d2
mpt : addendum to changeset:1be89c40 - use "req" parameter of GGUF_GE…
jploski Sep 30, 2023
90e7d6d
mpt : fixed comment s/gptneox/mpt/
jploski Oct 2, 2023
4708012
mpt : remove tabs, trailing whitespace
jploski Oct 2, 2023
1364bcd
mpt : removed ne01 + n_past == ne00 assertion from alibi (cuda/f32) a…
jploski Oct 3, 2023
7d6a24a
mpt : updated convert-mpt-hf-to-gguf.py to reflect changes made to co…
jploski Oct 6, 2023
292363e
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
cebtenzzre Oct 9, 2023
ad3c2f3
comment out n_past instead of marking it unused
cebtenzzre Oct 9, 2023
1a454eb
mpt : removed hardcoded +178 from convert script in favor of utilizin…
jploski Oct 9, 2023
32172f1
mpt : remove unused tokenizer_json in convert script
cebtenzzre Oct 9, 2023
96cf3f5
ggml : remove obsolete n_past assert in ggml_alibi
ggerganov Oct 10, 2023
9b66378
llama : print clam_kqv and max_alibi_bias hparams
ggerganov Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions convert-mpt-hf-to-gguf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
#!/usr/bin/env python3
# HF gptneox--> gguf conversion
jploski marked this conversation as resolved.
Show resolved Hide resolved

from __future__ import annotations

import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any

import numpy as np
import torch
from transformers import AutoTokenizer # type: ignore[import]

if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf

# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py


def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
return dict(zip(bs, (chr(n) for n in cs)))
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved


def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
num_parts += 1

if num_parts > 0:
print("gguf: found " + str(num_parts) + " model parts")
return num_parts


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
)
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input",
)
parser.add_argument(
"model", type=Path,
help="directory containing model file, or model file itself (*.bin)",
)
parser.add_argument(
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
help="output format - use 0 for float32, 1 for float16",
)
return parser.parse_args()

args = parse_args()

dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)

# possible tensor data types
# ftype == 0 -> float32
# ftype == 1 -> float16

# map from ftype to string
ftype_str = ["f32", "f16"]

if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'

print("gguf: loading model "+dir_model.name)

with open(dir_model / "config.json", "r", encoding="utf-8") as f:
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
hparams = json.load(f)

if hparams["architectures"][0] != "MPTForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])

sys.exit()

# get number of model parts
num_parts = count_model_parts(dir_model)

ARCH=gguf.MODEL_ARCH.MPT
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])

print("gguf: get model metadata")

block_count = hparams["n_layers"]

gguf_writer.add_name(dir_model.name)
gguf_writer.add_context_length(hparams["max_seq_len"])
gguf_writer.add_embedding_length(hparams["d_model"])
gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
gguf_writer.add_head_count(hparams["n_heads"])
gguf_writer.add_layer_norm_eps(1e-05)
gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
jploski marked this conversation as resolved.
Show resolved Hide resolved
gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])

# TOKENIZATION

print("gguf: get tokenizer metadata")

tokens: list[bytearray] = []

tokenizer_json_file = dir_model / 'tokenizer.json'
if not tokenizer_json_file.is_file():
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
sys.exit(1)

# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")

with open(tokenizer_json_file, "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)

print("gguf: get gpt2 tokenizer vocab")

# MPT token embedding tensors have dimension 50432, but there are only 50254
# tokens in the vocab, presumably to accomodate some "reserved" tokens;
# this is causing problems down the line in llama.cpp, so we extend the vocab_size:

vocab_size = len(tokenizer_json["model"]["vocab"]) + 178
goerch marked this conversation as resolved.
Show resolved Hide resolved
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved

# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)

reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

for i in range(vocab_size):
if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
try:
text.append(byte_decoder[c])
except KeyError:
text.extend(c.encode('utf-8'))
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token. (It's normal for MPT.)")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)

tokens.append(text)

gguf_writer.add_token_list(tokens)

special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)

# TENSORS

tensor_map = gguf.get_tensor_name_map(ARCH,block_count)

# tensor info
print("gguf: get tensor metadata")

if num_parts == 0:
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)

for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")

for name in model_part.keys():
data = model_part[name]

old_dtype = data.dtype

# convert any unsupported data types to float32
if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float32)

data = data.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Cannot map tensor '" + name + "'")
continue # for the sake of compatibility with some old published models, don't quit
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))


# if new_name == "wte.weight" and data.shape[0] == 50432 and vocab_size == 50254:
# data = data[0:vocab_size,:]

gguf_writer.add_tensor(new_name, data)

# note: MPT output is tied to (same as) wte in original model;
# for easier implementation in llama.cpp it's duplicated in GGUF, though :/
if new_name == "wte.weight":
gguf_writer.add_tensor("output.weight", data)

print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()

print(f"gguf: model successfully exported to '{fname_out}'")
print("")
44 changes: 44 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
Expand Down Expand Up @@ -4555,6 +4556,16 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
dst[i] = scale * x[i];
}

static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}

dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}

static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
Expand Down Expand Up @@ -5436,6 +5447,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}

static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
}

template<typename T>
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
Expand Down Expand Up @@ -6353,6 +6369,24 @@ inline void ggml_cuda_op_scale(
(void) src1_dd;
}

inline void ggml_cuda_op_clamp(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

const float min = ((float *) dst->op_params)[0];
const float max = ((float *) dst->op_params)[1];

clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
CUDA_CHECK(cudaGetLastError());

(void) src1;
(void) dst;
(void) src1_dd;
}

static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
const int64_t nrows0 = ggml_nrows(src0);

Expand Down Expand Up @@ -6906,6 +6940,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
}

static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
}

static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
Expand Down Expand Up @@ -7330,6 +7368,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_scale;
break;
case GGML_OP_CLAMP:
if (!any_on_device) {
return false;
}
func = ggml_cuda_clamp;
break;
case GGML_OP_CPY:
if (!any_on_device) {
return false;
Expand Down
14 changes: 14 additions & 0 deletions gguf-py/gguf/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,19 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.MPT: {
MODEL_TENSOR.TOKEN_EMBD: "wte",
MODEL_TENSOR.OUTPUT_NORM: "norm_f",
# note: MPT output is tied to (same as) wte in original model;
# for easier implementation in llama.cpp it's duplicated in GGUF, though :/
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.norm_1",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.norm_2",
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn.Wqkv",
jploski marked this conversation as resolved.
Show resolved Hide resolved
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn.out_proj",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn.down_proj",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn.up_proj",
},
MODEL_ARCH.GPT2: {
# TODO
},
Expand Down Expand Up @@ -231,6 +244,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 falcon
"transformer.norm_f", # mpt
"model.norm", # llama-hf baichuan
"norm", # llama-pth
),
Expand Down
Loading