Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support running GPTQ 8-bit models in Marlin #4533

Merged
merged 9 commits into from
May 2, 2024

Conversation

alexm-redhat
Copy link
Collaborator

This PR adds 8-bit weight support to the Marlin GPU kernel and the Marlin on-the-fly repack. As a result, all GPTQ 8-bit models (with any group size or act_order) can run via Marlin (for Ampere GPUs and up).

@@ -273,7 +273,9 @@ def generate_greedy_logprobs(
return all_logprobs

def __del__(self):
del self.model
if self.model is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why needed? @alexm-nm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got occasional warnings about del self.model being None. Not necessary for correctness if it causes issues in CI.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we should touch this file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@@ -114,11 +115,20 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
return res;
}

template <int start_byte, int mask>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't hurt to add a comment on what this does :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexm-nm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, added

FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order

// if (blockIdx.x == 0 && threadIdx.x == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it is left over from debugging -- remove it before merging the PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexm-nm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@@ -62,7 +65,14 @@ def _get_perms():
return perm, scale_perm, scale_perm_single


_perm, _scale_perm, _scale_perm_single = _get_perms()
_perm = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why you need these global variables -- if you want to avoid recomputing things, it is most likely cleaner / better to use a function with a @functools.lru_cache annotation :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that we don't need most of this code anymore since we use auto-repack GPU kernel. Refactored to use only scale shuffles.

elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ValueError? Raising generic exceptions is not a good idea :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed


TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity

// printf("tb_k = %d, tb_n = %d, pipe_size = %f, scales_cache_size = %d,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, remove debug code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, removed

int scales_cache_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
// printf("scales_cache_size = %d\n", scales_cache_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here and below

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

int num_threads = th_config.num_threads;
thread_k = th_config.thread_k;
thread_n = th_config.thread_n;
// printf("exec_config: max_m_blocks = %d, thread_k = %d, thread_n = %d\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@alexm-redhat
Copy link
Collaborator Author

alexm-redhat commented May 2, 2024

Here are vLLM server TTFT and TPOT benchmark results for 8-bit GPTQ Llama-3-8B and Yi-34B on A100 GPU on GCP instance (based on benchmark_serving.py):

Llama3-8B 8-bit GPTQ:
image

Yi-34B 8-bit GPTQ:
image

Original PDFs:
vLLM Server - Llama-3-8B 8-bit.pdf
vLLM Server - Yi-34B 8-bit.pdf

@pcmoritz
Copy link
Collaborator

pcmoritz commented May 2, 2024

Thanks for fixing the comments, this looks much nicer :)

@@ -63,10 +70,11 @@ def test_models(
gptq_marlin_model = vllm_runner(model_name=model_name,
revision=revision,
dtype=dtype,
quantization="marlin",
quantization="gptq",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexm-nm this test should have marlin for quantization

Also - is enforce_eager=True required?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh must be a leftover from debug, good catch, will fix it in 30min

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, tests pass

max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1,
disable_custom_all_reduce=True)
disable_custom_all_reduce=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make this cleaner, can we remove disable_custom_all_reduce?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) May 2, 2024 14:07
@robertgshaw2-redhat
Copy link
Collaborator

@alexm-nm the model test failed

auto-merge was automatically disabled May 2, 2024 15:39

Head branch was pushed to by a user without write access

@robertgshaw2-redhat robertgshaw2-redhat merged commit 7038e8b into vllm-project:main May 2, 2024
48 checks passed
@robertgshaw2-redhat robertgshaw2-redhat deleted the marlin_8bit branch May 2, 2024 16:56
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants