From 614c6675e9ae287acde4f134850e8db4b9653068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Diogo=20Ven=C3=A2ncio?= Date: Fri, 23 Aug 2024 17:52:56 +0100 Subject: [PATCH] Add sparse marlin 2:4 gemm op (#733) feat: add sparse marlin 2:4 kernel --- test/test_ops.py | 117 +- torchao/csrc/cuda/sparse_marlin/base.h | 51 + .../cuda/sparse_marlin/marlin_kernel_nm.cu | 1126 +++++++++++++++++ torchao/csrc/cuda/sparse_marlin/mem.h | 136 ++ torchao/csrc/cuda/sparse_marlin/mma.h | 191 +++ torchao/csrc/sparse_marlin.cpp | 8 + torchao/ops.py | 102 ++ torchao/sparsity/marlin/README.md | 6 + torchao/sparsity/marlin/__init__.py | 351 +++++ torchao/sparsity/marlin/utils.py | 417 ++++++ torchao/sparsity/utils.py | 35 + torchao/utils.py | 3 + 12 files changed, 2542 insertions(+), 1 deletion(-) create mode 100644 torchao/csrc/cuda/sparse_marlin/base.h create mode 100644 torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu create mode 100644 torchao/csrc/cuda/sparse_marlin/mem.h create mode 100644 torchao/csrc/cuda/sparse_marlin/mma.h create mode 100644 torchao/csrc/sparse_marlin.cpp create mode 100644 torchao/sparsity/marlin/README.md create mode 100644 torchao/sparsity/marlin/__init__.py create mode 100644 torchao/sparsity/marlin/utils.py diff --git a/test/test_ops.py b/test/test_ops.py index eecb4a287b..171089237b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -10,8 +10,9 @@ run_tests, ) from torch.testing._internal.optests import opcheck -from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff from torchao.prototype.quant_llm import from_scaled_tc_fpx +from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24 import pytest if is_fbcode(): @@ -302,5 +303,119 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size test_utils=test_utils, ) + +MARLIN_24_K_CHUNKS = [128] +MARLIN_24_N_CHUNKS = [512] +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), +] +MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_TEST_PARAMS = list(itertools.product( + MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS, + MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS +)) + +def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + s = s.reshape((-1, size_n)).contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s.to(device=orig_device), + ) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str) +def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda") + b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda") + + # Inject 2:4 sparsity + w_24, _ = inject_24(b_weight, size_k, size_n) + + # Symmetric quantize + w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size) + + # Obtains reference output + output_ref = torch.matmul(a_input, w_24_ref) + + # Packs to marlin 2:4 + marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size) + workspace_24 = marlin_24_workspace(size_n) + + fn_inputs = ( + a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, + num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1], + ) + output = torchao.ops.marlin_24_gemm(*fn_inputs) + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + assert max_diff < 0.04 + + # Performs opcheck + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] + opcheck( + torch.ops.torchao.marlin_24_gemm, + fn_inputs, + test_utils=test_utils, + ) + + if __name__ == "__main__": run_tests() diff --git a/torchao/csrc/cuda/sparse_marlin/base.h b/torchao/csrc/cuda/sparse_marlin/base.h new file mode 100644 index 0000000000..bf81fb5d8a --- /dev/null +++ b/torchao/csrc/cuda/sparse_marlin/base.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace torchao { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +template +struct ShapeBase { + static constexpr int M = M_, N = N_, K = K_; +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragM = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +} // namespace torchao \ No newline at end of file diff --git a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu new file mode 100644 index 0000000000..29c17d1bdd --- /dev/null +++ b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu @@ -0,0 +1,1126 @@ +/* + * Notice: This file was modified by Neuralmagic inc to include 8-bit support + * + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is a modified version of the original Marlin kernel from the file: +// https://github.com/neuralmagic/nm-vllm/blob/9daca33a6fdc429802f448e1ea71630c996c9740/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu + +#include + +#include +#include +#include +#include +#include + +#include + +#include "base.h" +#include "mem.h" +#include "mma.h" + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace torchao { + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int THREADS = 256; +static constexpr int STAGES = 4; + +static constexpr int min_thread_n = 128; + +static constexpr int tile_size = 16; +static constexpr int max_par = 64; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin_24( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) {} + +torch::Tensor marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin_24( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + // number of thread_k_blocks in k-dim + int k_tiles = prob_k / 32 / thread_k_blocks; + // number of thread_n_blocks in n-dim + int n_tiles = prob_n / 16 / thread_n_blocks; + // iters needed to cover all slices + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + // number of threadblock tiles in the current slice + int slice_iters; + // total number of active threadblocks in the current slice + int slice_count = 0; + // index of threadblock in current slice; numbered bottom to top + int slice_idx; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 32 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads //RLC: 2 * #warps k-dim + constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + constexpr int pack_factor = 32 / num_bits; + + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + constexpr int m_sh_stride = + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; + int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); + constexpr int m_sh_wr_delta = threads / 2; + constexpr int m_sh_rd_delta = threads / 2; + constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; + constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + + (threadIdx.x % (m_sh_stride)); + m_gl_rd += (m_sh_stride)*slice_col; + m_gl_rd += m_gl_rd_delta_o * slice_row; + int m_sh_wr = threadIdx.x; + int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; + + int s_gl_rd; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + if (group_blocks != -1) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } else { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[0][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[1][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; + const int4* meta_ptr[m_sh_iters]; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + int4* sh_m = sh_s + (stages * s_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks][2]; + I4 frag_b_quant[2][b_thread_vecs]; + FragM frag_m[2][2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + int4* sh_meta_stage = sh_m + m_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) { + if (m_sh_wr_pred) + cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); + meta_ptr[i] += m_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + ldsm4(frag_a[k % 2][i][0], + &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i][1], + &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); + } + + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + + // Load meta with ldsm4 + int4* sh_m_stage = sh_m + m_sh_stage * pipe; + ldsm4_m(frag_m[k % 2][0], + &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + + } else { + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], + frag_m[k % 2][j / 2], j % 2); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; + int c_gl_wr_delta_i = + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int col = 2 * ((threadIdx.x % 32) % 4); + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j2 = 0; j2 < 2; j2++) { + #pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2] += + __half2float( + reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); + } + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j2 = 0; j2 < 2; j2++) { + #pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2]); + } + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + + int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + + constexpr int c_sh_rd_delta = + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + + (threadIdx.x % (2 * 2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, + float c4, float c5, float c6, float c7, FragS& s1) { + uint2 res[2]; + res[0] = to_half4(c0, c1, c2, c3); + res[1] = to_half4(c4, c5, c6, c7); + half2* tmp = (half2*)&res; + // for per-column quantization we finally apply the scale here + if constexpr (group_blocks == -1 && num_bits == 4) { + tmp[0] = __hmul2(tmp[0], s0[0]); + tmp[1] = __hmul2(tmp[1], s0[1]); + tmp[2] = __hmul2(tmp[2], s1[0]); + tmp[3] = __hmul2(tmp[3], s1[1]); + } + ((int4*)sh)[idx] = *((int4*)&res[0]); + }; + + // RLC: only warp 0 and 1 baseline example + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + int wr = c_sh_wr; + write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], + frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], + frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], + frag_s[0][2]); + write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1], + frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0], + frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3], + frag_c[i][3][0][3], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0], + frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0], + frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2], + frag_c[i][3][1][2], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1], + frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1], + frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3], + frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]); + + c_sh_wr += 8 * c_sh_stride_2; + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + matmul(pipe); + wait_for_stage(); + + fetch_to_registers(pipe + 1, (pipe + 1) % stages); + + pipe++; + slice_iters--; + if (slice_iters == 0) break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + } + } + thread_block_reduce(); + + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); + } + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], + &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], + &frag_c[i][0][0][2], &frag_c[i][1][0][2], + &frag_c[i][2][0][2], &frag_c[i][3][0][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1], + &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0], + &frag_c[i][0][0][3], &frag_c[i][1][0][3], + &frag_c[i][2][0][3], &frag_c[i][3][0][3], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0], + &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0], + &frag_c[i][0][1][2], &frag_c[i][1][1][2], + &frag_c[i][2][1][2], &frag_c[i][3][1][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1], + &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0], + &frag_c[i][0][1][3], &frag_c[i][1][1][3], + &frag_c[i][2][1][3], &frag_c[i][3][1][3], + frag_s[0][2]); + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + +#endif + +#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + cudaFuncSetAttribute( \ + Marlin_24, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin_24 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, prob_n, \ + prob_m, prob_k, locks); \ + } + +void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, + void* s, int prob_m, int prob_n, int prob_k, + void* workspace, int num_bits, int groupsize = -1, + int dev = 0, cudaStream_t stream = 0, int thread_k = -1, + int thread_m = -1, int sms = -1, int max_par = 16) { + int tot_n = prob_n; + int tot_n_blocks = ceildiv(tot_n, 16); + int pad = 16 * tot_n_blocks - tot_n; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + TORCH_CHECK(sms > 0); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (thread_k == -1 || thread_m == -1) { + if (prob_n <= 16) { + // For small batchizes, better partitioningif is slightly more important + // than better compute utilization + thread_k = 128; + thread_m = 128; + } else if (prob_n <= 256) { + thread_k = 64; + thread_m = 256; + } else { + thread_k = 32; + thread_m = 512; + } + } + + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_m_blocks = thread_m / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m, + " is not divisible by thread_m = ", thread_m); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + if (group_blocks != -1) { + TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2, + " is not divisible by group_blocks = ", group_blocks); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + const int4* meta_ptr = (const int4*)meta; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + + constexpr int max_m_blocks = 4; + + int* locks = (int*)workspace; + for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { + int thread_n_blocks = tot_n_blocks - i; + prob_n = tot_n - 16 * i; + int par = 1; + if (thread_n_blocks > max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); + if (par > max_par) par = max_par; + prob_n = (max_m_blocks * 16) * par; + i += max_m_blocks * (par - 1); + thread_n_blocks = max_m_blocks; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + + // the false is start of the CALL_IF macros + if (false) { + } // BMxBNxBK, group + // 4-bit + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 16, 2, 2, 4) + CALL_IF_2_4(4, 16, 3, 2, -1) + CALL_IF_2_4(4, 16, 3, 2, 4) + CALL_IF_2_4(4, 16, 4, 2, -1) + CALL_IF_2_4(4, 16, 4, 2, 4) + + CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 32, 2, 1, 4) + CALL_IF_2_4(4, 32, 3, 1, -1) + CALL_IF_2_4(4, 32, 3, 1, 4) + CALL_IF_2_4(4, 32, 4, 1, -1) + CALL_IF_2_4(4, 32, 4, 1, 4) + + // 8-bit + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 16, 2, 2, 4) + CALL_IF_2_4(8, 16, 3, 2, -1) + CALL_IF_2_4(8, 16, 3, 2, 4) + CALL_IF_2_4(8, 16, 4, 2, -1) + CALL_IF_2_4(8, 16, 4, 2, 4) + + CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 32, 2, 1, 4) + CALL_IF_2_4(8, 32, 3, 1, -1) + CALL_IF_2_4(8, 32, 3, 1, 4) + CALL_IF_2_4(8, 32, 4, 1, -1) + CALL_IF_2_4(8, 32, 4, 1, 4) + else { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + + ", " + str(prob_k) + ", " + str(prob_n) + "]" + + ", groupsize = " + str(groupsize) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; + } +} + +torch::Tensor marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify M + TORCH_CHECK(size_m == a.size(0), + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + + // Verify K + TORCH_CHECK(size_k == a.size(1), + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + TORCH_CHECK(size_k % torchao::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(torchao::tile_size)); + TORCH_CHECK((size_k / torchao::tile_size / 2) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + + ", tile_size = " + str(torchao::tile_size)); + + // Verify N + TORCH_CHECK(b_scales.size(1) == size_n, + "b_scales.size(1) = " + str(b_scales.size(1)) + + ", size_n = " + str(size_n)); + TORCH_CHECK( + b_q_weight.size(1) % torchao::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(torchao::tile_size)); + + int actual_size_n = (b_q_weight.size(1) / torchao::tile_size) * pack_factor; + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); + + // Verify meta + TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, + "b_meta.size(0) = ", b_meta.size(0), + " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2); + TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1), + " is not size_n * 2 = ", size_n * 2); + + // Verify A device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + // Verify B device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + // Verify b_meta device and strides + TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU"); + TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous"); + + // Verify scales device and strides + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc C matrix + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + int thread_k = -1; + int thread_m = -1; + int sms = -1; + int max_par = torchao::max_par; + + int groupsize = -1; + if (b_scales.size(0) > 1) { + TORCH_CHECK(size_k % b_scales.size(0) == 0, + "size_k = " + str(size_k) + + ", is not divisible by b_scales.size(0) = " + + str(b_scales.size(0))); + groupsize = size_k / b_scales.size(0); + groupsize /= 2; // Because of 24 + } + + // Verify groupsize + TORCH_CHECK(groupsize == -1 || groupsize == 64, + "Unexpected groupsize = " + str(groupsize)); + + // Verify workspace size + TORCH_CHECK(size_n % torchao::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(torchao::min_thread_n)); + int min_workspace_size = + (size_n / torchao::min_thread_n) * torchao::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + int dev = a.get_device(); + torchao::marlin_cuda_2_4( + a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), + num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_m, sms, max_par); + + return c; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::marlin_24_gemm", &marlin_24_gemm); +} + +} // namespace torchao \ No newline at end of file diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h new file mode 100644 index 0000000000..59d5af38e7 --- /dev/null +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" + +namespace torchao { +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred_zfill(void* smem_ptr, + const void* glob_ptr, + bool pred = true, + const bool zfill = false) { + const int BYTES = 16; + int src_in_bytes = (zfill ? 0 : BYTES); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); +} + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_m); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} +} // namespace torchao \ No newline at end of file diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h new file mode 100644 index 0000000000..dde6938d83 --- /dev/null +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" +#include + +namespace torchao { + +// On CUDA earlier than 12.5, the ordered_metadata version of this instruction +// is not supported. On later versions of CUDA the version without ordered +// metadata results in the following warning: +// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction +// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially +// | reduced performance on some future architectures +#if defined CUDA_VERSION && CUDA_VERSION >= 12050 + #define MMA_SP_INST \ + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " +#else + #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " +#endif + +// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, + const FragA& frag_b, FragC& frag_c, FragM& frag_m, + const int psel) { + const uint32_t* a0 = reinterpret_cast(&a_frag0); + const uint32_t* a1 = reinterpret_cast(&a_frag1); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* e = reinterpret_cast(&frag_m); + + float* c = reinterpret_cast(&frag_c); + if (psel == 0) { + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), + "f"(c[2]), "f"(c[3]), "r"(e[0])); + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), + "f"(c[6]), "f"(c[7]), "r"(e[0])); + } else { + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), + "f"(c[2]), "f"(c[3]), "r"(e[0])); + asm volatile(MMA_SP_INST + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), + "f"(c[6]), "f"(c[7]), "r"(e[0])); + } +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, + float c3) { + uint2 r; + asm("{\n\t" + ".reg .f16 a, b, c, d; \n\t" + "cvt.rn.f16.f32 a, %2; \n\t" + "cvt.rn.f16.f32 b, %3; \n\t" + "cvt.rn.f16.f32 c, %4; \n\t" + "cvt.rn.f16.f32 d, %5; \n\t" + "mov.b32 %0, {a, b}; \n\t" + "mov.b32 %1, {c, d}; \n\t" + "}" + : "=r"(r.x), "=r"(r.y) + : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); + return r; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_4bit(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, + FragS& s0, float* c4, float* c5, float* c6, + float* c7, FragS& s1) { + *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); + *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); + *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); + *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); + + *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); + *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); + *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); + *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); +} + +} // namespace torchao \ No newline at end of file diff --git a/torchao/csrc/sparse_marlin.cpp b/torchao/csrc/sparse_marlin.cpp new file mode 100644 index 0000000000..70350dda9d --- /dev/null +++ b/torchao/csrc/sparse_marlin.cpp @@ -0,0 +1,8 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor"); +} \ No newline at end of file diff --git a/torchao/ops.py b/torchao/ops.py index cb337aabbe..5bb8271638 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -159,3 +159,105 @@ def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) + + +def marlin_24_gemm( + x: Tensor, + weight_marlin: Tensor, + meta: Tensor, + s: Tensor, + workspace: Tensor, + bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> Tensor: + """ + Sparse Marlin 2:4 matrix multiplication. Reference: https://github.com/IST-DASLab/Sparse-Marlin/tree/main + Args: + x: input matrix of shape `(n, k/2)` in column-major layout. + weight_marlin: weight matrix of original shape `(m, k)` in Marlin format; see `Layer.pack()`. + meta: metadata information for 2:4 sparsity. + s: scales of shape `(n / groupsize / 2, m)`. + workspace: tensor with at least `m / 128 * max_par` entries that are all zero. + bits: number of bits for quantization. + size_m: number of rows in input matrix. + size_n: number of columns in weight matrix. + size_k: number of columns in input matrix. + Returns: + output matrix of shape `(n, m)` in column-major layout. + """ + return torch.ops.torchao.marlin_24_gemm.default( + x, weight_marlin, meta, s, workspace, bits, size_m, size_n, size_k + ) + + +@register_custom_op(f"torchao::marlin_24_gemm") +def _( + x: Tensor, + weight_marlin: Tensor, + meta: Tensor, + s: Tensor, + workspace: Tensor, + bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> Tensor: + TILE_SIZE = 16 + MIN_THREAD_N = 128 + MAX_PARALLELISM = 64 + + # Verify num_bits + torch._check(bits == 4 or bits == 8, lambda: f"num_bits must be 4 or 8. Got = {bits}") + pack_factor = 32 // bits + + # Verify M + torch._check(size_m == x.size(0), lambda: f"Shape mismatch: x.size(0) = {x.size(0)}, size_m = {size_m}") + + # Verify K + torch._check(size_k == x.size(1), lambda: f"Shape mismatch: x.size(1) = {x.size(1)}, size_k = {size_k}") + torch._check(size_k % TILE_SIZE == 0, lambda: f"size_k = {size_k} is not divisible by tile_size = {TILE_SIZE}") + torch._check((size_k // TILE_SIZE // 2) == weight_marlin.size(0), lambda: f"Shape mismatch: weight_marlin.size(0) = {weight_marlin.size(0)}, size_k = {size_k}, tile_size = {TILE_SIZE}") + + # Verify N + torch._check(s.size(1) == size_n, lambda: f"s.size(1) = {s.size(1)}, size_n = {size_n}") + torch._check(weight_marlin.size(1) % TILE_SIZE == 0, lambda: f"weight_marlin.size(1) = {weight_marlin.size(1)} is not divisible by tile_size = {TILE_SIZE}") + + actual_size_n = (weight_marlin.size(1) // TILE_SIZE) * pack_factor + torch._check(size_n == actual_size_n, lambda: f"size_n = {size_n}, actual_size_n = {actual_size_n}") + + # Verify meta + torch._check(meta.size(0) == size_k // 8 // 2 // 2, lambda: f"meta.size(0) = {meta.size(0)} is not size_k / 8 / 2 / 2 = {size_k // 8 // 2 // 2}") + torch._check(meta.size(1) == size_n * 2, lambda: f"meta.size(1) = {meta.size(1)} is not size_n * 2 = {size_n * 2}") + + # Verify A device and strides + torch._check(x.is_cuda, lambda: "x is not on GPU") + torch._check(x.is_contiguous(), lambda: "x is not contiguous") + + # Verify B device and strides + torch._check(weight_marlin.is_cuda, lambda: "weight_marlin is not on GPU") + torch._check(weight_marlin.is_contiguous(), lambda: "weight_marlin is not contiguous") + + # Verify meta device and strides + torch._check(meta.is_cuda, lambda: "meta is not on GPU") + torch._check(meta.is_contiguous(), lambda: "meta is not contiguous") + + # Verify scales device and strides + torch._check(s.is_cuda, lambda: "s is not on GPU") + torch._check(s.is_contiguous(), lambda: "s is not contiguous") + + # Verify groupsize + groupsize = -1 + if s.size(0) > 1: + torch._check(size_k % s.size(0) == 0, lambda: f"size_k = {size_k} is not divisible by s.size(0) = {s.size(0)}") + groupsize = size_k // s.size(0) + groupsize //= 2 # Because of 24 + torch._check(groupsize == -1 or groupsize == 64, lambda: f"Unexpected groupsize = {groupsize}") + + # Verify workspace size + torch._check(size_n % MIN_THREAD_N == 0, lambda: f"size_n = {size_n} is not divisible by min_thread_n = {MIN_THREAD_N}") + min_workspace_size = (size_n // MIN_THREAD_N) * MAX_PARALLELISM + torch._check(workspace.numel() >= min_workspace_size, lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}") + + return torch.empty((x.size(0), s.size(1)), dtype=x.dtype, device=x.device) diff --git a/torchao/sparsity/marlin/README.md b/torchao/sparsity/marlin/README.md new file mode 100644 index 0000000000..94d062365a --- /dev/null +++ b/torchao/sparsity/marlin/README.md @@ -0,0 +1,6 @@ +# Sparse Marlin + +Sparse Marlin implementation adapted from the two below sources: + +* [Sparse-Marlin](https://github.com/IST-DASLab/Sparse-Marlin/tree/main) +* [nm-vllm](https://github.com/neuralmagic/nm-vllm/tree/main) \ No newline at end of file diff --git a/torchao/sparsity/marlin/__init__.py b/torchao/sparsity/marlin/__init__.py new file mode 100644 index 0000000000..41a83be3d3 --- /dev/null +++ b/torchao/sparsity/marlin/__init__.py @@ -0,0 +1,351 @@ +import torch +import numpy as np +from typing import Tuple, Dict, List + +import torchao.sparsity.marlin.utils as utils +from torchao.sparsity.marlin.utils import const +from torchao.sparsity.utils import mask_creator + + +__all__ = [ + "inject_24", + "marlin_24_workspace", + "pack_to_marlin_24", + "unpack_from_marlin_24", +] + + +def inject_24(w: torch.Tensor, size_k: int, size_n: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Injects 2:4 sparsity into a weight tensor. The sparsity is applied in a 2:4 ratio, where for every + group of 4 weights, 2 will be pruned based on their value. The mask will be created based on the + ranked weight values. + + Args: + w (torch.Tensor): The weight tensor to inject sparsity into. + size_k (int): The number of input features. + size_n (int): The number of output features. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The pruned weight tensor and the mask tensor. + """ + assert w.shape == (size_k, size_n) + mask = mask_creator(w.t()).t().cuda().bool() + return (mask * w).contiguous(), mask.contiguous() + + +def marlin_24_workspace( + out_features: int, + min_thread_n: int = const.MIN_THREAD_N, + max_parallel: int = const.MAX_PARALLEL + ) -> torch.Tensor: + """Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks + during the execution of the kernel. + + Args: + out_features (int): The number of output features. + min_thread_n (int, optional): The minimum number of threads per block. Defaults to `MARLIN_24_MIN_THREAD_N`. + max_parallel (int, optional): The maximum number of parallel threads. Defaults to `MARLIN_24_MAX_PARALLEL`. + Returns: + torch.Tensor: The workspace tensor fully initialized with zeros. + """ + assert (out_features % min_thread_n == 0), f"out_features = {out_features}, min_thread_n = {min_thread_n}" + max_workspace_size = ((out_features // min_thread_n) * max_parallel) + return torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def pack_to_marlin_24( + q_w_24: torch.Tensor, + scales: torch.Tensor, + num_bits: int, + group_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Packs the quantized weights and scales into the marlin 2:4 format. + + Args: + q_w_24 (torch.Tensor): The quantized weight tensor with 2:4 sparsity applied. + scales (torch.Tensor): The scale tensor. + num_bits (int): The number of bits used for quantization. + group_size (int): The group size that was applied during quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The packed quantized weights, the packed scales, and the meta tensor. + """ + in_features, out_features = q_w_24.shape + + # Compress quantized weight + q_w_24_comp, meta = _compress_quantized_24_weight( + q_w_24, in_features, out_features, num_bits + ) + + in_features_comp = in_features // 2 + + # Reformat to marlin + marlin_24_q_w_comp = _to_marlin_weights( + q_w_24_comp, in_features_comp, out_features, num_bits + ) + + marlin_24_s = _to_marlin_scales( + scales, in_features, out_features, group_size, num_bits + ) + + return marlin_24_q_w_comp, marlin_24_s, meta + + +def unpack_from_marlin_24( + q_w_24_comp: torch.Tensor, + scales: torch.Tensor, + meta: torch.Tensor, + original_shape: torch.Size, + group_size: int, + num_bits: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Unpacks the quantized weights and scales from the marlin 2:4 format. + Args: + q_w_24_comp (torch.Tensor): The packed quantized weights. + scales (torch.Tensor): The packed scales. + meta (torch.Tensor): The meta tensor. + original_shape (torch.Size): The original shape of the weight tensor. + group_size (int): The group size that was applied during quantization. + num_bits (int): The number of bits used for quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The unpacked quantized weights and scales. + """ + in_features, out_features = original_shape + + # Unpacks the scales + unpacked_scales = _from_marlin_scale( + scales, *original_shape, group_size, num_bits + ) + + in_features_comp = in_features // 2 + + # Unpacks the weights + unpacked_q_w_24_comp = _from_marlin_weights( + q_w_24_comp, in_features_comp, out_features, num_bits + ) + + # Decompress quantized weight + unpacked_q_w_24 = _decompress_quantized_24_weight( + unpacked_q_w_24_comp, meta, in_features_comp, out_features, num_bits + ) + + return unpacked_q_w_24, unpacked_scales + + +def _compress_quantized_24_weight( + q_24: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0 + before compressing them. + + Args: + q_24 (torch.Tensor): The quantized weight tensor. + size_k (int): The number of input features. + size_n (int): The number of output features. + num_bits (int): The number of bits used for quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The compressed quantized weight tensor and the meta tensor. + """ + assert q_24.shape == (size_k, size_n) + + # Remove zp to normalize over 0 + max_q_val = (1 << num_bits) - 1 + zp = (max_q_val + 1) // 2 + q_24_no_zp = q_24 - zp + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = utils.sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore zp + q_24_comp = q_24_no_zp_comp + zp + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def _decompress_quantized_24_weight( + q_24_comp: torch.Tensor, + meta: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int + ) -> torch.Tensor: + """Decompresses the quantized weights from a 2:4 sparse format and restores the original shape. + + Args: + q_24_comp (torch.Tensor): The compressed quantized weight tensor in 2:4 sparse format. + meta (torch.Tensor): The meta tensor. + size_k (int): The number of input features. + size_n (int): The number of output features. + num_bits (int): The number of bits used for quantization. + Returns: + torch.Tensor: The decompressed quantized weight tensor. + """ + assert q_24_comp.shape == (size_k, size_n) + + # Resize meta back to its original shape + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + # Remove zp to normalize over 0 + max_q_val = (1 << num_bits) - 1 + zp = (max_q_val + 1) // 2 + q_24_no_zp_comp = q_24_comp - zp + + # Decompress + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + q_24_no_zp = utils.sparse_semi_structured_to_dense_cutlass(q_24_no_zp_comp, meta) + q_24_no_zp = q_24_no_zp.t().contiguous() + + # Restore zp + q_24 = q_24_no_zp + zp + + return q_24 + + +def _to_marlin_weights( + q_w: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + ) -> torch.Tensor: + """Converts a quantized and 2:4 sparse format weight tensor to the marlin 2:4 format. + + Args: + q_w (torch.Tensor): The quantized weight tensor in 2:4 sparse format. + size_k (int): The number of input features. + size_n (int): The number of output features. + num_bits (int): The number of bits used for quantization. + Returns: + torch.Tensor: The weight tensor in the marlin 2:4 format. + """ + # Permute + q_w = utils.marlin_permute_weights(q_w, size_k, size_n, marlin_24_perm[num_bits]) + + # Pack + pack_factor = utils.get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + return q_packed + + +def _from_marlin_weights( + q_packed: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int + ) -> torch.Tensor: + """Converts a weight tensor in the marlin 2:4 format to a regular quantized 2:4 sparse format. + + Args: + q_packed (torch.Tensor): The weight tensor in the marlin 2:4 format. + size_k (int): The number of input features. + size_n (int): The number of output features. + num_bits (int): The number of bits used for quantization. + Returns: + torch.Tensor: The weight tensor in the quantized 2:4 sparse format. + """ + reverse_perm = reverse_marlin_24_perm[num_bits] + + pack_factor = utils.get_pack_factor(num_bits) + orig_device = q_packed.device + + # Unpack + q_packed = q_packed.cpu().numpy().astype(np.uint32) + q_w_unpacked = np.zeros((q_packed.shape[0], q_packed.shape[1] * pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_w_unpacked[:, i::pack_factor] = (q_packed >> (num_bits * i)) & ((1 << num_bits) - 1) + + q_w_unpacked = torch.from_numpy(q_w_unpacked.astype(np.int32)).to(orig_device) + + q_w_comp = utils.reverse_marlin_permute_weights(q_w_unpacked, size_k, size_n, reverse_perm) + return q_w_comp + + +def _to_marlin_scales( + scales: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, + num_bits: int + ) -> torch.Tensor: + """Converts a scale tensor to the format necessary for marlin. + Args: + scales (torch.Tensor): The scale tensor. + size_k (int): The number of input features. + size_n (int): The number of output features. + group_size (int): The group size that was applied during quantization. + num_bits (int): The number of bits used for quantization. + + Returns: + torch.Tensor: The scale tensor in the marlin format. + """ + if group_size < size_k and group_size != -1: + perms = marlin_24_scale_perm[num_bits] + scales = scales.reshape((-1, len(perms)))[:, perms] + else: + perms = marlin_24_scale_perm_single[num_bits] + scales = scales.reshape((-1, len(perms)))[:, perms] + scales = scales.reshape((-1, size_n)).contiguous() + return scales + + +def _from_marlin_scale( + scales: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, + num_bits: int + ) -> torch.Tensor: + """Converts a scale tensor from the marlin format to their original format. + + Args: + scales (torch.Tensor): The scale tensor in the marlin format. + size_k (int): The number of input features. + size_n (int): The number of output features. + group_size (int): The group size that was applied during quantization. + num_bits (int): The number of bits used for quantization. + Returns: + torch.Tensor: The scale tensor in their original format + """ + if group_size < size_k and group_size != -1: + reverse_perms = reverse_marlin_24_scale_perm[num_bits] + scales = scales.reshape((-1, len(reverse_perms)))[:, reverse_perms] + return scales.reshape((size_k // group_size, size_n)) + else: + reverse_perms = reverse_marlin_24_scale_perm_single[num_bits] + scales = scales.reshape((-1, len(reverse_perms)))[:, reverse_perms] + return scales.reshape((1, -1)) + + +# Contains the permutations for marlin 2:4 quantization +marlin_24_perm: Dict[int, torch.Tensor] = {} +marlin_24_scale_perm: Dict[int, List[int]] = {} +marlin_24_scale_perm_single: Dict[int, List[int]] = {} + +# Contains the reverse permutations for marlin 2:4 quantization +reverse_marlin_24_perm: Dict[int, torch.Tensor] = {} +reverse_marlin_24_scale_perm: Dict[int, List[int]] = {} +reverse_marlin_24_scale_perm_single: Dict[int, List[int]] = {} + +for num_bits in const.SUPPORTED_NUM_BITS: + perm_24, scale_perm_24, scale_perm_single_24 = utils.get_perms_24(num_bits) + + marlin_24_perm[num_bits] = perm_24 + marlin_24_scale_perm[num_bits] = scale_perm_24 + marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 + + reverse_marlin_24_perm[num_bits] = perm_24.argsort() + reverse_marlin_24_scale_perm[num_bits] = torch.tensor(scale_perm_24).argsort() + reverse_marlin_24_scale_perm_single[num_bits] = torch.tensor(scale_perm_single_24).argsort() \ No newline at end of file diff --git a/torchao/sparsity/marlin/utils.py b/torchao/sparsity/marlin/utils.py new file mode 100644 index 0000000000..08b8f1efce --- /dev/null +++ b/torchao/sparsity/marlin/utils.py @@ -0,0 +1,417 @@ +import torch +import numpy as np +from typing import List, Tuple +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class Marlin24Constants: + TILE: int = 16 + MIN_THREAD_N: int = 128 + MAX_PARALLEL: int = 64 + + # NOTE: Cuda kernel supports fp8, but not implemented yet in SparseMarlinAQTLayout + SUPPORTED_NUM_BITS: List[int] = field(default_factory=lambda: [4, 8]) + SUPPORTED_GROUP_SIZES: List[int] = field(default_factory=lambda: [-1, 32, 64, 128]) +const = Marlin24Constants() + + +def get_pack_factor(num_bits: int) -> int: + """Compute the packing factor for a given number of bits. + + Args: + num_bits (int): Number of bits to pack. + Returns: + int: The packing factor. + """ + + assert num_bits in const.SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def marlin_permute_weights( + q_w: torch.Tensor, + size_k: int, + size_n: int, + perm: torch.Tensor, + tile: int = const.TILE + ) -> torch.Tensor: + """Permute weights to 16x64 Marlin tiles. + + Args: + q_w (torch.Tensor): Quantized weights. + size_k (int): Number of input features. + size_n (int): Number of output features. + perm (torch.Tensor): The computed permutation tensor to be applied. + tile (int, optional): Tile size. Defaults to `TILE`. + Returns: + torch.Tensor: Weight tensor permuted to Marlin tiles. + """ + + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def reverse_marlin_permute_weights( + q_w_unpacked: torch.Tensor, + size_k: int, + size_n: int, + reverse_perm: torch.Tensor, + tile: int = const.TILE, + ) -> torch.Tensor: + """Reverse permute weights from 16x64 Marlin tiles. + Args: + q_w_unpacked (torch.Tensor): Unpacked quantized weights. + size_k (int): Number of input features. + size_n (int): Number of output features. + reverse_perm (torch.Tensor): The computed reverse permutation tensor to be applied. + tile (int, optional): Tile size. Defaults to `TILE`. + Returns: + torch.Tensor: Weight tensor reverse permuted from Marlin tiles. + """ + + assert (q_w_unpacked.shape[0], size_n) == (size_k // tile, q_w_unpacked.shape[1] // tile) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Reverse permute weights to original shape + q_w_comp = q_w_unpacked.reshape((-1, reverse_perm.numel()))[:, reverse_perm].reshape(q_w_unpacked.shape) + q_w_comp = q_w_comp.reshape((size_k // tile, size_n // tile, tile, tile)) + q_w_comp = q_w_comp.permute((0, 2, 1, 3)) + q_w_comp = q_w_comp.reshape((size_k, size_n)) + + return q_w_comp + + + +def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]: + """Precompute permutations for Marlin24 weight and scale shuffling + + Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible + with the tensor-core format that is described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + + As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core + (without the need to use ldmatrix instructions) + + Args: + num_bits (int): Number of bits to pack. + Returns: + Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list and + scale permutation list for single group. + """ + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return perm, scale_perm, scale_perm_single + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of sparse matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) \ No newline at end of file diff --git a/torchao/sparsity/utils.py b/torchao/sparsity/utils.py index c88e2f8bdf..0669c3cd70 100644 --- a/torchao/sparsity/utils.py +++ b/torchao/sparsity/utils.py @@ -6,6 +6,7 @@ "create_block_sparse_tensor", "create_semi_structured_tensor", "PerChannelNormObserver", + "mask_creator", ] def create_block_sparse_tensor(M, N, blocksize, sparsity, dtype): @@ -86,3 +87,37 @@ def calculate_qparams(self): raise NotImplementedError( "PerChannelNormObserver is designed to store activations only. " ) + + +def mask_creator( + tensor: torch.Tensor, + N: int = 2, + M: int = 4, + ) -> torch.Tensor: + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + :param tensor: The input tensor to create a mask for + :param N: The number of weights in a group to keep + :param M: The size of a weight group + :return: A mask tensor with the same shape as the input tensor + """ + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask diff --git a/torchao/utils.py b/torchao/utils.py index 9239fc999f..329d4790f8 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -131,6 +131,9 @@ def wrapper(*args, **kwargs): return wrapper return decorator +def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Tensor: + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) def benchmark_torch_function_in_microseconds(f, *args, **kwargs): import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded