From 01fd1863b752920d6468c6c95bdb5baae9dcdd2d Mon Sep 17 00:00:00 2001 From: chenqianfzh <51831990+chenqianfzh@users.noreply.github.com> Date: Tue, 17 Sep 2024 08:09:12 -0700 Subject: [PATCH] [Feature][kernel] tensor parallelism with bitsandbytes quantization (#8434) Signed-off-by: Amit Garg --- tests/quantization/test_bitsandbytes.py | 26 ++++++++++--- vllm/config.py | 6 --- vllm/model_executor/layers/linear.py | 21 ++++++++--- vllm/model_executor/model_loader/loader.py | 44 +++++++++++++++++++++- 4 files changed, 80 insertions(+), 17 deletions(-) diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 87200b1dcc534..36167cf95f589 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name) +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason='Test requires at least 2 GPUs.') +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@fork_new_process_for_each_test +def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + hf_model_kwargs = {"load_in_4bit": True} + validate_generated_texts(hf_runner, + vllm_runner, + example_prompts[:1], + model_name, + hf_model_kwargs, + vllm_tp_size=2) + + def log_generated_texts(prompts, outputs, runner_name): logged_texts = [] for i, (_, generated_text) in enumerate(outputs): @@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner, vllm_runner, prompts, model_name, - hf_model_kwargs=None): + hf_model_kwargs=None, + vllm_tp_size=1): # NOTE: run vLLM first, as it requires a clean process # when using distributed inference - - #Run with vLLM runner with vllm_runner(model_name, quantization='bitsandbytes', load_format='bitsandbytes', + tensor_parallel_size=vllm_tp_size, enforce_eager=True, gpu_memory_utilization=0.8) as llm: vllm_outputs = llm.generate_greedy(prompts, 8) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") # Clean up the GPU memory for the next test - torch.cuda.synchronize() gc.collect() torch.cuda.empty_cache() @@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner, hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") # Clean up the GPU memory for the next test - torch.cuda.synchronize() gc.collect() torch.cuda.empty_cache() diff --git a/vllm/config.py b/vllm/config.py index a0991597d0673..6c24d15640e99 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -393,12 +393,6 @@ def verify_with_parallel_config( "Pipeline parallelism is only supported for the following " f" architectures: {_PP_SUPPORTED_MODELS}.") - if self.quantization == "bitsandbytes" and ( - parallel_config.tensor_parallel_size > 1 - or parallel_config.pipeline_parallel_size > 1): - raise ValueError( - "BitAndBytes quantization with TP or PP is not supported yet.") - # Remove the constraint after the bitsandbytes issue is fixed: # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 if self.quantization == "bitsandbytes" and self.enforce_eager is False: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index cea768469aeb8..568892778abe2 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -530,8 +530,11 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -899,8 +902,13 @@ def weight_loader(self, else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -1000,6 +1008,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) # Special case for GGUF is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -1015,7 +1024,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data - if input_dim is not None: + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if input_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index ac869e56ce198..fd9533ab156a5 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -22,6 +22,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -689,6 +691,8 @@ def save_model( class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" + # TODO: these module names are for Llama only, + # change so that it works with other models as well default_target_modules = [ "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", "o_proj" @@ -911,13 +915,44 @@ def _parse_quant_state(param_name: str, def _unquantized_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: from bitsandbytes.functional import quantize_4bit + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): if any(target_module in weight_name for target_module in self.target_modules): weight_name = weight_name.replace(".weight", ".qweight") + + # weight partitions of different modules occur at + # different dimensions + # TODO: these module names are for Llama only, + # change so that it works with other models as well + if 'down_proj' in weight_name or 'o_proj' in weight_name: + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, + ...] + # bitsandbytes requires data in GPU - loaded_weight = weight_tensor.cuda().data + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + with set_default_torch_dtype(torch.float32): processed_weight, quant_state = quantize_4bit( loaded_weight, @@ -958,6 +993,13 @@ def _load_weights(self, model_config: ModelConfig, f"BitsAndBytes loader does not support {quant_method} " "quantization") + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with TP is not supported." + "Please try with PP.") + load_8bit = False if pre_quant: load_8bit = quant_config.get('load_in_8bit', False)