Skip to content

Commit

Permalink
Add sparse marlin 2:4 gemm op (#733)
Browse files Browse the repository at this point in the history
feat: add sparse marlin 2:4 kernel
  • Loading branch information
Diogo-V authored Aug 23, 2024
1 parent aacaf9b commit 614c667
Show file tree
Hide file tree
Showing 12 changed files with 2,542 additions and 1 deletion.
117 changes: 116 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
run_tests,
)
from torch.testing._internal.optests import opcheck
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
from torchao.prototype.quant_llm import from_scaled_tc_fpx
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
import pytest

if is_fbcode():
Expand Down Expand Up @@ -302,5 +303,119 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
test_utils=test_utils,
)


MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [512]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
]
MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(itertools.product(
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
))

def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int):
orig_device = w.device
size_k, size_n = w.shape

assert w.is_floating_point(), "w must be float"

if group_size == -1:
group_size = size_k
assert group_size <= size_k

max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2

# Reshape to [groupsize, -1]
if group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))

# Compute scale for each group
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / max_q_val # 2 => symmetric

# Quantize
q_w = torch.round(w / s).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)

# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s

# Restore original shapes
if group_size < size_k:

def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w

q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)

s = s.reshape((-1, size_n)).contiguous()

return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda")

# Inject 2:4 sparsity
w_24, _ = inject_24(b_weight, size_k, size_n)

# Symmetric quantize
w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size)

# Obtains reference output
output_ref = torch.matmul(a_input, w_24_ref)

# Packs to marlin 2:4
marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size)
workspace_24 = marlin_24_workspace(size_n)

fn_inputs = (
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1],
)
output = torchao.ops.marlin_24_gemm(*fn_inputs)
torch.cuda.synchronize()

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04

# Performs opcheck
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
opcheck(
torch.ops.torchao.marlin_24_gemm,
fn_inputs,
test_utils=test_utils,
)


if __name__ == "__main__":
run_tests()
51 changes: 51 additions & 0 deletions torchao/csrc/cuda/sparse_marlin/base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace torchao {

constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }

// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};

template <int M_, int N_, int K_>
struct ShapeBase {
static constexpr int M = M_, N = N_, K = K_;
};

using I4 = Vec<int, 4>;

// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragM = Vec<uint, 1>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales

} // namespace torchao
Loading

0 comments on commit 614c667

Please sign in to comment.