Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] [Triton] [AMD] Add Triton implementation of awq_dequantize #6850

Closed
10 changes: 10 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@

logger = init_logger(__name__)

# NOTE: 7/26/24: Use Triton implementation of AWQ functions. Currently, only
# awq_dequantize is implemented, but awq_gemm coming soon.

use_awq_triton = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this get set? Should this be an environment variable or based on hardware available?


try:
import vllm._C
except ImportError as e:
Expand Down Expand Up @@ -181,6 +186,11 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
if use_awq_triton:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton)
return awq_dequantize_triton(qweight, scales, zeros, split_k_iters,
thx, thy)
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
thx, thy)

Expand Down
272 changes: 272 additions & 0 deletions vllm/model_executor/layers/quantization/awq_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import torch
import triton
import triton.language as tl


@triton.jit
def awq_dequantize_kernel(
qweight_ptr, # quantized matrix
scales_ptr, # scales, per group
zeros_ptr, # zeros, per group
split_k_iters, # Not used
thx, # Not used
thy, # Not used
group_size, # Should always be 128
result_ptr, # Output matrix
num_cols, # input num cols in qweight
num_rows, # input num rows in qweight
reverse_awq_order_ptr,
BLOCK_SIZE_X: tl.constexpr,
BLOCK_SIZE_Y: tl.constexpr):
# Setup the pids.
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)

# Compute offsets and masks for qweight_ptr.
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]

masks_y = offsets_y < num_rows
masks_x = offsets_x < num_cols

masks = masks_y[:, None] & masks_x[None, :]

# Compute offsets and masks for result output ptr.
result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
result_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
tl.arange(0, BLOCK_SIZE_X * 8))
result_offsets = (8 * num_cols * result_offsets_y[:, None] +
result_offsets_x[None, :])

result_masks_y = result_offsets_y < num_rows
result_masks_x = result_offsets_x < num_cols * 8
result_masks = result_masks_y[:, None] & result_masks_x[None, :]

# Load the weights.
iweights = tl.load(qweight_ptr + offsets, masks)

# Load the AWQ reverse order offsets.
reverse_awq_order_offsets = tl.arange(0, 8)
reverse_awq_order_tensor = tl.load(reverse_awq_order_ptr +
reverse_awq_order_offsets)

# Use this to compute a set of shifts that can be used to unpack and
# reorder the values in iweights and zeros.
shifts = reverse_awq_order_tensor * 4
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Unpack and reorder: shift out the correct 4-bit value and mask.
iweights = (iweights >> shifts) & 0xF

# Compute zero offsets and masks.
zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
tl.arange(0, BLOCK_SIZE_Y) // group_size)
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]

zero_masks_y = zero_offsets_y < num_rows // group_size
zero_masks_x = zero_offsets_x < num_cols
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]

# Load the zeros.
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)

# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros = (zeros >> shifts) & 0xF

# Compute scale offsets and masks.
scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
tl.arange(0, BLOCK_SIZE_Y) // group_size)
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
tl.arange(0, BLOCK_SIZE_X * 8))
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
scale_offsets_x[None, :])
scale_masks_y = scale_offsets_y < num_rows // group_size
scale_masks_x = scale_offsets_x < num_cols * 8
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]

# Load the scales.
scales = tl.load(scales_ptr + scale_offsets, scale_masks)

# Dequantize.
iweights = (iweights - zeros) * scales
iweights = iweights.to(tl.float16)

# Finally, store.
tl.store(result_ptr + result_offsets, iweights, result_masks)


# Example input:
# qweight.size=torch.Size([3584, 576]),
# qweight.dtype = torch.int32,
# scales.size=torch.Size([28, 4608]),
# scales.dtype=torch.float16,
# zeros.size=torch.Size([28, 576]),
# zeros.dtype=torch.int32
# split_k_iters=0
# thx=0
# thy=0
def awq_dequantize_triton(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
split_k_iters: int, # Not used
thx: int, # Not used
thy: int # Not used
) -> torch.Tensor:
# Result tensor:
# number of rows = same as input tensor
# number of cols = 8 x input tensor num cols
result = torch.empty(qweight.shape[0],
qweight.shape[1] * 8,
device=qweight.device,
dtype=torch.float16)

block_size_x = 32
block_size_y = 32

reverse_awq_order = torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
dtype=torch.uint8,
device=qweight.device)

Y = qweight.shape[0] # num rows
X = qweight.shape[1] # num cols
group_size = 128
grid = lambda META: (
triton.cdiv(X, META['BLOCK_SIZE_X']),
triton.cdiv(Y, META['BLOCK_SIZE_Y']),
)
awq_dequantize_kernel[grid](qweight,
scales,
zeros,
split_k_iters,
thx,
thy,
group_size,
result,
X,
Y,
reverse_awq_order,
BLOCK_SIZE_X=block_size_x,
BLOCK_SIZE_Y=block_size_y)

return result


def reverse_awq_order(t: torch.Tensor):
bits = 4
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
reverse_order_tensor = torch.arange(
t.shape[-1],
dtype=torch.int32,
device=t.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)

t = t[:, reverse_order_tensor] & 0xF
return t


# qweightss [R , C // 8], int32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# qweightss [R , C // 8], int32
# qweights [R , C // 8], int32

# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
print(f"awq_dequantize_torch:qweight.shape = {qweight.shape}"
f", qzeros.shape={qzeros.shape}")
bits = 4
group_size = 128
shifts = torch.arange(0, 32, bits, device=qzeros.device)

iweights = torch.bitwise_right_shift(qweight[:, :, None],
shifts[None, None, :]).to(torch.int8)

iweights = iweights.view(iweights.shape[0], -1)

# iweights = reverse_awq_order(iweights)
# return (iweights & 0xF).to(torch.float16)

zeros = torch.bitwise_right_shift(qzeros[:, :, None],
shifts[None, None, :]).to(torch.int8)

zeros = zeros.view(qzeros.shape[0], -1)

zeros = reverse_awq_order(zeros)
iweights = reverse_awq_order(iweights)

iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)

scales = scales.repeat_interleave(group_size, dim=0)
zeros = zeros.repeat_interleave(group_size, dim=0)
print(f"awq_dequantize_torch:iweights.shape = {iweights.shape},"
f"zeros.shape={zeros.shape}, "
f"scales.shape={scales.shape}")

return (iweights - zeros) * scales


def main():
use_triton = True
use_torch = True

qweight_rows = 3584
qweight_cols = 576
group_size = 128
small_test_size = True
if small_test_size:
qweight_rows = 256
qweight_cols = 128
print(f"qweight_rows = {qweight_rows}, qweight_cols = {qweight_cols}")
qweight_dtype = torch.int32
scales_rows = qweight_rows // group_size
scales_cols = qweight_cols * 8
scales_dtype = torch.float16
zeros_rows = scales_rows
zeros_cols = qweight_cols
zeros_dtype = torch.int32
split_k_iters = 0
thx = 0
thy = 0
device = 'cuda'
torch.manual_seed(0)

qweight = torch.randint(0,
10000000, (qweight_rows, qweight_cols),
dtype=qweight_dtype,
device=device)
scales = torch.rand(scales_rows,
scales_cols,
dtype=scales_dtype,
device=device)
zeros = torch.randint(0,
10000000, (zeros_rows, zeros_cols),
dtype=zeros_dtype,
device=device)
print(f"zeros.shape = {zeros.shape}")
print(f"qweight = {qweight}")
if use_triton:
iweights_triton = awq_dequantize_triton(qweight, scales, zeros,
split_k_iters, thx, thy)
print(f"Triton result:iweights_triton = {iweights_triton}")
print(f"Any infs in triton result? -->"
f"{torch.any(torch.isinf(iweights_triton))}")

if use_torch:
iweights_torch = awq_dequantize_torch(qweight, scales, zeros,
split_k_iters, thx, thy)
print(f"Torch result:iweights_torch = {iweights_torch}")

if use_torch and use_triton:
diff = iweights_torch - iweights_triton
error = torch.sum(torch.sqrt(diff * diff))
print(f"error = {error}")


if __name__ == '__main__':
main()
Comment on lines +271 to +272
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you convert this to a separate test file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing, but I've moved it over to this one #7386 , instead of two PRs, I'm going going to have one.

Loading