-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Support running GPTQ 8-bit models in Marlin #4533
[Kernel] Support running GPTQ 8-bit models in Marlin #4533
Conversation
tests/conftest.py
Outdated
@@ -273,7 +273,9 @@ def generate_greedy_logprobs( | |||
return all_logprobs | |||
|
|||
def __del__(self): | |||
del self.model | |||
if self.model is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why needed? @alexm-nm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got occasional warnings about del self.model being None. Not necessary for correctness if it causes issues in CI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we should touch this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
@@ -114,11 +115,20 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) { | |||
return res; | |||
} | |||
|
|||
template <int start_byte, int mask> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It wouldn't hurt to add a comment on what this does :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexm-nm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, added
FragC frag_c[thread_m_blocks][4][2]; | ||
FragS frag_s[2][4]; // No act-order | ||
FragS act_frag_s[2][4][4]; // For act-order | ||
|
||
// if (blockIdx.x == 0 && threadIdx.x == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it is left over from debugging -- remove it before merging the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexm-nm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
@@ -62,7 +65,14 @@ def _get_perms(): | |||
return perm, scale_perm, scale_perm_single | |||
|
|||
|
|||
_perm, _scale_perm, _scale_perm_single = _get_perms() | |||
_perm = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why you need these global variables -- if you want to avoid recomputing things, it is most likely cleaner / better to use a function with a @functools.lru_cache
annotation :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized that we don't need most of this code anymore since we use auto-repack GPU kernel. Refactored to use only scale shuffles.
elif num_bits == 8: | ||
interleave = numpy.array([0, 2, 1, 3]) | ||
else: | ||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ValueError? Raising generic exceptions is not a good idea :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
|
||
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity | ||
|
||
// printf("tb_k = %d, tb_n = %d, pipe_size = %f, scales_cache_size = %d, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, remove debug code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, removed
int scales_cache_size = | ||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, | ||
group_size, has_act_order, is_k_full); | ||
// printf("scales_cache_size = %d\n", scales_cache_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here and below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
int num_threads = th_config.num_threads; | ||
thread_k = th_config.thread_k; | ||
thread_n = th_config.thread_n; | ||
// printf("exec_config: max_m_blocks = %d, thread_k = %d, thread_n = %d\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debug code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
Here are vLLM server TTFT and TPOT benchmark results for 8-bit GPTQ Llama-3-8B and Yi-34B on A100 GPU on GCP instance (based on benchmark_serving.py): Original PDFs: |
Thanks for fixing the comments, this looks much nicer :) |
tests/models/test_gptq_marlin.py
Outdated
@@ -63,10 +70,11 @@ def test_models( | |||
gptq_marlin_model = vllm_runner(model_name=model_name, | |||
revision=revision, | |||
dtype=dtype, | |||
quantization="marlin", | |||
quantization="gptq", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexm-nm this test should have marlin
for quantization
Also - is enforce_eager=True
required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh must be a leftover from debug, good catch, will fix it in 30min
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, tests pass
tests/models/test_gptq_marlin.py
Outdated
max_model_len=MAX_MODEL_LEN, | ||
tensor_parallel_size=1, | ||
disable_custom_all_reduce=True) | ||
disable_custom_all_reduce=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to make this cleaner, can we remove disable_custom_all_reduce
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will try
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
works
@alexm-nm the model test failed |
Head branch was pushed to by a user without write access
This PR adds 8-bit weight support to the Marlin GPU kernel and the Marlin on-the-fly repack. As a result, all GPTQ 8-bit models (with any group size or act_order) can run via Marlin (for Ampere GPUs and up).