From bff05e4754777ac63f845fc002016dbc9994a089 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Sat, 2 Mar 2024 11:47:15 +0300 Subject: [PATCH 01/15] Drop interleaved placement in QKV --- lit_gpt/lora.py | 45 +----- lit_gpt/model.py | 92 ++++++++----- scripts/convert_hf_checkpoint.py | 198 +++++++++++++-------------- scripts/convert_lit_checkpoint.py | 167 +++++++++++----------- tests/test_convert_hf_checkpoint.py | 57 ++++++++ tests/test_convert_lit_checkpoint.py | 42 ++++-- tests/test_lora.py | 4 +- tests/test_model.py | 8 +- 8 files changed, 335 insertions(+), 278 deletions(-) diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index bfc7adc122..0f679fb821 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -262,20 +262,13 @@ def __init__( # Compute the indices # Indices are needed to properly pad weight updates with zeros in `zero_pad` method. - q_per_kv = self.n_head // self.n_query_groups - total_qkv = q_per_kv + 2 - head_size = out_features // (self.n_query_groups * total_qkv) - ind = range(out_features) self.lora_ind = [] if enable_q: - q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2] - self.lora_ind.extend(q_ind) + self.lora_ind.extend(range(0, self.linear.in_features)) if enable_k: - k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2] - self.lora_ind.extend(k_ind) + self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size)) if enable_v: - v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1] - self.lora_ind.extend(v_ind) + self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features)) self.reset_parameters() def zero_pad(self, x: torch.Tensor) -> torch.Tensor: @@ -291,27 +284,6 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: ________________________________________ | query | key | value | ---------------------------------------- - For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped - queries are adjacent to their associated key and value weights. - For example, suppose we have n_head = 12 with 3 query groups. - Then along the embedding dimension the interleaved weights would look like - - [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V], - - where each Q, K, and V has size head_size. - - In this case, the previously-described weight update applies separately to each - individual block, so the update will take the form - - [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...], - [.............................................................................], - [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]] - ↑ ↑ ↑ ↑ ↑ ↑ - ________________________________________________________________________________ - | q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ... - -------------------------------------------------------------------------------- - Note that in the above diagram, the size of each q block will equal q_per_kv - times the size of each k and v block. Args: x: tensor with weights update that will be padded with zeros if necessary @@ -373,7 +345,8 @@ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1) return torch.cat( - [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T) + [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], + dim=1, # (B, C_output', T) ) # (B, C_output, T) def get_lora_AB(self) -> torch.Tensor: @@ -385,9 +358,7 @@ def get_lora_AB(self) -> torch.Tensor: lora = self.conv1d( self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).squeeze( - 0 - ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) + ).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) return self.zero_pad(lora * self.scaling) # (256, 128) after zero_pad (384, 128) def merge(self) -> None: @@ -426,9 +397,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: after_B = self.conv1d( after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).transpose( - -2, -1 - ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) + ).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384) return pretrained + lora diff --git a/lit_gpt/model.py b/lit_gpt/model.py index ed33664fa2..f87cb5910f 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -173,15 +173,17 @@ def forward( class CausalSelfAttention(nn.Module): def __init__(self, config: Config) -> None: super().__init__() - shape = (config.n_head + 2 * config.n_query_groups) * config.head_size - # key, query, value projections for all heads, but in a batch - self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) + # key, query and value projections for all heads, but in a batch + self.attn = nn.Linear( + config.n_embd, + (config.n_head + 2 * config.n_query_groups) * config.head_size, # support for grouped/multi queries + bias=config.bias, + ) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) - # disabled by default - self.kv_cache: Optional[KVCache] = None + self.kv_cache: Optional[KVCache] = None self.config = config def forward( @@ -192,46 +194,68 @@ def forward( mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - - qkv = self.attn(x) - - # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) - q_per_kv = self.config.n_head // self.config.n_query_groups - total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value - qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - - # split batched computation into three - q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - - # maybe repeat k and v if for the non multi-head attention cases - # training: flash attention requires it - # inference: multi-query would require a full kv cache so avoid it to limit its memory usage - if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): - k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) - v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) - - q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - + # Notation: + # - B | batch size + # - T | time-step (sequence length) + # - C | model's embeddings size (n_embd) + # - C* | attentions's embeddings size + # - nh_(q,k,v) | number of heads for query, key and value + # - hs | head size + + B, T, C = x.size() + + # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value` + # instead of individually multiplying the input `x` with the respective weight matrices. + qkv = self.attn(x) # (B, T, 3xC*) + + # Define query, key and value sizes. + # If grouped/multi query is enabled, these sizes are not equal (see the diagram in `lit_gpt/config.py::Config`). + query_size = self.config.n_head * self.config.head_size + key_size = value_size = self.config.n_query_groups * self.config.head_size + # Split qkv into query, key and value matrices. + q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) + + # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the + # embedding size (C) into num_heads (nh) and head_size (hs). + q = q.view(B, T, self.config.n_head, self.config.head_size) # (B, T, nh_q, hs) + k = k.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_k, hs) + v = v.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_v, hs) + + # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are + # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector + # of size `hs`. + q = q.transpose(1, 2) # (B, nh_q, T, hs) + k = k.transpose(1, 2) # (B, nh_k, T, hs) + v = v.transpose(1, 2) # (B, nh_v, T, hs) + + # Unlike standard positional embeddings rotary embeddings must be applied at every layer. q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) + q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) # (B, nh_q, T, hs) + k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) # (B, nh_k, T, hs) + # Apply kv-cache during inference. if input_pos is not None: if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) + # Grouped queries: balance the number of heads across all three matrices. + # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. + if self.config.n_query_groups != self.config.n_head and self.config.n_query_groups != 1: + q_per_kv = self.config.n_head // self.config.n_query_groups + k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + + # Efficient attention using Flash Attention CUDA kernels. + # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) y = self.scaled_dot_product_attention(q, k, v, mask) - y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side + # Re-assemble all head outputs side by side. + y = y.reshape(B, T, self.config.head_size * self.config.n_head) - # output projection - return self.proj(y) + # Output projection. + return self.proj(y) # (B, T, C) def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index b033a49346..51ab64f7e7 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -2,6 +2,7 @@ import gc import json +import re import sys from collections import defaultdict from dataclasses import asdict @@ -21,6 +22,7 @@ def copy_weights_gpt_neox( + config: Config, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, @@ -47,24 +49,23 @@ def copy_weights_gpt_neox( "gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight", "embed_out.weight": "lm_head.weight", } - - for name, param in hf_weights.items(): - if "gpt_neox.layers" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype) + for from_name, param in hf_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template] + if to_name is None: + continue + to_name = to_name.format(layer_idx) + param = load_param(param, from_name, dtype) + if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): + # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] + param = torch.cat(qkv_split(param, config)) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_falcon( - model_name: str, + config: Config, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, @@ -81,14 +82,14 @@ def copy_weights_falcon( "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size - if "7b" in model_name: + if "7b" in config.name: weight_map.update( { "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", } ) - elif "40b" in model_name or "180B" in model_name: + elif "40b" in config.name or "180B" in config.name: weight_map.update( { "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", @@ -100,13 +101,13 @@ def copy_weights_falcon( else: raise NotImplementedError - for name, param in hf_weights.items(): - if "transformer.h" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name].format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype) + for from_name, param in hf_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template].format(layer_idx) + param = load_param(param, from_name, dtype) + if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): + # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] + param = torch.cat(qkv_split(param, config)) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -122,15 +123,15 @@ def copy_weights_hf_llama( ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", - "model.layers.{}.input_layernorm.weight": "transformer.h.{l}.norm_1.weight", - "model.layers.{}.input_layernorm.bias": "transformer.h.{l}.norm_1.bias", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, - "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{l}.attn.proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, - "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{l}.norm_2.weight", - "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{l}.norm_2.bias", + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", "model.norm.weight": "transformer.ln_f.weight", "model.norm.bias": "transformer.ln_f.bias", "lm_head.weight": "lm_head.weight", @@ -138,43 +139,34 @@ def copy_weights_hf_llama( if config._mlp_class == "LLaMAMoE": weight_map.update( { - "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{l}.mlp.gate.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{l}.mlp.experts.{e}.fc_1.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight", + "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{}.mlp.gate.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{}.mlp.experts.{}.proj.weight", } ) elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { - "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight", - "model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight", - "model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight", + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", } ) else: raise NotImplementedError - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l = layer_template(name, 2) - e = None - if "block_sparse_moe.experts" in name: - from_name, e = layer_template(from_name, 5) - qkv = qkv_weights.setdefault(l, [None, None, None]) - if "q_proj" in name: - qkv[0] = param - elif "k_proj" in name: - qkv[1] = param - elif "v_proj" in name: - qkv[2] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l=l, e=e) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype) + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -182,25 +174,21 @@ def copy_weights_hf_llama( if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] - # convert separate q, k, v matrices into an interleaved qkv - for i, (q, k, v) in list(qkv_weights.items()): - if q is None or k is None or v is None: - # split across different .bin files - continue - q = load_param(q, f"layer {i} q", dtype) - k = load_param(k, f"layer {i} k", dtype) - v = load_param(v, f"layer {i} v", dtype) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv - del qkv_weights[i] + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # qkv is splitted across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype) + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + def copy_weights_phi( - config: Config, qkv_weights: dict, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], @@ -234,20 +222,17 @@ def copy_weights_phi( "lm_head.bias": "lm_head.bias", } - for name, param in hf_weights.items(): - if name.startswith("model.layers."): - from_name, l = layer_template(name, 2) - qkv = qkv_weights.setdefault(l, defaultdict(dict)) - if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): - weight_name, weight_type = from_name.split(".")[-2:] - qkv[weight_type][weight_name] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype) + for from_name, param in hf_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(layer_idx, defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(layer_idx) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -256,27 +241,22 @@ def copy_weights_phi( for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: - # split across different .bin files + # qkv is splitted across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) + qkv = torch.cat((q, k, v)) state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv del qkv_weights[i][weight_type] -def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: - split = layer_name.split(".") - number = int(split[idx]) - split[idx] = "{}" - from_name = ".".join(split) - return from_name, number +def layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]: + pattern = r"\.(\d+)\." + if not (search_res := re.findall(pattern, layer_name)): + return layer_name, -1 + layer_name_template = re.sub(pattern, ".{}.", layer_name, count=num_matches) + return layer_name_template, *(int(x) for x in search_res[:num_matches]) def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: @@ -290,6 +270,24 @@ def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: return param +def qkv_split( + param: Union[torch.Tensor, NotYetLoadedTensor], config: Config +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q_per_kv = config.n_head // config.n_query_groups + qs = [] + ks = [] + vs = [] + for chunk in torch.chunk(param, config.n_query_groups): + split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]) + qs.append(split[0]) + ks.append(split[1]) + vs.append(split[2]) + q = torch.cat(qs) + k = torch.cat(ks) + v = torch.cat(vs) + return q, k, v + + @torch.inference_mode() def convert_hf_checkpoint( *, @@ -309,7 +307,7 @@ def convert_hf_checkpoint( json.dump(config_dict, json_config) if "falcon" in model_name: - copy_fn = partial(copy_weights_falcon, model_name) + copy_fn = partial(copy_weights_falcon, config) elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} @@ -317,9 +315,9 @@ def convert_hf_checkpoint( elif "phi" in model_name: # holder to reconstitute the split q, k, v qkv_weights = {} - copy_fn = partial(copy_weights_phi, config, qkv_weights) + copy_fn = partial(copy_weights_phi, qkv_weights) else: - copy_fn = copy_weights_gpt_neox + copy_fn = partial(copy_weights_gpt_neox, config) # initialize a new empty state dict to hold our new weights sd = {} diff --git a/scripts/convert_lit_checkpoint.py b/scripts/convert_lit_checkpoint.py index 8a3b101a7d..30b079dc2c 100644 --- a/scripts/convert_lit_checkpoint.py +++ b/scripts/convert_lit_checkpoint.py @@ -19,7 +19,7 @@ def copy_weights_falcon( - model_name: str, + config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, @@ -35,14 +35,14 @@ def copy_weights_falcon( "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size - if "7b" in model_name: + if "7b" in config.name: weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", } ) - elif "40b" in model_name or "180B" in model_name: + elif "40b" in config.name or "180B" in config.name: weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", @@ -54,19 +54,22 @@ def copy_weights_falcon( else: raise NotImplementedError - for name, param in lit_weights.items(): - if "transformer.h" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name].format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + for from_name, param in lit_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template].format(layer_idx) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): + # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] + qs, ks, vs = qkv_split(param, config) + cycled = [t for group in zip(qs, ks, vs) for t in group] + param = torch.cat(cycled) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_gpt_neox( + config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, @@ -90,13 +93,15 @@ def copy_weights_gpt_neox( "lm_head.weight": "embed_out.weight", } - for name, param in lit_weights.items(): - if "transformer.h" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name].format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + for from_name, param in lit_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template].format(layer_idx) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): + # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] + qs, ks, vs = qkv_split(param, config) + cycled = [t for group in zip(qs, ks, vs) for t in group] + param = torch.cat(cycled) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -111,11 +116,11 @@ def copy_weights_llama( ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", - "transformer.h.{}.norm_1.weight": "model.layers.{l}.input_layernorm.weight", - "transformer.h.{}.norm_1.bias": "model.layers.{l}.input_layernorm.bias", - "transformer.h.{}.attn.proj.weight": "model.layers.{l}.self_attn.o_proj.weight", - "transformer.h.{}.norm_2.weight": "model.layers.{l}.post_attention_layernorm.weight", - "transformer.h.{}.norm_2.bias": "model.layers.{l}.post_attention_layernorm.bias", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias", "transformer.ln_f.weight": "model.norm.weight", "transformer.ln_f.bias": "model.norm.bias", "lm_head.weight": "lm_head.weight", @@ -123,48 +128,40 @@ def copy_weights_llama( if config._mlp_class == "LLaMAMoE": weight_map.update( { - "transformer.h.{}.mlp.gate.weight": "model.layers.{l}.block_sparse_moe.gate.weight", - "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight", - "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight", - "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight", + "transformer.h.{}.mlp.gate.weight": "model.layers.{}.block_sparse_moe.gate.weight", + "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.block_sparse_moe.experts.{}.w1.weight", + "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.block_sparse_moe.experts.{}.w3.weight", + "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{}.block_sparse_moe.experts.{}.w2.weight", } ) elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { - "transformer.h.{}.mlp.fc_1.weight": "model.layers.{l}.mlp.gate_proj.weight", - "transformer.h.{}.mlp.fc_2.weight": "model.layers.{l}.mlp.up_proj.weight", - "transformer.h.{}.mlp.proj.weight": "model.layers.{l}.mlp.down_proj.weight", + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", } ) else: raise NotImplementedError - for name, param in lit_weights.items(): - if name == "lm_head.weight" and untie_weights: + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: continue - if name.endswith(".attn.attn.weight"): - from_name, l = layer_template(name, 2) - q = "model.layers.{}.self_attn.q_proj.weight".format(l) - k = "model.layers.{}.self_attn.k_proj.weight".format(l) - v = "model.layers.{}.self_attn.v_proj.weight".format(l) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith(".attn.attn.weight"): + to_names = ( + "model.layers.{}.self_attn.q_proj.weight".format(*ids), + "model.layers.{}.self_attn.k_proj.weight".format(*ids), + "model.layers.{}.self_attn.v_proj.weight".format(*ids), + ) + params = [torch.cat(w) for w in qkv_split(param, config)] else: - if "transformer.h" in name: - from_name, l = layer_template(name, 2) - e = None - if "mlp.experts" in name: - from_name, e = layer_template(from_name, 5) - to_name = weight_map[from_name] - to_name = to_name.format(l=l, e=e) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -191,28 +188,22 @@ def copy_weights_phi( "lm_head.weight": "lm_head.weight", "lm_head.bias": "lm_head.bias", } - - for name, param in lit_weights.items(): - if name.endswith((".attn.attn.weight", ".attn.attn.bias")): - from_name, l = layer_template(name, 2) - weight_type = name.split(".")[-1] # weight or bias - q = f"model.layers.{l}.self_attn.q_proj.{weight_type}" - k = f"model.layers.{l}.self_attn.k_proj.{weight_type}" - v = f"model.layers.{l}.self_attn.v_proj.{weight_type}" - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + for from_name, param in lit_weights.items(): + name_template, layer_idx = layer_template(from_name) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): + weight_type = from_name.split(".")[-1] # weight or bias + to_names = ( + f"model.layers.{{}}.self_attn.q_proj.{weight_type}".format(layer_idx), + f"model.layers.{{}}.self_attn.k_proj.{weight_type}".format(layer_idx), + f"model.layers.{{}}.self_attn.v_proj.{weight_type}".format(layer_idx), + ) + params = [torch.cat(w) for w in qkv_split(param, config)] else: - if "transformer.h" in name: - from_name, l = layer_template(name, 2) - to_name = weight_map[from_name] - to_name = to_name.format(l) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(layer_idx),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -220,20 +211,18 @@ def copy_weights_phi( def qkv_split( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q_per_kv = config.n_head // config.n_query_groups - qs = [] - ks = [] - vs = [] - for chunk in torch.chunk(param, config.n_query_groups): - split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]) - qs.append(split[0]) - ks.append(split[1]) - vs.append(split[2]) - q = torch.cat(qs) - k = torch.cat(ks) - v = torch.cat(vs) - return q, k, v +) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]: + q, k, v = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) + qs = q.split(config.n_head // config.n_query_groups * config.head_size) + ks = k.split(config.head_size) + vs = v.split(config.head_size) + return qs, ks, vs def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: @@ -248,14 +237,14 @@ def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path config = Config.from_json(config_path) if "falcon" in config.name: - copy_fn = partial(copy_weights_falcon, config.name) + copy_fn = partial(copy_weights_falcon, config) elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): untie_weights = "Gemma" in config.name copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights) elif "phi" in config.name: copy_fn = partial(copy_weights_phi, config) else: - copy_fn = copy_weights_gpt_neox + copy_fn = partial(copy_weights_gpt_neox, config) # initialize a new empty state dict to hold our new weights sd = {} diff --git a/tests/test_convert_hf_checkpoint.py b/tests/test_convert_hf_checkpoint.py index 97d42b4f9e..76de12681a 100644 --- a/tests/test_convert_hf_checkpoint.py +++ b/tests/test_convert_hf_checkpoint.py @@ -120,3 +120,60 @@ def test_convert_hf_checkpoint(tmp_path): config = Config.from_json(tmp_path / "lit_config.json") assert isinstance(config, Config) + + +def test_qkv_split(): + from lit_gpt import Config + from scripts.convert_hf_checkpoint import qkv_split + + # MHA + config = Config(n_embd=4, n_head=4) + qkv = torch.tensor( + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23], + [24, 25, 26, 27], + [28, 29, 30, 31], + [32, 33, 34, 35], + [36, 37, 38, 39], + [40, 41, 42, 43], + [44, 45, 46, 47], + ] + ) + q, k, v = qkv_split(qkv, config) + torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [12, 13, 14, 15], [24, 25, 26, 27], [36, 37, 38, 39]])) + torch.testing.assert_close(k, torch.tensor([[4, 5, 6, 7], [16, 17, 18, 19], [28, 29, 30, 31], [40, 41, 42, 43]])) + torch.testing.assert_close(v, torch.tensor([[8, 9, 10, 11], [20, 21, 22, 23], [32, 33, 34, 35], [44, 45, 46, 47]])) + + # GQA + config = Config(n_embd=4, n_head=4, n_query_groups=2) + qkv = torch.tensor( + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23], + [24, 25, 26, 27], + [28, 29, 30, 31], + ] + ) + q, k, v = qkv_split(qkv, config) + torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [16, 17, 18, 19], [20, 21, 22, 23]])) + torch.testing.assert_close(k, torch.tensor([[8, 9, 10, 11], [24, 25, 26, 27]])) + torch.testing.assert_close(v, torch.tensor([[12, 13, 14, 15], [28, 29, 30, 31]])) + + # MQA + config = Config(n_embd=4, n_head=4, n_query_groups=1) + qkv = torch.tensor( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ) + q, k, v = qkv_split(qkv, config) + torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]])) + torch.testing.assert_close(k, torch.tensor([[16, 17, 18, 19]])) + torch.testing.assert_close(v, torch.tensor([[20, 21, 22, 23]])) diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 7e0fca8c63..f0b0e77ca1 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -60,7 +60,7 @@ def test_against_falcon_40b(): ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} - copy_to_theirs("40b", theirs_state_dict, ours_state_dict) + copy_to_theirs(ours_config, theirs_state_dict, ours_state_dict) theirs_model = FalconForCausalLM(theirs_config) # assign must be set to True for torch.testing.assert_close to pass @@ -100,7 +100,7 @@ def test_against_original_gpt_neox(): ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} - copy_to_theirs(theirs_state_dict, ours_state_dict) + copy_to_theirs(ours_config, theirs_state_dict, ours_state_dict) theirs_model = GPTNeoXForCausalLM(theirs_config) # strict=False because we don't save the rotary embeddings inv frequency keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) @@ -455,6 +455,8 @@ def test_check_conversion_supported_lora(): def test_qkv_split(): + from torch import tensor + from lit_gpt import Config from scripts.convert_lit_checkpoint import qkv_split @@ -477,9 +479,27 @@ def test_qkv_split(): ] ) q, k, v = qkv_split(qkv, config) - torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [12, 13, 14, 15], [24, 25, 26, 27], [36, 37, 38, 39]])) - torch.testing.assert_close(k, torch.tensor([[4, 5, 6, 7], [16, 17, 18, 19], [28, 29, 30, 31], [40, 41, 42, 43]])) - torch.testing.assert_close(v, torch.tensor([[8, 9, 10, 11], [20, 21, 22, 23], [32, 33, 34, 35], [44, 45, 46, 47]])) + torch.testing.assert_close( + q, (tensor([[0, 1, 2, 3]]), tensor([[4, 5, 6, 7]]), tensor([[8, 9, 10, 11]]), tensor([[12, 13, 14, 15]])) + ) + torch.testing.assert_close( + k, + ( + tensor([[16, 17, 18, 19]]), + tensor([[20, 21, 22, 23]]), + tensor([[24, 25, 26, 27]]), + tensor([[28, 29, 30, 31]]), + ), + ) + torch.testing.assert_close( + v, + ( + tensor([[32, 33, 34, 35]]), + tensor([[36, 37, 38, 39]]), + tensor([[40, 41, 42, 43]]), + tensor([[44, 45, 46, 47]]), + ), + ) # GQA config = Config(n_embd=4, n_head=4, n_query_groups=2) @@ -496,9 +516,9 @@ def test_qkv_split(): ] ) q, k, v = qkv_split(qkv, config) - torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [16, 17, 18, 19], [20, 21, 22, 23]])) - torch.testing.assert_close(k, torch.tensor([[8, 9, 10, 11], [24, 25, 26, 27]])) - torch.testing.assert_close(v, torch.tensor([[12, 13, 14, 15], [28, 29, 30, 31]])) + torch.testing.assert_close(q, (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), tensor([[8, 9, 10, 11], [12, 13, 14, 15]]))) + torch.testing.assert_close(k, (tensor([[16, 17, 18, 19]]), tensor([[20, 21, 22, 23]]))) + torch.testing.assert_close(v, (tensor([[24, 25, 26, 27]]), tensor([[28, 29, 30, 31]]))) # MQA config = Config(n_embd=4, n_head=4, n_query_groups=1) @@ -506,6 +526,6 @@ def test_qkv_split(): [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] ) q, k, v = qkv_split(qkv, config) - torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]])) - torch.testing.assert_close(k, torch.tensor([[16, 17, 18, 19]])) - torch.testing.assert_close(v, torch.tensor([[20, 21, 22, 23]])) + torch.testing.assert_close(q, (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]),)) + torch.testing.assert_close(k, (tensor([[16, 17, 18, 19]]),)) + torch.testing.assert_close(v, (tensor([[20, 21, 22, 23]]),)) diff --git a/tests/test_lora.py b/tests/test_lora.py index 88fe72c08c..0af64aa55d 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -102,7 +102,7 @@ def test_lora_mqa_gqa(): for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) - lora_ind = [0, 1, 6, 7, 12, 13, 18, 19, 4, 5, 10, 11, 16, 17, 22, 23] + lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] assert attn.linear.weight.shape == (24, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (16, 2) @@ -144,7 +144,7 @@ def test_lora_mqa_gqa(): for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) - lora_ind = [0, 1, 2, 3, 8, 9, 10, 11, 6, 7, 14, 15] + lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] assert attn.linear.weight.shape == (16, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (12, 2) diff --git a/tests/test_model.py b/tests/test_model.py index 0182dde9f1..c9d8214df2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -75,7 +75,7 @@ def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residua state_dict = {} theirs_model = GPTNeoXForCausalLM(theirs_config).to(device) # load the hf initialization into our model - copy_weights_gpt_neox(state_dict, theirs_model.state_dict()) + copy_weights_gpt_neox(ours_config, state_dict, theirs_model.state_dict()) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -135,7 +135,7 @@ def test_against_hf_falcon(kwargs, device, dtype): theirs_model = FalconForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} - copy_weights_falcon(kwargs["name"], state_dict, theirs_state_dict) + copy_weights_falcon(ours_config, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -311,7 +311,7 @@ def test_against_hf_phi_1_5(device, dtype): theirs_model = PhiForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} - copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_phi({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -373,7 +373,7 @@ def test_against_hf_phi_2(device, dtype): theirs_model = PhiForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} - copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_phi({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) From 0ed697fb46363af802b92899936bb69257bef264 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Tue, 5 Mar 2024 22:15:05 +0300 Subject: [PATCH 02/15] Update test for test_llama2_70b_conversion --- tests/test_convert_hf_checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_convert_hf_checkpoint.py b/tests/test_convert_hf_checkpoint.py index 76de12681a..1f49be5622 100644 --- a/tests/test_convert_hf_checkpoint.py +++ b/tests/test_convert_hf_checkpoint.py @@ -18,9 +18,9 @@ def test_llama2_70b_conversion(): "model.layers.0.mlp.up_proj.weight": (28672, 8192), "model.layers.0.post_attention_layernorm.weight": (8192,), "model.layers.0.self_attn.k_proj.weight": (1024, 8192), - "model.layers.0.self_attn.o_proj.weight": (8192, 8192), "model.layers.0.self_attn.q_proj.weight": (8192, 8192), "model.layers.0.self_attn.v_proj.weight": (1024, 8192), + "model.layers.0.self_attn.o_proj.weight": (8192, 8192), "model.layers.1.input_layernorm.weight": (8192,), "model.layers.1.mlp.down_proj.weight": (8192, 28672), "model.layers.1.mlp.gate_proj.weight": (28672, 8192), @@ -56,9 +56,9 @@ def test_llama2_70b_conversion(): weight_map = {k: torch.empty(s) for k, s in shapes.items()} copy_weights_hf_llama(config, qkv_weights, holder, weight_map) - # we are only testing 5 layers - assert len(qkv_weights) == 5 - # there are no loaded qkv weights + # NOTE: there are 5 layers, but only in the first layer we have `q`, `k` and `v` + assert len(qkv_weights) == 1 + # # there are no loaded qkv weights assert all(v is None for qkv in qkv_weights.values() for v in qkv) # the shapes are correct holder = {k: tuple(t.shape) for k, t in holder.items()} From b2864703024ce217f6ff1c432fc9aecd2f40a000 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 7 Mar 2024 17:41:17 +0300 Subject: [PATCH 03/15] Correct shapes for KV-cache --- lit_gpt/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lit_gpt/model.py b/lit_gpt/model.py index c0431913d2..20bcd43759 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -274,8 +274,7 @@ def build_kv_cache( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> "KVCache": - heads = 1 if self.config.n_query_groups == 1 else self.config.n_head - v_shape = (batch_size, heads, max_seq_length, self.config.head_size) + v_shape = (batch_size, self.config.n_query_groups, max_seq_length, self.config.head_size) if rope_cache_length is None: if self.config.rotary_percentage != 1.0: raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") @@ -283,7 +282,7 @@ def build_kv_cache( else: k_shape = ( batch_size, - heads, + self.config.n_query_groups, max_seq_length, rope_cache_length + self.config.head_size - self.config.rope_n_elem, ) From 60f2c93aece506d6c23757e0aff81e33699388f5 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 7 Mar 2024 18:39:26 +0300 Subject: [PATCH 04/15] Always do .repeat_interleave for grouped queries in training mode. --- lit_gpt/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lit_gpt/model.py b/lit_gpt/model.py index 20bcd43759..1e2458fbf5 100644 --- a/lit_gpt/model.py +++ b/lit_gpt/model.py @@ -241,8 +241,9 @@ def forward( k, v = self.kv_cache(input_pos, k, v) # Grouped queries: balance the number of heads across all three matrices. + # NOTE: flash attention requires it in training mode. # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. - if self.config.n_query_groups != self.config.n_head and self.config.n_query_groups != 1: + if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): q_per_kv = self.config.n_head // self.config.n_query_groups k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) From 0f41eb486ca098a35946fc7b2edc9043f15d04a6 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 8 Mar 2024 20:51:04 +0300 Subject: [PATCH 05/15] Test_convert_hf_checkpoint: test all branches --- tests/test_convert_hf_checkpoint.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_convert_hf_checkpoint.py b/tests/test_convert_hf_checkpoint.py index 1f49be5622..9838e09c63 100644 --- a/tests/test_convert_hf_checkpoint.py +++ b/tests/test_convert_hf_checkpoint.py @@ -101,16 +101,23 @@ def test_llama2_70b_conversion(): } -def test_convert_hf_checkpoint(tmp_path): +@pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2")) +def test_convert_hf_checkpoint(tmp_path, model_name): + import torch + from scripts.convert_hf_checkpoint import convert_hf_checkpoint with pytest.raises(ValueError, match="to contain .bin"): - convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name="pythia-14m") + convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) bin_file = tmp_path / "foo.bin" bin_file.touch() with mock.patch("scripts.convert_hf_checkpoint.lazy_load") as load: - convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name="pythia-14m") + # bypass if-statement for weight tying + if model_name == "Llama-2-7b-hf": + load.return_value = {"model.embed_tokens.weight": torch.rand((10, 10))} + convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) + # convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) load.assert_called_with(bin_file) assert {p.name for p in tmp_path.glob("*")} == {"foo.bin", "lit_config.json", "lit_model.pth"} From e25fb2c26962dfc0b6c6d0d3146f63413f013c56 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 8 Mar 2024 20:56:24 +0300 Subject: [PATCH 06/15] test_convert_lit_checkpoint check all branches --- tests/test_convert_lit_checkpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index f0b0e77ca1..3d6da71c4e 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -14,11 +14,12 @@ wd = Path(__file__).parent.parent.absolute() -def test_convert_lit_checkpoint(tmp_path): +@pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2")) +def test_convert_lit_checkpoint(tmp_path, model_name): from lit_gpt import GPT, Config from scripts.convert_lit_checkpoint import convert_lit_checkpoint - ours_config = Config.from_name("Llama-2-7b-hf", block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128) + ours_config = Config.from_name(model_name, block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "foo.ckpt" config_path = tmp_path / "foo.json" From 61c3265c54c8b7b314c77558b5bdfcfca8ec9728 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 8 Mar 2024 22:09:00 +0300 Subject: [PATCH 07/15] qkv_reassemble instead of qkv_split --- scripts/convert_hf_checkpoint.py | 8 +- scripts/convert_lit_checkpoint.py | 24 +++--- tests/test_convert_hf_checkpoint.py | 123 ++++++++++++++++++--------- tests/test_convert_lit_checkpoint.py | 77 ----------------- 4 files changed, 102 insertions(+), 130 deletions(-) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 51ab64f7e7..bf3e8046a9 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -58,7 +58,7 @@ def copy_weights_gpt_neox( param = load_param(param, from_name, dtype) if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] - param = torch.cat(qkv_split(param, config)) + param = reassemble_qkv(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -107,7 +107,7 @@ def copy_weights_falcon( param = load_param(param, from_name, dtype) if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] - param = torch.cat(qkv_split(param, config)) + param = reassemble_qkv(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -270,7 +270,7 @@ def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: return param -def qkv_split( +def reassemble_qkv( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: q_per_kv = config.n_head // config.n_query_groups @@ -285,7 +285,7 @@ def qkv_split( q = torch.cat(qs) k = torch.cat(ks) v = torch.cat(vs) - return q, k, v + return torch.cat((q, k, v)) @torch.inference_mode() diff --git a/scripts/convert_lit_checkpoint.py b/scripts/convert_lit_checkpoint.py index fca25d3f6d..08066b102c 100644 --- a/scripts/convert_lit_checkpoint.py +++ b/scripts/convert_lit_checkpoint.py @@ -60,9 +60,9 @@ def copy_weights_falcon( param = load_param(param, from_name, None) if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] - qs, ks, vs = qkv_split(param, config) - cycled = [t for group in zip(qs, ks, vs) for t in group] - param = torch.cat(cycled) + qs, ks, vs = qkv_split(param, config, split_into_heads=True) + interleaved = [t for group in zip(qs, ks, vs) for t in group] + param = torch.cat(interleaved) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -99,9 +99,9 @@ def copy_weights_gpt_neox( param = load_param(param, from_name, None) if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] - qs, ks, vs = qkv_split(param, config) - cycled = [t for group in zip(qs, ks, vs) for t in group] - param = torch.cat(cycled) + qs, ks, vs = qkv_split(param, config, split_into_heads=True) + interleaved = [t for group in zip(qs, ks, vs) for t in group] + param = torch.cat(interleaved) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -156,7 +156,7 @@ def copy_weights_llama( "model.layers.{}.self_attn.k_proj.weight".format(*ids), "model.layers.{}.self_attn.v_proj.weight".format(*ids), ) - params = [torch.cat(w) for w in qkv_split(param, config)] + params = qkv_split(param, config) else: to_names = (weight_map[name_template].format(*ids),) params = (param,) @@ -198,7 +198,7 @@ def copy_weights_phi( f"model.layers.{{}}.self_attn.k_proj.{weight_type}".format(layer_idx), f"model.layers.{{}}.self_attn.v_proj.{weight_type}".format(layer_idx), ) - params = [torch.cat(w) for w in qkv_split(param, config)] + params = qkv_split(param, config) else: to_names = (weight_map[name_template].format(layer_idx),) params = (param,) @@ -210,8 +210,10 @@ def copy_weights_phi( def qkv_split( - param: Union[torch.Tensor, NotYetLoadedTensor], config: Config -) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]: + param: Union[torch.Tensor, NotYetLoadedTensor], + config: Config, + split_into_heads: bool = False, +) -> Union[torch.Tensor, Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]]: q, k, v = param.split( ( config.n_head * config.head_size, @@ -219,6 +221,8 @@ def qkv_split( config.n_query_groups * config.head_size, ) ) + if not split_into_heads: + return q, k, v qs = q.split(config.n_head // config.n_query_groups * config.head_size) ks = k.split(config.head_size) vs = v.split(config.head_size) diff --git a/tests/test_convert_hf_checkpoint.py b/tests/test_convert_hf_checkpoint.py index 9838e09c63..776897284d 100644 --- a/tests/test_convert_hf_checkpoint.py +++ b/tests/test_convert_hf_checkpoint.py @@ -129,58 +129,103 @@ def test_convert_hf_checkpoint(tmp_path, model_name): assert isinstance(config, Config) -def test_qkv_split(): +def test_reassemble_qkv(): from lit_gpt import Config - from scripts.convert_hf_checkpoint import qkv_split + from scripts.convert_hf_checkpoint import reassemble_qkv # MHA config = Config(n_embd=4, n_head=4) - qkv = torch.tensor( + qkv_interleaved = torch.tensor( [ - [0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11], - [12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23], - [24, 25, 26, 27], - [28, 29, 30, 31], - [32, 33, 34, 35], - [36, 37, 38, 39], - [40, 41, 42, 43], - [44, 45, 46, 47], + [0, 1, 2, 3], # query + [16, 17, 18, 19], # key + [32, 33, 34, 35], # value + [4, 5, 6, 7], # query + [20, 21, 22, 23], # key + [36, 37, 38, 39], # value + [8, 9, 10, 11], # query + [24, 25, 26, 27], # key + [40, 41, 42, 43], # value + [12, 13, 14, 15], # query + [28, 29, 30, 31], # key + [44, 45, 46, 47], # value ] ) - q, k, v = qkv_split(qkv, config) - torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [12, 13, 14, 15], [24, 25, 26, 27], [36, 37, 38, 39]])) - torch.testing.assert_close(k, torch.tensor([[4, 5, 6, 7], [16, 17, 18, 19], [28, 29, 30, 31], [40, 41, 42, 43]])) - torch.testing.assert_close(v, torch.tensor([[8, 9, 10, 11], [20, 21, 22, 23], [32, 33, 34, 35], [44, 45, 46, 47]])) + qkv = reassemble_qkv(qkv_interleaved, config) + torch.testing.assert_close( + qkv, + torch.tensor( + [ + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # key + [24, 25, 26, 27], # key + [28, 29, 30, 31], # key + [32, 33, 34, 35], # value + [36, 37, 38, 39], # value + [40, 41, 42, 43], # value + [44, 45, 46, 47], # value + ] + ), + ) # GQA config = Config(n_embd=4, n_head=4, n_query_groups=2) - qkv = torch.tensor( + qkv_interleaved = torch.tensor( [ - [0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11], - [12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23], - [24, 25, 26, 27], - [28, 29, 30, 31], + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [16, 17, 18, 19], # key + [24, 25, 26, 27], # value + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [20, 21, 22, 23], # key + [28, 29, 30, 31], # value ] ) - q, k, v = qkv_split(qkv, config) - torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [16, 17, 18, 19], [20, 21, 22, 23]])) - torch.testing.assert_close(k, torch.tensor([[8, 9, 10, 11], [24, 25, 26, 27]])) - torch.testing.assert_close(v, torch.tensor([[12, 13, 14, 15], [28, 29, 30, 31]])) + qkv = reassemble_qkv(qkv_interleaved, config) + torch.testing.assert_close( + qkv, + torch.tensor( + [ + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # key + [24, 25, 26, 27], # value + [28, 29, 30, 31], # value + ] + ), + ) - # MQA + # # MQA config = Config(n_embd=4, n_head=4, n_query_groups=1) - qkv = torch.tensor( - [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + qkv_interleaved = torch.tensor( + [ + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # value + ], + ) + qkv = reassemble_qkv(qkv_interleaved, config) + torch.testing.assert_close( + qkv, + torch.tensor( + [ + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # value + ] + ), ) - q, k, v = qkv_split(qkv, config) - torch.testing.assert_close(q, torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]])) - torch.testing.assert_close(k, torch.tensor([[16, 17, 18, 19]])) - torch.testing.assert_close(v, torch.tensor([[20, 21, 22, 23]])) diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 3d6da71c4e..e213cddc9f 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -453,80 +453,3 @@ def test_check_conversion_supported_lora(): lit_weights = {"some.key.name": ANY, "error.key.lora": ANY} with pytest.raises(ValueError, match=r"LoRA.*cannot be converted"): check_conversion_supported(lit_weights=lit_weights) - - -def test_qkv_split(): - from torch import tensor - - from lit_gpt import Config - from scripts.convert_lit_checkpoint import qkv_split - - # MHA - config = Config(n_embd=4, n_head=4) - qkv = torch.tensor( - [ - [0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11], - [12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23], - [24, 25, 26, 27], - [28, 29, 30, 31], - [32, 33, 34, 35], - [36, 37, 38, 39], - [40, 41, 42, 43], - [44, 45, 46, 47], - ] - ) - q, k, v = qkv_split(qkv, config) - torch.testing.assert_close( - q, (tensor([[0, 1, 2, 3]]), tensor([[4, 5, 6, 7]]), tensor([[8, 9, 10, 11]]), tensor([[12, 13, 14, 15]])) - ) - torch.testing.assert_close( - k, - ( - tensor([[16, 17, 18, 19]]), - tensor([[20, 21, 22, 23]]), - tensor([[24, 25, 26, 27]]), - tensor([[28, 29, 30, 31]]), - ), - ) - torch.testing.assert_close( - v, - ( - tensor([[32, 33, 34, 35]]), - tensor([[36, 37, 38, 39]]), - tensor([[40, 41, 42, 43]]), - tensor([[44, 45, 46, 47]]), - ), - ) - - # GQA - config = Config(n_embd=4, n_head=4, n_query_groups=2) - qkv = torch.tensor( - [ - [0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11], - [12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23], - [24, 25, 26, 27], - [28, 29, 30, 31], - ] - ) - q, k, v = qkv_split(qkv, config) - torch.testing.assert_close(q, (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]), tensor([[8, 9, 10, 11], [12, 13, 14, 15]]))) - torch.testing.assert_close(k, (tensor([[16, 17, 18, 19]]), tensor([[20, 21, 22, 23]]))) - torch.testing.assert_close(v, (tensor([[24, 25, 26, 27]]), tensor([[28, 29, 30, 31]]))) - - # MQA - config = Config(n_embd=4, n_head=4, n_query_groups=1) - qkv = torch.tensor( - [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] - ) - q, k, v = qkv_split(qkv, config) - torch.testing.assert_close(q, (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]),)) - torch.testing.assert_close(k, (tensor([[16, 17, 18, 19]]),)) - torch.testing.assert_close(v, (tensor([[20, 21, 22, 23]]),)) From 0e8b18fc368ab3a6973171d3e0f578868c448164 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Fri, 8 Mar 2024 22:51:00 +0300 Subject: [PATCH 08/15] convert_lit: test for qkv_reassemble --- lit_gpt/lora.py | 11 +-- scripts/convert_hf_checkpoint.py | 6 +- scripts/convert_lit_checkpoint.py | 37 +++++----- tests/test_chat.py | 2 +- tests/test_convert_hf_checkpoint.py | 19 +++-- tests/test_convert_lit_checkpoint.py | 102 +++++++++++++++++++++++++++ 6 files changed, 142 insertions(+), 35 deletions(-) diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index 0f679fb821..1065da43ac 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -345,8 +345,7 @@ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1) return torch.cat( - [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], - dim=1, # (B, C_output', T) + [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T) ) # (B, C_output, T) def get_lora_AB(self) -> torch.Tensor: @@ -358,7 +357,9 @@ def get_lora_AB(self) -> torch.Tensor: lora = self.conv1d( self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) + ).squeeze( + 0 + ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) return self.zero_pad(lora * self.scaling) # (256, 128) after zero_pad (384, 128) def merge(self) -> None: @@ -397,7 +398,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: after_B = self.conv1d( after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) + ).transpose( + -2, -1 + ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384) return pretrained + lora diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index bf3e8046a9..4e4cb09eab 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -58,7 +58,7 @@ def copy_weights_gpt_neox( param = load_param(param, from_name, dtype) if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] - param = reassemble_qkv(param, config) + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -107,7 +107,7 @@ def copy_weights_falcon( param = load_param(param, from_name, dtype) if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] - param = reassemble_qkv(param, config) + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -270,7 +270,7 @@ def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: return param -def reassemble_qkv( +def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: q_per_kv = config.n_head // config.n_query_groups diff --git a/scripts/convert_lit_checkpoint.py b/scripts/convert_lit_checkpoint.py index 08066b102c..05ed46e81c 100644 --- a/scripts/convert_lit_checkpoint.py +++ b/scripts/convert_lit_checkpoint.py @@ -4,7 +4,7 @@ import sys from functools import partial from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Union import torch from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor @@ -60,9 +60,7 @@ def copy_weights_falcon( param = load_param(param, from_name, None) if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] - qs, ks, vs = qkv_split(param, config, split_into_heads=True) - interleaved = [t for group in zip(qs, ks, vs) for t in group] - param = torch.cat(interleaved) + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -99,9 +97,7 @@ def copy_weights_gpt_neox( param = load_param(param, from_name, None) if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] - qs, ks, vs = qkv_split(param, config, split_into_heads=True) - interleaved = [t for group in zip(qs, ks, vs) for t in group] - param = torch.cat(interleaved) + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -156,7 +152,13 @@ def copy_weights_llama( "model.layers.{}.self_attn.k_proj.weight".format(*ids), "model.layers.{}.self_attn.v_proj.weight".format(*ids), ) - params = qkv_split(param, config) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: to_names = (weight_map[name_template].format(*ids),) params = (param,) @@ -198,7 +200,13 @@ def copy_weights_phi( f"model.layers.{{}}.self_attn.k_proj.{weight_type}".format(layer_idx), f"model.layers.{{}}.self_attn.v_proj.{weight_type}".format(layer_idx), ) - params = qkv_split(param, config) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: to_names = (weight_map[name_template].format(layer_idx),) params = (param,) @@ -209,11 +217,7 @@ def copy_weights_phi( state_dict[to_name] = param -def qkv_split( - param: Union[torch.Tensor, NotYetLoadedTensor], - config: Config, - split_into_heads: bool = False, -) -> Union[torch.Tensor, Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]]: +def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: q, k, v = param.split( ( config.n_head * config.head_size, @@ -221,12 +225,11 @@ def qkv_split( config.n_query_groups * config.head_size, ) ) - if not split_into_heads: - return q, k, v qs = q.split(config.n_head // config.n_query_groups * config.head_size) ks = k.split(config.head_size) vs = v.split(config.head_size) - return qs, ks, vs + interleaved = [t for group in zip(qs, ks, vs) for t in group] + return torch.cat(interleaved) def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: diff --git a/tests/test_chat.py b/tests/test_chat.py index a8fe346a49..df99db2576 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -112,7 +112,7 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te assert generate_mock.mock_calls == [ call(ANY, tensor_like, 128, temperature=2.0, top_k=2, stop_tokens=([tokenizer_mock.return_value.eos_id],)) ] - # # only the generated result is printed to stdout + # only the generated result is printed to stdout assert out.getvalue() == ">> Reply: foo bar baz\n" assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue() diff --git a/tests/test_convert_hf_checkpoint.py b/tests/test_convert_hf_checkpoint.py index 776897284d..665343d389 100644 --- a/tests/test_convert_hf_checkpoint.py +++ b/tests/test_convert_hf_checkpoint.py @@ -17,8 +17,8 @@ def test_llama2_70b_conversion(): "model.layers.0.mlp.gate_proj.weight": (28672, 8192), "model.layers.0.mlp.up_proj.weight": (28672, 8192), "model.layers.0.post_attention_layernorm.weight": (8192,), - "model.layers.0.self_attn.k_proj.weight": (1024, 8192), "model.layers.0.self_attn.q_proj.weight": (8192, 8192), + "model.layers.0.self_attn.k_proj.weight": (1024, 8192), "model.layers.0.self_attn.v_proj.weight": (1024, 8192), "model.layers.0.self_attn.o_proj.weight": (8192, 8192), "model.layers.1.input_layernorm.weight": (8192,), @@ -58,7 +58,7 @@ def test_llama2_70b_conversion(): # NOTE: there are 5 layers, but only in the first layer we have `q`, `k` and `v` assert len(qkv_weights) == 1 - # # there are no loaded qkv weights + # there are no loaded qkv weights assert all(v is None for qkv in qkv_weights.values() for v in qkv) # the shapes are correct holder = {k: tuple(t.shape) for k, t in holder.items()} @@ -117,7 +117,6 @@ def test_convert_hf_checkpoint(tmp_path, model_name): if model_name == "Llama-2-7b-hf": load.return_value = {"model.embed_tokens.weight": torch.rand((10, 10))} convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) - # convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) load.assert_called_with(bin_file) assert {p.name for p in tmp_path.glob("*")} == {"foo.bin", "lit_config.json", "lit_model.pth"} @@ -129,9 +128,9 @@ def test_convert_hf_checkpoint(tmp_path, model_name): assert isinstance(config, Config) -def test_reassemble_qkv(): +def test_qkv_reassemble(): from lit_gpt import Config - from scripts.convert_hf_checkpoint import reassemble_qkv + from scripts.convert_hf_checkpoint import qkv_reassemble # MHA config = Config(n_embd=4, n_head=4) @@ -151,7 +150,7 @@ def test_reassemble_qkv(): [44, 45, 46, 47], # value ] ) - qkv = reassemble_qkv(qkv_interleaved, config) + qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( qkv, torch.tensor( @@ -186,7 +185,7 @@ def test_reassemble_qkv(): [28, 29, 30, 31], # value ] ) - qkv = reassemble_qkv(qkv_interleaved, config) + qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( qkv, torch.tensor( @@ -203,7 +202,7 @@ def test_reassemble_qkv(): ), ) - # # MQA + # MQA config = Config(n_embd=4, n_head=4, n_query_groups=1) qkv_interleaved = torch.tensor( [ @@ -213,9 +212,9 @@ def test_reassemble_qkv(): [12, 13, 14, 15], # query [16, 17, 18, 19], # key [20, 21, 22, 23], # value - ], + ] ) - qkv = reassemble_qkv(qkv_interleaved, config) + qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( qkv, torch.tensor( diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index e213cddc9f..beff3d51c3 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -453,3 +453,105 @@ def test_check_conversion_supported_lora(): lit_weights = {"some.key.name": ANY, "error.key.lora": ANY} with pytest.raises(ValueError, match=r"LoRA.*cannot be converted"): check_conversion_supported(lit_weights=lit_weights) + + +def test_qkv_reassemble(): + from lit_gpt import Config + from scripts.convert_lit_checkpoint import qkv_reassemble + + # MHA + config = Config(n_embd=4, n_head=4) + qkv_interleaved = torch.tensor( + [ + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # key + [24, 25, 26, 27], # key + [28, 29, 30, 31], # key + [32, 33, 34, 35], # value + [36, 37, 38, 39], # value + [40, 41, 42, 43], # value + [44, 45, 46, 47], # value + ] + ) + qkv = qkv_reassemble(qkv_interleaved, config) + torch.testing.assert_close( + qkv, + torch.tensor( + [ + [0, 1, 2, 3], # query + [16, 17, 18, 19], # key + [32, 33, 34, 35], # value + [4, 5, 6, 7], # query + [20, 21, 22, 23], # key + [36, 37, 38, 39], # value + [8, 9, 10, 11], # query + [24, 25, 26, 27], # key + [40, 41, 42, 43], # value + [12, 13, 14, 15], # query + [28, 29, 30, 31], # key + [44, 45, 46, 47], # value + ] + ), + ) + + # GQA + config = Config(n_embd=4, n_head=4, n_query_groups=2) + qkv_interleaved = torch.tensor( + [ + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # key + [24, 25, 26, 27], # value + [28, 29, 30, 31], # value + ] + ) + qkv = qkv_reassemble(qkv_interleaved, config) + torch.testing.assert_close( + qkv, + torch.tensor( + [ + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [16, 17, 18, 19], # key + [24, 25, 26, 27], # value + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [20, 21, 22, 23], # key + [28, 29, 30, 31], # value + ] + ), + ) + + # MQA + config = Config(n_embd=4, n_head=4, n_query_groups=1) + qkv_interleaved = torch.tensor( + [ + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # value + ] + ) + qkv = qkv_reassemble(qkv_interleaved, config) + torch.testing.assert_close( + qkv, + torch.tensor( + [ + [0, 1, 2, 3], # query + [4, 5, 6, 7], # query + [8, 9, 10, 11], # query + [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # value + ] + ), + ) From b636a58cce0b6c610c94c1abc8070dc70b5539b6 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Sun, 17 Nov 2024 22:22:44 +0300 Subject: [PATCH 09/15] Fix the test --- tests/test_adapter.py | 2 +- tests/test_adapter_v2.py | 2 +- tests/test_convert_lit_checkpoint.py | 3 +-- tests/test_lora.py | 7 ++++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index da422f6288..a17e140796 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -345,7 +345,7 @@ def test_against_original_gemma_2(model_name, device, dtype): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index aec205155d..ec43a77eaa 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -297,7 +297,7 @@ def test_against_original_gemma_2(model_name): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = AdapterV2GPT(ours_config).to(device) keys = ours_model.load_state_dict(state_dict, strict=False) assert not keys.unexpected_keys diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index e15edb97ff..2168454f11 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -19,8 +19,6 @@ from transformers.models.phi.modeling_phi import PhiForCausalLM from transformers.models.phi3.configuration_phi3 import Phi3Config from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM -from litgpt import Config -from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble from litgpt import GPT, Config from litgpt.scripts.convert_lit_checkpoint import ( @@ -31,6 +29,7 @@ copy_weights_gpt_neox, copy_weights_llama, copy_weights_phi, + qkv_reassemble, ) from tests.conftest import RunIf diff --git a/tests/test_lora.py b/tests/test_lora.py index 8c4d00bfdc..58f7398fb6 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -835,11 +835,12 @@ def test_lora_model_fsdp_init(): def test_zero_pad_cpu_and_mocked_mps(): - in_features = 128 - out_features = 384 head_size = 64 n_head = 12 n_query_groups = 3 + in_features = 128 + kv_embed_dim = in_features // (n_head // n_query_groups) + out_features = in_features + 2 * kv_embed_dim enable_lora = [True, False, True] r = 4 @@ -855,7 +856,7 @@ def test_zero_pad_cpu_and_mocked_mps(): batch_size = 64 seq_len = 64 - embed_dim = 320 + embed_dim = 160 x = torch.randn(batch_size, seq_len, embed_dim) result_cpu = model.zero_pad(x) From 4371c068b237929528bce10c8f554076df758bb1 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Sun, 17 Nov 2024 23:31:44 +0300 Subject: [PATCH 10/15] Handle legacy checkpoints --- litgpt/adapter_v2.py | 9 ++- litgpt/lora.py | 10 +++- litgpt/model.py | 14 ++++- litgpt/scripts/convert_hf_checkpoint.py | 70 +++++++++++++----------- litgpt/scripts/convert_lit_checkpoint.py | 3 + tests/test_adapter.py | 26 ++++++++- tests/test_adapter_v2.py | 25 ++++++++- tests/test_lora.py | 38 ++++++++++++- tests/test_model.py | 26 ++++++++- 9 files changed, 179 insertions(+), 42 deletions(-) diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index f5e6069343..bec5ca12b2 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -21,6 +21,7 @@ from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention from litgpt.adapter import Config as BaseConfig from litgpt.model import KVCache +from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @@ -163,7 +164,7 @@ def __init__(self, config: Config, block_idx: int) -> None: nn.Module.__init__(self) shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) + self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) @@ -197,6 +198,12 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa # For compatibility with older checkpoints if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + + for attr in ("weight", "bias"): + key = f"{prefix}attn.linear.{attr}" + if key in state_dict: + state_dict[f"{prefix}qkv.linear.{attr}"] = qkv_reassemble(state_dict.pop(key), self.config) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/litgpt/lora.py b/litgpt/lora.py index daccdfe61d..5d5ff95c2e 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -58,6 +58,7 @@ from litgpt.model import Block as BaseBlock from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention from litgpt.model import KVCache +from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @@ -581,7 +582,7 @@ def __init__(self, config: Config, block_idx: int) -> None: nn.Module.__init__(self) shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = LoRAQKVLinear( + self.qkv = LoRAQKVLinear( in_features=config.n_embd, out_features=shape, r=config.lora_r, @@ -620,7 +621,14 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + + for attr in ("weight", "bias"): + key = f"{prefix}attn.linear.{attr}" + if key in state_dict: + state_dict[f"{prefix}qkv.linear.{attr}"] = qkv_reassemble(state_dict.pop(key), self.config) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/litgpt/model.py b/litgpt/model.py index 401a372a5e..d25f7d8dda 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -7,13 +7,14 @@ """ import math -from typing import Any, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn from typing_extensions import Self from litgpt.config import Config +from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble class GPT(nn.Module): @@ -251,7 +252,7 @@ class CausalSelfAttention(nn.Module): def __init__(self, config: Config, block_idx: int) -> None: super().__init__() # key, query and value projections for all heads, but in a batch - self.attn = nn.Linear( + self.qkv = nn.Linear( config.n_embd, (config.n_head + 2 * config.n_query_groups) * config.head_size, # support for grouped/multi queries bias=config.bias, @@ -403,6 +404,15 @@ def build_kv_cache( ) return KVCache(k_shape, v_shape, device=device, dtype=dtype) + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with legacy checkpoints.""" + for attr in ("weight", "bias"): + key = f"{prefix}attn.{attr}" + if key in state_dict: + state_dict[f"{prefix}qkv.{attr}"] = qkv_reassemble(state_dict.pop(key), self.config) + + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + class GptNeoxMLP(nn.Module): def __init__(self, config: Config) -> None: diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 7fb7c12a57..533e2f5ae9 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -2,17 +2,17 @@ import gc import json +import os import re from collections import defaultdict from functools import partial -import os from pathlib import Path from pprint import pprint from typing import Dict, List, Optional, Tuple, Union -from tqdm import tqdm import torch from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor +from tqdm import tqdm from litgpt.config import Config from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load, save_config @@ -377,22 +377,43 @@ def copy_weights_phi( pbar.update(progress_per_file) -def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: +# def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: +# """Reassemble from a normal to an interleaved placement in a QKV matrix. +# [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] +# """ +# q, k, v = param.split( +# ( +# config.n_head * config.head_size, +# config.n_query_groups * config.head_size, +# config.n_query_groups * config.head_size, +# ) +# ) +# qs = q.split(config.n_head // config.n_query_groups * config.head_size) +# ks = k.split(config.head_size) +# vs = v.split(config.head_size) +# interleaved = [t for group in zip(qs, ks, vs) for t in group] +# return torch.cat(interleaved) + +def qkv_reassemble( + param: Union[torch.Tensor, NotYetLoadedTensor], config: Config +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Reassemble from a normal to an interleaved placement in a QKV matrix. - [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] + [Q, K, V, Q, K, V, ...] --> [Q, Q, ..., K, K, ..., V, V, ...] """ - q, k, v = param.split( - ( - config.n_head * config.head_size, - config.n_query_groups * config.head_size, - config.n_query_groups * config.head_size, - ) - ) - qs = q.split(config.n_head // config.n_query_groups * config.head_size) - ks = k.split(config.head_size) - vs = v.split(config.head_size) - interleaved = [t for group in zip(qs, ks, vs) for t in group] - return torch.cat(interleaved) + q_per_kv = config.n_head // config.n_query_groups + qs = [] + ks = [] + vs = [] + for chunk in torch.chunk(param, config.n_query_groups): + split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]) + qs.append(split[0]) + ks.append(split[1]) + vs.append(split[2]) + q = torch.cat(qs) + k = torch.cat(ks) + v = torch.cat(vs) + return torch.cat((q, k, v)) + def layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]: @@ -418,23 +439,6 @@ def load_param( return param -def qkv_reassemble( - param: Union[torch.Tensor, NotYetLoadedTensor], config: Config -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q_per_kv = config.n_head // config.n_query_groups - qs = [] - ks = [] - vs = [] - for chunk in torch.chunk(param, config.n_query_groups): - split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]) - qs.append(split[0]) - ks.append(split[1]) - vs.append(split[2]) - q = torch.cat(qs) - k = torch.cat(ks) - v = torch.cat(vs) - return torch.cat((q, k, v)) - @torch.inference_mode() def convert_hf_checkpoint( diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index 3499a32a9c..65cf1cd194 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -293,6 +293,9 @@ def copy_weights_phi( def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: + """Reassemble from a normal to an interleaved placement in a QKV matrix. + [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] + """ q, k, v = param.split( ( config.n_head * config.head_size, diff --git a/tests/test_adapter.py b/tests/test_adapter.py index a17e140796..0c8b098710 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout +from copy import deepcopy from dataclasses import asdict from io import StringIO from unittest import mock @@ -19,10 +20,11 @@ import litgpt.adapter as gpt_adapter import litgpt.finetune.adapter as module import litgpt.model as gpt -from litgpt.adapter import GPT, Config, adapter_filter +from litgpt.adapter import GPT, CausalSelfAttention, Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -355,3 +357,25 @@ def test_against_original_gemma_2(model_name, device, dtype): ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config) + state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index ec43a77eaa..9fea6c1386 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout +from copy import deepcopy from io import StringIO from unittest import mock from unittest.mock import Mock @@ -19,11 +20,12 @@ import litgpt.config as config_module import litgpt.finetune.adapter_v2 as module from litgpt.adapter_v2 import GPT as AdapterV2GPT -from litgpt.adapter_v2 import Config, adapter_filter +from litgpt.adapter_v2 import CausalSelfAttention, Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.model import GPT as BaseGPT from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -465,3 +467,24 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp logs = stdout.getvalue() assert "of trainable parameters: 552" in logs assert "of non-trainable parameters: 1,808" in logs + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config) + state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) diff --git a/tests/test_lora.py b/tests/test_lora.py index 58f7398fb6..a763b43e28 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout +from copy import deepcopy from io import StringIO from itertools import product from unittest import mock @@ -23,10 +24,19 @@ from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.lora import GPT as LoRAGPT +from litgpt.lora import ( + CausalSelfAttention, + Config, + LoRALinear, + LoRAQKVLinear, + lora_filter, + mark_only_lora_as_trainable, + merge_lora_weights, +) from litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention -from litgpt.lora import Config, LoRALinear, LoRAQKVLinear, lora_filter, mark_only_lora_as_trainable, merge_lora_weights from litgpt.model import GPT as BaseGPT from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -869,3 +879,29 @@ def test_zero_pad_cpu_and_mocked_mps(): assert result_cpu.shape == result_mps.shape, "Shape mismatch between CPU and MPS" assert torch.allclose(result_cpu, result_mps), "Tensor values mismatch between CPU and MPS" + + + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + lora_r=8, + lora_alpha=16, + lora_dropout=0.1 + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config) + state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) diff --git a/tests/test_model.py b/tests/test_model.py index 1ef7cfd5d1..9d696c9397 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,9 +2,9 @@ from copy import deepcopy from functools import partial +from unittest import mock import pytest -from unittest import mock import torch from lightning import Fabric from lightning.fabric.utilities.imports import _IS_WINDOWS @@ -30,8 +30,8 @@ from transformers.models.olmo import OlmoConfig, OlmoForCausalLM import litgpt.config as config_module -from litgpt.model import batched_index_copy_ from litgpt import GPT, Config +from litgpt.model import CausalSelfAttention, batched_index_copy_ from litgpt.scripts.convert_hf_checkpoint import ( copy_weights_falcon, copy_weights_gemma_2, @@ -39,6 +39,7 @@ copy_weights_hf_llama, copy_weights_phi, ) +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -1055,3 +1056,24 @@ def test_batched_index_copy_modes(): val_3_mps = val_3 batched_index_copy_(t3_mps, dim_3, idx_3_mps, val_3_mps) assert torch.allclose(t3_cpu, t3_mps), "Mismatch with negative dimension on mocked MPS" + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config) + state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) From 465a9f7094b1bba0a1510de29698188fcca977c2 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Mon, 18 Nov 2024 12:10:55 +0300 Subject: [PATCH 11/15] attn.attn --> attn.qkv --- litgpt/adapter.py | 2 +- litgpt/adapter_v2.py | 4 +-- litgpt/generate/tp.py | 13 ++++---- litgpt/lora.py | 4 +-- litgpt/model.py | 2 +- litgpt/scripts/convert_hf_checkpoint.py | 31 ++++------------- litgpt/scripts/convert_lit_checkpoint.py | 18 +++++----- tests/test_adapter.py | 8 ++--- tests/test_adapter_v2.py | 36 ++++++++++---------- tests/test_convert_hf_checkpoint.py | 2 +- tests/test_generate_sequentially.py | 34 +++++++++---------- tests/test_lora.py | 42 ++++++++++++------------ 12 files changed, 89 insertions(+), 107 deletions(-) diff --git a/litgpt/adapter.py b/litgpt/adapter.py index 8523cec814..f8e4ac51e4 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -151,7 +151,7 @@ def scaled_dot_product_attention( ak, av = self.adapter_kv_cache else: prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) - aqkv = self.attn(prefix) + aqkv = self.qkv(prefix) q_per_kv = self.config.n_head // self.config.n_query_groups aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) aqkv = aqkv.permute(0, 2, 3, 1, 4) diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index bec5ca12b2..513f5b1745 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -189,8 +189,8 @@ def __init__(self, config: Config, block_idx: int) -> None: def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { - "attn.weight": "attn.linear.weight", - "attn.bias": "attn.linear.bias", + "qkv.weight": "qkv.linear.weight", + "qkv.bias": "qkv.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index c76d4f27c9..7b45ffd014 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -3,31 +3,30 @@ import logging import sys import time +import warnings from functools import partial from pathlib import Path from pprint import pprint from typing import Literal, Optional, Union -import warnings import lightning as L -from lightning_utilities.core.imports import RequirementCache import torch import torch._dynamo.config import torch._inductor.config from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.utilities import rank_zero_only +from lightning_utilities.core.imports import RequirementCache import litgpt.generate.base as generate_base -from litgpt.model import GPT from litgpt.config import Config -from litgpt.tokenizer import Tokenizer -from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE +from litgpt.model import GPT, CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style +from litgpt.tokenizer import Tokenizer from litgpt.utils import ( check_nvlink_connectivity, check_valid_checkpoint_dir, extend_checkpoint_dir, - get_default_supported_precision + get_default_supported_precision, ) @@ -71,7 +70,7 @@ def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMA def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None: - tensor_parallel_linear(fabric, attn.attn, "colwise") + tensor_parallel_linear(fabric, attn.qkv, "colwise") tensor_parallel_linear(fabric, attn.proj, "rowwise") attn.register_forward_hook(partial(all_reduce_output, fabric.world_size)) diff --git a/litgpt/lora.py b/litgpt/lora.py index 5d5ff95c2e..0718b01f67 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -616,8 +616,8 @@ def __init__(self, config: Config, block_idx: int) -> None: def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { - "attn.weight": "attn.linear.weight", - "attn.bias": "attn.linear.bias", + "qkv.weight": "qkv.linear.weight", + "qkv.bias": "qkv.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } diff --git a/litgpt/model.py b/litgpt/model.py index d25f7d8dda..ba8344999a 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -288,7 +288,7 @@ def forward( # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value` # instead of individually multiplying the input `x` with the respective weight matrices. - qkv = self.attn(x) # (B, T, 3xC*) + qkv = self.qkv(x) # (B, T, 3xC*) # Define query, key and value sizes. # If grouped/multi query is enabled, these sizes are not equal (see the diagram in `lit_gpt/config.py::Config`). diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 533e2f5ae9..b24f874680 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -32,8 +32,8 @@ def copy_weights_gpt_neox( "gpt_neox.embed_in.weight": "transformer.wte.weight", "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", - "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", - "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.qkv.bias", + "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight", "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, @@ -83,7 +83,7 @@ def copy_weights_falcon( ) -> None: weight_map = { "transformer.word_embeddings.weight": "transformer.wte.weight", - "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight", "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", @@ -209,7 +209,7 @@ def copy_weights_hf_llama( k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: @@ -277,7 +277,7 @@ def copy_weights_gemma_2( k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: @@ -325,7 +325,7 @@ def copy_weights_phi( if config.name.startswith("Phi-3"): weight_map.update( { - "model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.attn.weight", + "model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.qkv.weight", "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", @@ -370,30 +370,13 @@ def copy_weights_phi( k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) qkv = torch.cat((q, k, v)) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] if progress_per_file is not None: pbar.update(progress_per_file) -# def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: -# """Reassemble from a normal to an interleaved placement in a QKV matrix. -# [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] -# """ -# q, k, v = param.split( -# ( -# config.n_head * config.head_size, -# config.n_query_groups * config.head_size, -# config.n_query_groups * config.head_size, -# ) -# ) -# qs = q.split(config.n_head // config.n_query_groups * config.head_size) -# ks = k.split(config.head_size) -# vs = v.split(config.head_size) -# interleaved = [t for group in zip(qs, ks, vs) for t in group] -# return torch.cat(interleaved) - def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index 65cf1cd194..bab1ab57d2 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -23,7 +23,7 @@ def copy_weights_falcon( ) -> None: weight_map = { "transformer.wte.weight": "transformer.word_embeddings.weight", - "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", + "transformer.h.{}.attn.qkv.weight": "transformer.h.{}.self_attention.query_key_value.weight", "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", @@ -55,7 +55,7 @@ def copy_weights_falcon( name_template, layer_idx = layer_template(from_name) to_name = weight_map[name_template].format(layer_idx) param = load_param(param, from_name, None) - if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] param = qkv_reassemble(param, config) if saver is not None: @@ -73,8 +73,8 @@ def copy_weights_gpt_neox( "transformer.wte.weight": "gpt_neox.embed_in.weight", "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", - "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", - "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", + "transformer.h.{}.attn.qkv.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", + "transformer.h.{}.attn.qkv.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", @@ -92,7 +92,7 @@ def copy_weights_gpt_neox( name_template, layer_idx = layer_template(from_name) to_name = weight_map[name_template].format(layer_idx) param = load_param(param, from_name, None) - if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] param = qkv_reassemble(param, config) if saver is not None: @@ -143,7 +143,7 @@ def copy_weights_llama( continue name_template, *ids = layer_template(from_name, num_matches=2) param = load_param(param, from_name, None) - if from_name.endswith(".attn.attn.weight"): + if from_name.endswith(".attn.qkv.weight"): to_names = ( "model.layers.{}.self_attn.q_proj.weight".format(*ids), "model.layers.{}.self_attn.k_proj.weight".format(*ids), @@ -192,7 +192,7 @@ def copy_weights_gemma_2( continue name_template, *ids = layer_template(from_name, num_matches=2) param = load_param(param, from_name, None) - if from_name.endswith(".attn.attn.weight"): + if from_name.endswith(".attn.qkv.weight"): to_names = ( "model.layers.{}.self_attn.q_proj.weight".format(*ids), "model.layers.{}.self_attn.k_proj.weight".format(*ids), @@ -239,7 +239,7 @@ def copy_weights_phi( if config.name.startswith("Phi-3"): weight_map.update( { - "transformer.h.{}.attn.attn.weight": "model.layers.{}.self_attn.qkv_proj.weight", + "transformer.h.{}.attn.qkv.weight": "model.layers.{}.self_attn.qkv_proj.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", @@ -251,7 +251,7 @@ def copy_weights_phi( for from_name, param in lit_weights.items(): name_template, layer_idx = layer_template(from_name) param = load_param(param, from_name, None) - if from_name.endswith((".attn.attn.weight", ".attn.attn.bias")): + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): if config.name.startswith("Phi-3"): to_names = (weight_map[name_template].format(layer_idx),) params = (param,) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 0c8b098710..9deb7be1f7 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -194,7 +194,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca "transformer.h.0.norm_1.weight", "transformer.h.0.norm_1.bias", "transformer.h.0.attn.gating_factor", - "transformer.h.0.attn.attn.bias", + "transformer.h.0.attn.qkv.bias", "transformer.h.0.attn.proj.bias", "transformer.h.0.attn.adapter_wte.weight", "transformer.h.0.norm_2.weight", @@ -204,7 +204,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca "transformer.h.1.norm_1.weight", "transformer.h.1.norm_1.bias", "transformer.h.1.attn.gating_factor", - "transformer.h.1.attn.attn.bias", + "transformer.h.1.attn.qkv.bias", "transformer.h.1.attn.proj.bias", "transformer.h.1.attn.adapter_wte.weight", "transformer.h.1.norm_2.weight", @@ -216,11 +216,11 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca }, "torch.uint8": { "lm_head.weight", - "transformer.h.0.attn.attn.weight", + "transformer.h.0.attn.qkv.weight", "transformer.h.0.attn.proj.weight", "transformer.h.0.mlp.fc.weight", "transformer.h.0.mlp.proj.weight", - "transformer.h.1.attn.attn.weight", + "transformer.h.1.attn.qkv.weight", "transformer.h.1.attn.proj.weight", "transformer.h.1.mlp.fc.weight", "transformer.h.1.mlp.proj.weight", diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index 9fea6c1386..ca00a5d641 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -35,10 +35,10 @@ def test_config_identical(): base_model = BaseGPT.from_name(name) adapter_model = AdapterV2GPT.from_name(name) - assert not hasattr(base_model.transformer.h[2].attn.attn, "adapter_bias") - assert not hasattr(base_model.transformer.h[2].attn.attn, "adapter_scale") - assert hasattr(adapter_model.transformer.h[2].attn.attn, "adapter_bias") - assert hasattr(adapter_model.transformer.h[2].attn.attn, "adapter_scale") + assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_bias") + assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_scale") + assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_bias") + assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_scale") def test_adapter_v2_filter(tmp_path): @@ -58,8 +58,8 @@ def test_adapter_v2_filter(tmp_path): } for layer in range(3): for param in ( - "attn.attn.adapter_bias", - "attn.attn.adapter_scale", + "attn.qkv.adapter_bias", + "attn.qkv.adapter_scale", "attn.proj.adapter_bias", "attn.proj.adapter_scale", "mlp.fc.adapter_bias", @@ -366,27 +366,27 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "torch.uint8": { "transformer.h.0.mlp.fc.linear.weight", "transformer.h.1.mlp.proj.linear.weight", - "transformer.h.1.attn.attn.linear.weight", + "transformer.h.1.attn.qkv.linear.weight", "transformer.h.0.attn.proj.linear.weight", "lm_head.linear.weight", "transformer.h.1.attn.proj.linear.weight", "transformer.h.0.mlp.proj.linear.weight", - "transformer.h.0.attn.attn.linear.weight", + "transformer.h.0.attn.qkv.linear.weight", "transformer.h.1.mlp.fc.linear.weight", }, "torch.float16": { - "transformer.h.1.attn.attn.adapter_bias", + "transformer.h.1.attn.qkv.adapter_bias", "transformer.h.1.mlp.proj.adapter_bias", - "transformer.h.0.attn.attn.adapter_bias", + "transformer.h.0.attn.qkv.adapter_bias", "transformer.h.0.norm_1.bias", - "transformer.h.0.attn.attn.linear.bias", + "transformer.h.0.attn.qkv.linear.bias", "transformer.h.1.attn.adapter_wte.weight", "transformer.ln_f.weight", "transformer.h.0.mlp.fc.linear.bias", "transformer.h.0.mlp.proj.linear.bias", "transformer.h.1.mlp.fc.linear.bias", "transformer.h.0.attn.proj.adapter_scale", - "transformer.h.0.attn.attn.adapter_scale", + "transformer.h.0.attn.qkv.adapter_scale", "transformer.h.1.norm_2.bias", "transformer.h.1.attn.proj.adapter_scale", "transformer.h.0.norm_2.bias", @@ -408,9 +408,9 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "lm_head.adapter_bias", "transformer.h.1.norm_2.weight", "transformer.h.0.attn.adapter_wte.weight", - "transformer.h.1.attn.attn.adapter_scale", + "transformer.h.1.attn.qkv.adapter_scale", "transformer.h.1.mlp.fc.adapter_scale", - "transformer.h.1.attn.attn.linear.bias", + "transformer.h.1.attn.qkv.linear.bias", "transformer.wte.weight", "transformer.wte.norm.weight", "transformer.wte.norm.bias", @@ -437,20 +437,20 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "transformer.ln_f.bias", "lm_head.adapter_scale", "transformer.h.1.norm_2.weight", - "transformer.h.0.attn.attn.adapter_scale", + "transformer.h.0.attn.qkv.adapter_scale", "transformer.h.0.mlp.proj.adapter_bias", "transformer.h.0.attn.gating_factor", "transformer.h.1.norm_1.bias", "transformer.h.1.mlp.fc.adapter_bias", "transformer.h.1.mlp.proj.adapter_scale", "transformer.h.0.mlp.fc.adapter_scale", - "transformer.h.1.attn.attn.adapter_bias", + "transformer.h.1.attn.qkv.adapter_bias", "transformer.h.0.norm_2.weight", "transformer.h.1.norm_2.bias", "transformer.h.0.norm_1.weight", "transformer.h.0.attn.proj.adapter_scale", "transformer.h.1.mlp.proj.adapter_bias", - "transformer.h.0.attn.attn.adapter_bias", + "transformer.h.0.attn.qkv.adapter_bias", "transformer.h.0.attn.adapter_wte.weight", "transformer.ln_f.weight", "transformer.h.1.attn.gating_factor", @@ -460,7 +460,7 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "transformer.h.0.norm_1.bias", "transformer.h.0.norm_2.bias", "transformer.h.1.norm_1.weight", - "transformer.h.1.attn.attn.adapter_scale", + "transformer.h.1.attn.qkv.adapter_scale", } } diff --git a/tests/test_convert_hf_checkpoint.py b/tests/test_convert_hf_checkpoint.py index 98499c5a6a..38b41f711d 100644 --- a/tests/test_convert_hf_checkpoint.py +++ b/tests/test_convert_hf_checkpoint.py @@ -63,7 +63,7 @@ def test_llama2_70b_conversion(): # the shapes are correct holder = {k: tuple(t.shape) for k, t in holder.items()} assert holder == { - "transformer.h.0.attn.attn.weight": (10240, 8192), + "transformer.h.0.attn.qkv.weight": (10240, 8192), "transformer.h.0.attn.proj.weight": (8192, 8192), "transformer.h.0.mlp.fc_1.weight": (28672, 8192), "transformer.h.0.mlp.fc_2.weight": (28672, 8192), diff --git a/tests/test_generate_sequentially.py b/tests/test_generate_sequentially.py index 51bc9d2fe1..2d7603eb60 100644 --- a/tests/test_generate_sequentially.py +++ b/tests/test_generate_sequentially.py @@ -12,13 +12,13 @@ import pytest import torch import yaml -from tests.conftest import RunIf from lightning import Fabric from litgpt import Config from litgpt.generate.sequentially import layer_to_device, replace_device, sequential from litgpt.model import GPT, Block from litgpt.scripts.download import download_from_hub +from tests.conftest import RunIf @pytest.mark.parametrize( @@ -117,8 +117,8 @@ def _test_model_1device(accelerator): "cos": device_str, "sin": device_str, "lm_head.weight": device_str, - "transformer.h.0.attn.attn.bias": device_str, - "transformer.h.0.attn.attn.weight": device_str, + "transformer.h.0.attn.qkv.bias": device_str, + "transformer.h.0.attn.qkv.weight": device_str, "transformer.h.0.attn.proj.bias": device_str, "transformer.h.0.attn.proj.weight": device_str, "transformer.h.0.mlp.fc.bias": device_str, @@ -131,8 +131,8 @@ def _test_model_1device(accelerator): "transformer.h.0.norm_2.weight": device_str, "transformer.h.0.attn.kv_cache.k": device_str, "transformer.h.0.attn.kv_cache.v": device_str, - "transformer.h.1.attn.attn.bias": device_str, - "transformer.h.1.attn.attn.weight": device_str, + "transformer.h.1.attn.qkv.bias": device_str, + "transformer.h.1.attn.qkv.weight": device_str, "transformer.h.1.attn.proj.bias": device_str, "transformer.h.1.attn.proj.weight": device_str, "transformer.h.1.mlp.fc.bias": device_str, @@ -187,8 +187,8 @@ def test_model_forward_hooks(): "transformer.wte.weight": "cuda:0", "transformer.h.0.norm_1.weight": "cuda:0", "transformer.h.0.norm_1.bias": "cuda:0", - "transformer.h.0.attn.attn.weight": "cuda:0", - "transformer.h.0.attn.attn.bias": "cuda:0", + "transformer.h.0.attn.qkv.weight": "cuda:0", + "transformer.h.0.attn.qkv.bias": "cuda:0", "transformer.h.0.attn.proj.weight": "cuda:0", "transformer.h.0.attn.proj.bias": "cuda:0", "transformer.h.0.norm_2.weight": "cuda:0", @@ -199,8 +199,8 @@ def test_model_forward_hooks(): "transformer.h.0.mlp.proj.bias": "cuda:0", "transformer.h.1.norm_1.weight": "cuda:0", "transformer.h.1.norm_1.bias": "cuda:0", - "transformer.h.1.attn.attn.weight": "cuda:0", - "transformer.h.1.attn.attn.bias": "cuda:0", + "transformer.h.1.attn.qkv.weight": "cuda:0", + "transformer.h.1.attn.qkv.bias": "cuda:0", "transformer.h.1.attn.proj.weight": "cuda:0", "transformer.h.1.attn.proj.bias": "cuda:0", "transformer.h.1.norm_2.weight": "cuda:0", @@ -211,8 +211,8 @@ def test_model_forward_hooks(): "transformer.h.1.mlp.proj.bias": "cuda:0", "transformer.h.2.norm_1.weight": "cuda:0", "transformer.h.2.norm_1.bias": "cuda:0", - "transformer.h.2.attn.attn.weight": "cuda:0", - "transformer.h.2.attn.attn.bias": "cuda:0", + "transformer.h.2.attn.qkv.weight": "cuda:0", + "transformer.h.2.attn.qkv.bias": "cuda:0", "transformer.h.2.attn.proj.weight": "cuda:0", "transformer.h.2.attn.proj.bias": "cuda:0", "transformer.h.2.norm_2.weight": "cuda:0", @@ -223,8 +223,8 @@ def test_model_forward_hooks(): "transformer.h.2.mlp.proj.bias": "cuda:0", "transformer.h.3.norm_1.weight": "cuda:1", "transformer.h.3.norm_1.bias": "cuda:1", - "transformer.h.3.attn.attn.weight": "cuda:1", - "transformer.h.3.attn.attn.bias": "cuda:1", + "transformer.h.3.attn.qkv.weight": "cuda:1", + "transformer.h.3.attn.qkv.bias": "cuda:1", "transformer.h.3.attn.proj.weight": "cuda:1", "transformer.h.3.attn.proj.bias": "cuda:1", "transformer.h.3.norm_2.weight": "cuda:1", @@ -235,8 +235,8 @@ def test_model_forward_hooks(): "transformer.h.3.mlp.proj.bias": "cuda:1", "transformer.h.4.norm_1.weight": "cuda:1", "transformer.h.4.norm_1.bias": "cuda:1", - "transformer.h.4.attn.attn.weight": "cuda:1", - "transformer.h.4.attn.attn.bias": "cuda:1", + "transformer.h.4.attn.qkv.weight": "cuda:1", + "transformer.h.4.attn.qkv.bias": "cuda:1", "transformer.h.4.attn.proj.weight": "cuda:1", "transformer.h.4.attn.proj.bias": "cuda:1", "transformer.h.4.norm_2.weight": "cuda:1", @@ -247,8 +247,8 @@ def test_model_forward_hooks(): "transformer.h.4.mlp.proj.bias": "cuda:1", "transformer.h.5.norm_1.weight": "cuda:1", "transformer.h.5.norm_1.bias": "cuda:1", - "transformer.h.5.attn.attn.weight": "cuda:1", - "transformer.h.5.attn.attn.bias": "cuda:1", + "transformer.h.5.attn.qkv.weight": "cuda:1", + "transformer.h.5.attn.qkv.bias": "cuda:1", "transformer.h.5.attn.proj.weight": "cuda:1", "transformer.h.5.attn.proj.bias": "cuda:1", "transformer.h.5.norm_2.weight": "cuda:1", diff --git a/tests/test_lora.py b/tests/test_lora.py index a763b43e28..0db9ea5285 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -110,7 +110,7 @@ def test_lora_mqa_gqa(): ) assert config.n_query_groups == config.n_head model = LoRAGPT(config) - attn = model.transformer.h[0].attn.attn + attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) @@ -131,7 +131,7 @@ def test_lora_mqa_gqa(): # MQA config.n_query_groups = 1 model = LoRAGPT(config) - attn = model.transformer.h[0].attn.attn + attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) @@ -152,7 +152,7 @@ def test_lora_mqa_gqa(): # GQA config.n_query_groups = 2 model = LoRAGPT(config) - attn = model.transformer.h[0].attn.attn + attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) @@ -179,12 +179,12 @@ def test_lora_filter(tmp_path): saved = torch.load(save_path)["model"] expected = { - "transformer.h.1.attn.attn.lora_B", - "transformer.h.2.attn.attn.lora_B", - "transformer.h.2.attn.attn.lora_A", - "transformer.h.1.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_B", + "transformer.h.1.attn.qkv.lora_B", + "transformer.h.2.attn.qkv.lora_B", + "transformer.h.2.attn.qkv.lora_A", + "transformer.h.1.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_B", } assert set(saved) == expected @@ -750,29 +750,29 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.uint8": { - "transformer.h.0.attn.attn.linear.weight", + "transformer.h.0.attn.qkv.linear.weight", "transformer.h.0.attn.proj.linear.weight", "transformer.h.0.mlp.fc.linear.weight", "transformer.h.1.mlp.proj.linear.weight", "transformer.h.0.mlp.proj.linear.weight", - "transformer.h.1.attn.attn.linear.weight", + "transformer.h.1.attn.qkv.linear.weight", "lm_head.linear.weight", "transformer.h.1.attn.proj.linear.weight", "transformer.h.1.mlp.fc.linear.weight", }, "torch.float16": { - "transformer.h.0.attn.attn.lora_B", + "transformer.h.0.attn.qkv.lora_B", "transformer.h.0.norm_2.weight", "transformer.wte.weight", "transformer.wte.norm.weight", "transformer.wte.norm.bias", "transformer.h.1.mlp.fc.linear.bias", "transformer.ln_f.bias", - "transformer.h.1.attn.attn.lora_B", + "transformer.h.1.attn.qkv.lora_B", "transformer.h.1.attn.proj.linear.bias", "transformer.h.1.norm_1.weight", - "transformer.h.1.attn.attn.linear.bias", - "transformer.h.1.attn.attn.lora_A", + "transformer.h.1.attn.qkv.linear.bias", + "transformer.h.1.attn.qkv.lora_A", "transformer.h.1.norm_1.bias", "transformer.h.1.norm_2.bias", "transformer.h.0.attn.proj.linear.bias", @@ -781,11 +781,11 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa "transformer.h.0.mlp.fc.linear.bias", "transformer.h.0.norm_2.bias", "transformer.ln_f.weight", - "transformer.h.0.attn.attn.lora_A", + "transformer.h.0.attn.qkv.lora_A", "transformer.h.1.norm_2.weight", "transformer.h.1.mlp.proj.linear.bias", "transformer.h.0.norm_1.weight", - "transformer.h.0.attn.attn.linear.bias", + "transformer.h.0.attn.qkv.linear.bias", }, } @@ -797,10 +797,10 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.float16": { - "transformer.h.1.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_B", - "transformer.h.1.attn.attn.lora_B", + "transformer.h.1.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_B", + "transformer.h.1.attn.qkv.lora_B", } } From 5a48af15452bfaa552e2b7cae6731354be003b43 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Mon, 18 Nov 2024 12:13:11 +0300 Subject: [PATCH 12/15] Remove accidentally added files --- .pre-commit-config.yaml | 36 ------------------------- .ruff.toml | 58 ----------------------------------------- 2 files changed, 94 deletions(-) delete mode 100644 .pre-commit-config.yaml delete mode 100644 .ruff.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 2f2a273e3e..0000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,36 +0,0 @@ -repos: - # default hooks - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 - hooks: - - id: end-of-file-fixer - - id: trailing-whitespace - - id: check-yaml - - # Ruff - - repo: local - hooks: - - id: ruff-CLI - name: ruff-CLI - entry: bash -c "ruff check * --select E,W,F,S --extend-select C4,SIM,RET,PT,I001,ANN001,ANN201,ANN205,ANN206,ARG001 --ignore E501,E731,S108,S101,S113,S603,PT007,S310,E402,PT004,C408 --per-file-ignores "*/**/__init__.py":I001,"*/**/__init__.py":F401,"tests/*":ANN,"*/**/unsloth/kernels/*":ALL --line-length 120 --target-version py38 --fix-only" - language: system - pass_filenames: false - - # Black - - repo: local - hooks: - - id: black-CLI - name: black-CLI - entry: bash -c "black -l120 -C --preview --target-version py38 config_hub eval extensions litgpt tests --exclude unsloth/kernels" - language: system - pass_filenames: false - - # Markdown - - repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.39.0 - hooks: - - id: markdownlint-fix - args: - - --disable - - MD013 # line-length - - MD033 # no-inline-html diff --git a/.ruff.toml b/.ruff.toml deleted file mode 100644 index 7a4a65f2d3..0000000000 --- a/.ruff.toml +++ /dev/null @@ -1,58 +0,0 @@ -# https://beta.ruff.rs/docs/rules/ - -target-version = "py38" - -lint.select = [ - "ANN", # flake8-annotations - "ARG001", # Unused function argument - "C4", # flake8-comprehensions - "E", # pycodestyle Error - "F", # PyFlakes - "I001", # isort: unsorted-imports - "PT", # flake8-pytest-style - "RET", # flake8-return - "S", # flake8-bandit - "SIM", # flake8-simplify - "W", # pycodestyle Warning -] - -lint.ignore = [ - "ANN101", # Missing type annotation for self in method - "ANN102", # Missing type annotation for cls - "ANN202", # Missing return type annotation for private function - "ANN204", # Missing return type annotation for special method `__init__` - "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in `value` - "C408", # Unnecessary {obj_type} call (rewrite as a literal) - "E402", # Module level import not at top of file - "E501", # Line too long ({width} > {limit} characters) - "E731", # Do not assign a lambda expression, use a def - "PT004", # Fixture {function} does not return anything, add leading underscore - "PT007", # Wrong values type in @pytest.mark.parametrize expected {values} of {row} - "S101", # assert detected - "S108", # Probable insecure usage of temporary file or directory: "{}" - "S113", # Probable use of requests call without timeout - "S310", # Audit URL open for permitted schemes. Allowing use of file: or custom schemes is often unexpected. - "S603", # subprocess call: check for execution of untrusted input -] - -line-length = 120 - -[lint.per-file-ignores] -"*/**/unsloth/kernels/*" = [ - "ALL", -] -"__init__.py" = [ - "I001", # isort: unsorted-imports - "F401", # imported but unused -] -"conftest.py" = [ - "ANN001", # Missing type annotation for self in method - "ANN201", # Missing return type annotation for public function {name} - "D102", # Missing docstring in public method - "S101", # assert detected -] -"test_*.py" = [ - "ANN", # Type annotations - "D102", # Missing docstring in public method - "S101", # assert detected -] From 311c2c579ec134261985e5653ae82bf8e92dc342 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Tue, 19 Nov 2024 13:38:58 +0300 Subject: [PATCH 13/15] Cleaner version of load_state_dict for legacy checkpoints --- litgpt/adapter_v2.py | 9 +++++---- litgpt/lora.py | 10 +++++----- litgpt/model.py | 8 +++++--- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 513f5b1745..3ce60c9471 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -187,7 +187,7 @@ def __init__(self, config: Config, block_idx: int) -> None: self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: - """For compatibility with base checkpoints.""" + """For compatibility with base and/or legacy checkpoints.""" mapping = { "qkv.weight": "qkv.linear.weight", "qkv.bias": "qkv.linear.bias", @@ -200,9 +200,10 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa state_dict[key] = state_dict[key].permute(0, 2, 1, 3) for attr in ("weight", "bias"): - key = f"{prefix}attn.linear.{attr}" - if key in state_dict: - state_dict[f"{prefix}qkv.linear.{attr}"] = qkv_reassemble(state_dict.pop(key), self.config) + legacy_key = f"{prefix}attn.linear.{attr}" + current_key = f"{prefix}qkv.linear.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/litgpt/lora.py b/litgpt/lora.py index 0718b01f67..e519d5445d 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -614,20 +614,20 @@ def __init__(self, config: Config, block_idx: int) -> None: self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: - """For compatibility with base checkpoints.""" + """For compatibility with base and/or legacy checkpoints.""" mapping = { "qkv.weight": "qkv.linear.weight", "qkv.bias": "qkv.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } - state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) for attr in ("weight", "bias"): - key = f"{prefix}attn.linear.{attr}" - if key in state_dict: - state_dict[f"{prefix}qkv.linear.{attr}"] = qkv_reassemble(state_dict.pop(key), self.config) + legacy_key = f"{prefix}attn.linear.{attr}" + current_key = f"{prefix}qkv.linear.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/litgpt/model.py b/litgpt/model.py index ba8344999a..54b34cd478 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -406,10 +406,12 @@ def build_kv_cache( def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with legacy checkpoints.""" + for attr in ("weight", "bias"): - key = f"{prefix}attn.{attr}" - if key in state_dict: - state_dict[f"{prefix}qkv.{attr}"] = qkv_reassemble(state_dict.pop(key), self.config) + legacy_key = f"{prefix}attn.{attr}" + current_key = f"{prefix}qkv.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) From 5520aef2d37748082dcbaa4b32c8324be1288b7d Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 26 Dec 2024 18:55:01 +0300 Subject: [PATCH 14/15] Add note that SDPA is disabled for non None mask or softcapping --- litgpt/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litgpt/model.py b/litgpt/model.py index 54b34cd478..f3c426192d 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -349,6 +349,7 @@ def forward( mask += sliding_window_bias # Efficient attention using Flash Attention CUDA kernels. + # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) y = self.scaled_dot_product_attention(q, k, v, mask) From b7d82aa9791c933f9a12c596f6e7966b7d53e67b Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Thu, 26 Dec 2024 19:39:05 +0300 Subject: [PATCH 15/15] Align the code with non-interleaved placement of QKV --- litgpt/adapter_v2.py | 2 +- litgpt/model.py | 5 +-- litgpt/scripts/convert_hf_checkpoint.py | 40 +++++++++------------ litgpt/scripts/convert_lit_checkpoint.py | 45 ++++++++++++------------ 4 files changed, 44 insertions(+), 48 deletions(-) diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 6885f628aa..9b975260f0 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -164,7 +164,7 @@ def __init__(self, config: Config, block_idx: int) -> None: nn.Module.__init__(self) shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) + self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) diff --git a/litgpt/model.py b/litgpt/model.py index 5fbd0a8c24..cbdf2a4bdd 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -347,8 +347,9 @@ def forward( # NOTE: flash attention requires it in training mode. # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): - k = k.expand(*q.shape) # (B, nh_q, T, hs) - v = v.expand(*q.shape) # (B, nh_q, T, hs) + q_per_kv = self.config.n_head // self.config.n_query_groups + k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) if self.apply_sliding_window_attention: """ diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 2c0dbb6aad..fbcfa871a6 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -408,20 +408,17 @@ def copy_weights_qwen_2_5( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l = layer_template(name, 2) - qkv = qkv_weights.setdefault(l, defaultdict(dict)) - if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): - weight_name, weight_type = from_name.split(".")[-2:] - qkv[weight_type][weight_name] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -436,22 +433,19 @@ def copy_weights_qwen_2_5( for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: - # split across different .bin files + # qkv is splitted across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] + if progress_per_file is not None: pbar.update(progress_per_file) + def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -483,7 +477,7 @@ def layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]: def load_param( - param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose=False + param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose: bool =False ) -> torch.Tensor: if hasattr(param, "_load_tensor"): # support tensors loaded via `lazy_load()` diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index 5bb08ea4f6..f276e3ae31 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -310,34 +310,35 @@ def copy_weights_qwen_2_5( "lm_head.weight": "lm_head.weight", } - for name, param in lit_weights.items(): - if name == "lm_head.weight" and untie_weights: + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: continue - if name.endswith((".attn.attn.weight", ".attn.attn.bias")): - from_name, l_idx = layer_template(name, 2) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - - weight_type = name.split(".")[-1] # weight or bias - q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}" - k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}" - v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}" - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): + weight_type = from_name.split(".")[-1] # weight or bias + to_names = ( + "model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: - if "transformer.h" in name: - from_name, l_idx = layer_template(name, 2) - to_name = weight_map[from_name] - to_name = to_name.format(l_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param + def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: """Reassemble from a normal to an interleaved placement in a QKV matrix. [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]