Skip to content

Commit

Permalink
add support of Grok
Browse files Browse the repository at this point in the history
  • Loading branch information
Judd committed Mar 28, 2024
1 parent 4fdccf3 commit 912bacc
Show file tree
Hide file tree
Showing 6 changed files with 596 additions and 12 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
* Cohere (`CohereForCausalLM`)
* [x] [C4AI Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)

* Grok-1
* [x] [Base](https://huggingface.co/xai-org/grok-1)

About [Grok-1](./docs/grok.md).

* Text Embedding (`XLMRobertaModel`)
* [x] [BCE-Embedding](https://huggingface.co/maidalun1020/bce-embedding-base_v1)

Expand Down
258 changes: 256 additions & 2 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import struct
import sys
import io
import pickle
from pathlib import Path
from enum import Enum
from pathlib import Path
from typing import IO, Any, Iterable, List, Optional, Tuple
import numpy as np
import math
from attr import dataclass

import torch
from torch import nn
Expand Down Expand Up @@ -92,6 +94,8 @@ class ModelType(Enum):

CohereCommand = 0x1400

Grok1 = 0x1500

BCE_Embedding = 0x10000100
BCE_ReRanker = 0x10000101

Expand Down Expand Up @@ -205,7 +209,7 @@ def load_all_model_files(model_files) -> Dict:
r[k] = v
yield r

def dump_state_dict(f, weight_names, model_files, ggml_type, config, state_dict_pp):
def dump_state_dict(f, weight_names, model_files, ggml_type, config, state_dict_pp, loader_fun = None):
tensor_info = []
converted_names = []

Expand All @@ -214,7 +218,10 @@ def dump_state_dict(f, weight_names, model_files, ggml_type, config, state_dict_
state_dict_cache = {}
remaining: List = weight_names.copy()

for state_dict in load_all_model_files(model_files):
if loader_fun is None:
loader_fun = load_all_model_files

for state_dict in loader_fun(model_files):
this_round = {}
state_dict = state_dict_pp(config, state_dict)

Expand Down Expand Up @@ -2240,6 +2247,247 @@ def get_weight_names(config):
r = LlamaConverter.get_weight_names(config)
return r[:-1]

@dataclass
class QuantizedWeight8bit:
def __init__(self):
import jax
import jax.numpy as jnp
import jnp.array

self.weight: jnp.array
self.scales: jnp.array

@property
def shape(self):
return self.weight.shape

class Grok1Converter(BaseConverter):
MODEL_TYPE = ModelType.Grok1
tensor_map = []
file_to_name = {}
experts = []

@classmethod
def state_dict_pp(cls, config, state_dict):
new_dict = {}

for name in state_dict:
tensor: torch.Tensor = state_dict[name]
if name.endswith('embed_tokens.weight'):
new_dict['model.embed_tokens.weight'] = tensor * config.embedding_multiplier_scale
elif 'multi_head_attention' in name:
old_name = name.replace('multi_head_attention', 'self_attn')
if name.endswith('k_proj.weight'):
new_dict[old_name] = permute(tensor, config.num_key_value_heads)
elif name.endswith('q_proj.weight'):
new_dict[old_name] = permute(tensor, config.num_attention_heads)
else:
new_dict[old_name] = tensor
elif 'experts' in name:
new_dict[name] = tensor
else:
old_name = ''
mapping = {
'language_model.norm.weight': 'model.norm.weight',
'rms_norm.weight': 'rms_norm.weight',
'rms_norm_1.weight': 'rms_norm_1.weight',
'rms_norm_2.weight': 'rms_norm_2.weight',
'rms_norm_3.weight': 'rms_norm_3.weight',
'router.weight': 'router.weight',
}

for k in mapping.keys():
if name.endswith(k):
old_name = name.replace(k, mapping[k])
break

if old_name == '':
raise Exception(f'unhandled tensor {name}')

new_dict[old_name] = tensor

return new_dict

@staticmethod
def dump_config(f, config, ggml_type):
assert config.hidden_act == 'gelu', "hidden_act == 'gelu'"

config.hidden_act = 'silu'
LlamaConverter.dump_config(f, config, ggml_type)
config_values = [
config.num_key_value_heads,
config.num_experts,
config.num_selected_experts,
]
f.write(struct.pack("i" * len(config_values), *config_values))
f.write(struct.pack("<f", config.rope_theta))
f.write(struct.pack("<f", config.output_multiplier_scale))

@staticmethod
def get_weight_names(config):
weight_names = ["model.embed_tokens.weight"]
for i in range(config.num_hidden_layers):
for j in range(config.num_experts):
weight_names += [
f"model.layers.{i}.experts.{j}.w1.weight",
f"model.layers.{i}.experts.{j}.w2.weight",
f"model.layers.{i}.experts.{j}.w3.weight",
]

weight_names += [
f"model.layers.{i}.self_attn.k_proj.weight",
f"model.layers.{i}.self_attn.o_proj.weight",
f"model.layers.{i}.self_attn.q_proj.weight",
f"model.layers.{i}.self_attn.v_proj.weight",
f"model.layers.{i}.rms_norm.weight",
f"model.layers.{i}.rms_norm_1.weight",
f"model.layers.{i}.rms_norm_2.weight",
f"model.layers.{i}.rms_norm_3.weight",
f"model.layers.{i}.router.weight",
]

weight_names += [
"model.norm.weight",
]

return weight_names

@staticmethod
def load_tensor_file(tensor_name, fn) -> Any:
tensor_dict = {}
new_dict = {}

# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
def convert_weight(v):
dtype = torch.float32
if hasattr(v , 'scales'):
weight = torch.from_numpy(np.asarray(v.weight).astype(np.float32)).to(dtype)
scale =torch.from_numpy(np.asarray(v.scales).astype(np.float32)).to(dtype)
# row parallel layers have sharded scale
if len(scale.shape) >= 2 and scale.shape[-2] != 1:
scale = scale[..., None, :]
weight = weight.view(*weight.shape[:-2], 8, -1, weight.shape[-1])
weight = (weight * scale).view(*weight.shape[:-3], -1, weight.shape[-1])
else:
weight = weight * scale
else:
weight = torch.from_numpy(np.asarray(v).astype(np.float32)).to(dtype)

# Transpose linear matrix
if len(weight.shape) >= 2 and 'embed_tokens.weight' not in tensor_name:
weight = weight.transpose(-1, -2).contiguous()

if tensor_name.endswith('router.weight'):
new_dict[tensor_name] = weight[Grok1Converter.experts]
elif 'experts' not in tensor_name:
new_dict[tensor_name] = weight
else:
# split moe
for i in range(len(Grok1Converter.experts)):
new_key_i = tensor_name.replace('experts', f'experts.{i}')
new_dict[new_key_i] = weight[Grok1Converter.experts[i]]

with open(fn, 'rb') as f:
r = pickle.load(f)
tensor_dict[tensor_name] = r

convert_weight(r)

return new_dict

@staticmethod
def load_tensor_files(tensor_files) -> Dict:
for (t, f) in tensor_files:
print(f)
yield Grok1Converter.load_tensor_file(t, f)

@classmethod
def convert(cls, config, model_files_path, vocab: Any, ggml_type, save_path):

Grok1Converter.experts = config.experts

map = ['language_model.embed_tokens.weight',
'language_model.norm.weight']

# caution: alphabet order must not be changed!
for i in range(config.num_hidden_layers):
map += [
f"model.layers.{i}.experts.w1.weight",
f"model.layers.{i}.experts.w2.weight",
f"model.layers.{i}.experts.w3.weight",
f"model.layers.{i}.multi_head_attention.k_proj.weight",
f"model.layers.{i}.multi_head_attention.o_proj.weight",
f"model.layers.{i}.multi_head_attention.q_proj.weight",
f"model.layers.{i}.multi_head_attention.v_proj.weight",
f"model.layers.{i}.rms_norm.weight",
f"model.layers.{i}.rms_norm_1.weight",
f"model.layers.{i}.rms_norm_2.weight",
f"model.layers.{i}.rms_norm_3.weight",
f"model.layers.{i}.router.weight",
]

order = list(range(len(map)))
order.sort(key=lambda i: map[i])

for i in range(len(map)):
idx = order.index(i)
fn = model_files_path + f'/tensor{idx:05}_000'
info = (map[i], fn)
Grok1Converter.tensor_map.append(info)
Grok1Converter.file_to_name[fn] = map[i]

# convert all weights to fp16
with open(save_path, "wb") as f:
f.write(b"ggml") # magic
f.write(struct.pack("ii", cls.MODEL_TYPE.value, cls.FILE_VERSION))
Grok1Converter.dump_config(f, config, ggml_type)
vocab.write_vocab(f)

weight_names = Grok1Converter.get_weight_names(config)
dump_state_dict(f, weight_names, Grok1Converter.tensor_map, ggml_type, config, Grok1Converter.state_dict_pp, loader_fun=Grok1Converter.load_tensor_files)

print(f"{Grok1Converter.MODEL_TYPE.name} GGML model saved to {save_path}")

def convert_grok_1_base(args, vocab, ggml_type):
def ffn_size(emb_size, widening_factor):
_ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
return _ffn_size

grok1_config = {
'vocab_size': 128 * 1024,
'hidden_act': 'gelu',
'pad_token_id': 0,
'eos_token_id': 2,
'max_position_embeddings': 8192,
'output_multiplier_scale': 0.5773502691896257,
'embedding_multiplier_scale': 78.38367176906169,
'hidden_size': 48 * 128,
'intermediate_size': -1,
'num_attention_heads': 48,
'num_key_value_heads': 8,
'num_hidden_layers': 64,
'num_selected_experts': 2,
'rope_theta': 10000,
'attn_output_multiplier': 0.08838834764831845,
}

grok1_config['intermediate_size'] = ffn_size(grok1_config['hidden_size'], 8)

grok1_config['experts'] = list(range(8))
if args.experts != '':
grok1_config['experts'] = [int(x, 0) for x in args.experts.split(',')]

grok1_config['num_experts'] = len(grok1_config['experts'])

if grok1_config['num_experts'] < 2:
raise Exception(f"at least 2 experts")

print(f"experts to export: {grok1_config['experts']}")

Grok1Converter.convert(AttributeDict(grok1_config), args.model_name_or_path, vocab, ggml_type, args.save_path)
return

def load_vocab(path: Path) -> Any:

def load_spm(p: Path) -> Any:
Expand Down Expand Up @@ -2329,11 +2577,17 @@ def main():
parser.add_argument("-o", "--save_path", type=Path)
parser.add_argument("-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"])
parser.add_argument("--vocab_dir", type=str, default='')
parser.add_argument("--experts", type=str, default='')
args = parser.parse_args()

ggml_type = GGMLType[args.type.upper()]

vocab = load_vocab(Path(args.model_name_or_path) if args.vocab_dir == '' else Path(args.vocab_dir))

if args.arch.lower() == 'grok-1-base':
convert_grok_1_base(args, vocab, ggml_type)
return

model_files = load_some_model(Path(args.model_name_or_path))

#if args.lora_model_name_or_path is not None:
Expand Down
50 changes: 50 additions & 0 deletions docs/grok.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

# About Grok-1

Disclaimer: I am not sure if the implementation is correct or not, because I don't have enough compute resource to run the full model.

## Convert the model

To convert the base model, `jax` is needed:

```sh
pip install jax[cpu]

```

Download the [base model](https://huggingface.co/xai-org/grok-1) and [repository](https://github.com/xai-org/grok-1).

Use `convert.py` to convert it, for example quantized to Q4_0:

```sh
python convert.py -i /path/to/model/ckpt-0 --vocab_dir /path/to/repository -o grok.bin -a Grok-1-Base -t q4_0
```

**Bonus**: Use `--experts` to export a subset of experts, such as `--experts 0,1,2,3` for the first 4 experts.
The converted model will have less parameters but performance will degrade significantly.
At least 2 experts are required. Remember that `NUM_EXPERTS` in `grok.cpp` should be the actual number of experts.

## Test

Below is a test run with the first 4 experts:

```sh
./bin/main -m ../grok-1-4_q4_0.bin -i --temp 0 --max_length 1024

________ __ __ __ __ ___
/ ____/ /_ ____ _/ /_/ / / / / |/ /_________ ____
/ / / __ \/ __ `/ __/ / / / / /|_/ // ___/ __ \/ __ \
/ /___/ / / / /_/ / /_/ /___/ /___/ / / // /__/ /_/ / /_/ /
\____/_/ /_/\__,_/\__/_____/_____/_/ /_(_)___/ .___/ .___/
You are served by Grok-1, /_/ /_/
with 161064425472 (83.8B effect.) parameters.
You > what is your name?
A.I. >
what is your age?
what is your weight?
...
```
Loading

0 comments on commit 912bacc

Please sign in to comment.