Skip to content

Commit

Permalink
[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Apr 26, 2024
1 parent 603ad84 commit a62aaf1
Show file tree
Hide file tree
Showing 45 changed files with 759 additions and 713 deletions.
2 changes: 1 addition & 1 deletion tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.linear_method, Fp8LinearMethod)
assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
4 changes: 2 additions & 2 deletions tests/tensorizer_loader/test_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_agent_instance.deserialize.return_value = MagicMock()

result = load_with_tensorizer(tensorizer_config,
linear_method=mock_linear_method)
quant_method=mock_linear_method)

mock_agent.assert_called_once_with(tensorizer_config,
linear_method=mock_linear_method)
quant_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value

Expand Down
30 changes: 13 additions & 17 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,9 @@ def set_mapping(
self.indices = base_indices
self.indices_len = indices_len

def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
Expand All @@ -416,7 +415,7 @@ def forward(self, input_):
if not self.base_layer.skip_bias_add else None)

# Matrix multiply.
output_parallel = self.apply_weights(input_, bias)
output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
Expand Down Expand Up @@ -523,10 +522,9 @@ def set_lora(
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)

def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -765,10 +763,9 @@ def set_lora(
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)

def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -862,9 +859,8 @@ def set_mapping(
self.indices = base_indices
self.indices_len = indices_len

def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
_apply_lora(
x,
self.lora_a_stacked,
Expand Down Expand Up @@ -897,7 +893,7 @@ def forward(self, input_):
input_parallel = splitted_input[tp_rank].contiguous()

# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
Expand Down
Loading

0 comments on commit a62aaf1

Please sign in to comment.