Skip to content

Commit

Permalink
single req correct (support gqa)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 committed Sep 11, 2024
1 parent 054e7af commit ff5f51d
Show file tree
Hide file tree
Showing 13 changed files with 597 additions and 173 deletions.
146 changes: 78 additions & 68 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
self.segment_gemm = segment_gemm
self.lora_rank = lora_rank
self.scaling = scaling
self.set_lora = False

def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
Expand Down Expand Up @@ -83,7 +84,7 @@ def forward(self, input_: torch.Tensor):
self.base_layer, input_, bias
)

if hasattr(self, "A_buffer"):
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_)

if self.base_layer.gather_output:
Expand All @@ -101,6 +102,7 @@ def __init__(
super().__init__(base_layer, segment_gemm, lora_rank, scaling)

def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
Expand All @@ -121,54 +123,70 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2
for i in range(2):
left = self.lora_rank * i
right = left + self.lora_rank
lora_output[:, output_dim * i : output_dim * (i + 1)] = (
self.segment_gemm.run(
x=lora_a_output[:, left:right].contiguous(),
weights=self.B_buffer[:, :, left:right].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
left = output_dim * i
right = left + output_dim
lora_output[:, left:right] = self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * i : self.lora_rank * (i + 1)
].contiguous(),
weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
return base_output + lora_output * self.scaling


class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)

def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv
self.bs = bs
self.seq_lens = seq_lens
self.weight_indices = weight_indices

def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
weights=self.A_buffer_qkv,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
# FIXME parallelize qkv
assert lora_a_output.shape[-1] == self.lora_rank * 3
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 3
for i in range(3):
left = self.lora_rank * i
right = left + self.lora_rank
lora_output[:, output_dim * i : output_dim * (i + 1)] = (
# q
output_dim_q = self.B_buffer_q.shape[-2]
lora_output[:, :output_dim_q] = self.segment_gemm.run(
x=lora_a_output[:, : self.lora_rank].contiguous(),
weights=self.B_buffer_q,
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
weight_indices=self.weight_indices,
)
# kv
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
for i in range(2):
left = output_dim_kv * i
right = left + output_dim_kv
lora_output[:, output_dim_q + left : output_dim_q + right] = (
self.segment_gemm.run(
x=lora_a_output[:, left:right].contiguous(),
weights=self.B_buffer[:, :, left:right].contiguous(),
x=lora_a_output[
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
].contiguous(),
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_lens=self.seq_lens,
Expand All @@ -185,6 +203,7 @@ def __init__(
super().__init__(base_layer, segment_gemm, lora_rank, scaling)

def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
Expand Down Expand Up @@ -212,7 +231,6 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor

def forward(self, input_):
# duplicate the logic in RowParallelLinear
# Set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
Expand All @@ -225,7 +243,7 @@ def forward(self, input_):
self.base_layer, input_parallel
)

if hasattr(self, "A_buffer"):
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)

if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
Expand All @@ -252,9 +270,10 @@ def get_lora_layer(
supported_layer_types = {
# the order matters
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLora,
RowParallelLinear: RowParallelLinearWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
Expand All @@ -263,19 +282,6 @@ def get_lora_layer(
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")


def params_mapping(module_name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
if module_name in params_mapping:
return params_mapping[module_name]
return module_name


def get_mapped_params(module_names):
ret = set()
for module_name in module_names:
Expand Down Expand Up @@ -323,6 +329,7 @@ def __init__(self, uid, config, base_hf_config, load_config):
def get_stacked_multiply(self, module_name):
stacked_rank = {
"qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
Expand Down Expand Up @@ -354,40 +361,43 @@ def initialize_weights(self):
else:
self.weights[name] = loaded_weight.cpu()

# stack qkv_proj and gate_up_proj
# stack kv_proj and gate_up_proj
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
for weight_name in weight_names:
if "q_proj" in weight_name:
k_name = weight_name.replace("q_proj", "k_proj")
v_name = weight_name.replace("q_proj", "v_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj")
layer.weights[qkv_name] = torch.cat(
(
layer.weights[weight_name],
layer.weights[k_name],
layer.weights[v_name],
),
(
0
if layer.weights[weight_name].shape[0] == self.config.r
else 1
),
)
layer.weights.pop(weight_name)
layer.weights.pop(k_name)
layer.weights.pop(v_name)
if "k_proj" in weight_name:
q_name = weight_name.replace("k_proj", "q_proj")
v_name = weight_name.replace("k_proj", "v_proj")
kv_name = weight_name.replace("k_proj", "kv_proj")
qkv_name = weight_name.replace("k_proj", "qkv_proj")
if "lora_A" in weight_name:
layer.weights[qkv_name] = torch.cat(
(
layer.weights[q_name],
layer.weights[weight_name],
layer.weights[v_name],
),
0,
)
layer.weights.pop(q_name)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
else:
layer.weights[kv_name] = torch.cat(
(
layer.weights[weight_name],
layer.weights[v_name],
),
0,
)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
elif "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]),
(
0
if layer.weights[weight_name].shape[0] == self.config.r
else 1
),
(layer.weights[weight_name], layer.weights[up_name]), 0
)
layer.weights.pop(weight_name)
layer.weights.pop(up_name)
Loading

0 comments on commit ff5f51d

Please sign in to comment.