From 28a7cf39873fc90a080859f27102444d8d4bdf7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Diogo=20Ven=C3=A2ncio?= Date: Fri, 6 Sep 2024 15:14:24 +0100 Subject: [PATCH] Add sparse marlin AQT layout (#621) * feat: starting layout implementation fix: namespace of common modules chore: remove not needed test file fix: op name being registered chore: can compile the cuda kernel fix: segmentation fault chore: wip - paste test code just to check if everything passes feat: wip - adding layout. unpack not working fix: circular import feat: wip - can almost revert feat: can unpack. just needs cleanup chore: improve layout code chore: wip - mm needs work feat: wip - something seems wrong fix: e2e test feat: wip - add group param fix: unpack weights feat: marlin is implemented and correct chore: rebase chore: remove old import feat: use int4 instead of dequantizing chore: remove unused fn feat: add checks and validation feat: add new kernel and refactor code (#1) * feat: wip - adding new kernel * feat: wip - continue working on the unpack * feat: wip - working on unpacking * feat: remove old op * feat: more code changes * chore: remove old code * feat: more code * chore: more code changes * chore: more code changes * feat: add more documentation * fix: dataclass * feat: add more docs * feat: remove assert chore: block 8 bits chore: update comment feat: refactor dispatch chore: add validation on group size chore: wip - working on fixing unpack feat: add small readme with sources feat: add checks feat: tests pass & can execute llama2 * compile kind of working * fix: batching and layout outputs correct results * fix: torch.compile * wip * feat: wip * chore: cleanup * chore: review * chore: review v2 * update benchmarks + README --------- Co-authored-by: Jesse Cai --- test/sparsity/test_marlin.py | 115 +++++++++ test/sparsity/test_sparse_api.py | 30 ++- test/test_ops.py | 24 +- torchao/_models/llama/benchmark_results.txt | 1 + torchao/_models/llama/benchmarks.sh | 4 +- torchao/_models/llama/generate.py | 3 + torchao/_models/sam/benchmark.sh | 2 + torchao/_models/sam/eval_combo.py | 34 ++- torchao/_models/sam/results.csv | 1 + .../cuda/sparse_marlin/marlin_kernel_nm.cu | 2 +- torchao/csrc/sparse_marlin.cpp | 2 +- torchao/dtypes/__init__.py | 2 + torchao/dtypes/affine_quantized_tensor.py | 242 +++++++++++++++++- torchao/quantization/README.md | 37 +-- torchao/quantization/quant_api.py | 12 +- torchao/sparsity/README.md | 17 ++ torchao/sparsity/marlin/README.md | 2 +- torchao/sparsity/marlin/__init__.py | 68 ++--- torchao/sparsity/marlin/utils.py | 41 ++- torchao/sparsity/utils.py | 1 - 20 files changed, 538 insertions(+), 102 deletions(-) create mode 100644 test/sparsity/test_marlin.py diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py new file mode 100644 index 0000000000..c12f32ef6a --- /dev/null +++ b/test/sparsity/test_marlin.py @@ -0,0 +1,115 @@ +import torch +import copy +import pytest + +from torch import nn +from torch.testing._internal.common_utils import TestCase, run_tests +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.dtypes import MarlinSparseLayoutType +from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.quantization.quant_api import int4_weight_only, quantize_ +from torchao.sparsity.marlin import ( + pack_to_marlin_24, + unpack_from_marlin_24, + inject_24 +) +from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + quantize_affine, + ZeroPointDomain, + MappingType, +) + + +class SparseMarlin24(TestCase): + + def setUp(self): + super().setUp() + torch.manual_seed(0) + + self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda") + self.model = ( + nn.Sequential( + nn.Linear(4096, 21504), + nn.Linear(21504, 4096), + nn.ReLU(), + nn.Linear(4096, 21504), + nn.Linear(21504, 4096), + ) + .half() + .cuda() + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + def test_quant_sparse_marlin_layout_eager(self): + apply_fake_sparsity(self.model) + model_copy = copy.deepcopy(self.model) + + # Quantized + quantize_(model_copy.bfloat16(), int4_weight_only()) + dense_result = model_copy(self.input.bfloat16()).half() + + # Sparse + quantized + quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + sparse_result = self.model(self.input) + + assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + def test_quant_sparse_marlin_layout_compile(self): + apply_fake_sparsity(self.model) + model_copy = copy.deepcopy(self.model) + + # Quantized + quantize_(model_copy.bfloat16(), int4_weight_only()) + model_copy.foward = torch.compile(model_copy.forward, fullgraph=True) + dense_result = model_copy(self.input.bfloat16()).half() + + # Sparse + quantized + quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + self.model.forward = torch.compile(self.model.forward, fullgraph=True) + sparse_result = self.model(self.input) + + assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + def test_pack_unpack_equivalence(self): + num_bits = 4 + group_size = 128 + shape = (11008, 4096) + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + zero_point_dtype = torch.bfloat16 + mapping_type = MappingType.SYMMETRIC + preserve_zero = True + zero_point_domain = ZeroPointDomain.INT + scale_dtype = None + + w = torch.rand(shape, dtype=torch.float16, device="cuda") + + # Inject 2:4 sparsity mask + w_24, _ = inject_24(w, *w.shape) + + # Quantize weights + scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) + w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain) + scales = scales.reshape(-1, w_q_24.shape[1]) + + # Test pack/unpack equivalence + q_w_comp, packed_scales, meta = pack_to_marlin_24( + w_q_24, scales, num_bits, group_size + ) + unpacked_q_w, unpacked_scales = unpack_from_marlin_24( + q_w_comp, packed_scales, meta, shape, group_size, num_bits + ) + + assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights" + assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales" + + +if __name__ == "__main__": + run_tests() diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 824dd08f63..2732c8c9ff 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -11,12 +11,11 @@ int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight, ) +from torchao.dtypes import MarlinSparseLayoutType from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, - _get_subclass_inserter, - _is_linear, int8_dynamic_activation_int8_weight, quantize_, + int4_weight_only, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 from torch.testing._internal.common_utils import TestCase @@ -73,5 +72,30 @@ def test_quant_semi_sparse(self): assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_sparse_marlin(self): + input = torch.rand((256, 256)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(256, 1024), + nn.Linear(1024, 256), + ) + .half() + .cuda() + ) + + apply_fake_sparsity(model) + model_copy = copy.deepcopy(model) + + # Quantized + quantize_(model_copy.bfloat16(), int4_weight_only()) + dense_result = model_copy(input.bfloat16()).half() + + # Sparse + quantized + quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + sparse_result = model(input) + + assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + if __name__ == "__main__": unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index eb22f40ad5..e62766756c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -304,6 +304,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size ) +MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64] MARLIN_24_K_CHUNKS = [128] MARLIN_24_N_CHUNKS = [512] MNK_FACTORS = [ @@ -318,8 +319,8 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] MARLIN_TEST_PARAMS = list(itertools.product( - MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS, - MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS + MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, + MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS )) def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int): @@ -374,15 +375,15 @@ def reshape_w(w): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str) -def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors): +@pytest.mark.parametrize("batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str) +def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_m = m_factor size_k = k_chunk * k_factor size_n = n_chunk * n_factor - a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda") + a_input = torch.randn((batch_size, size_m, size_k), dtype=torch.float16, device="cuda") b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda") # Inject 2:4 sparsity @@ -391,19 +392,24 @@ def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors): # Symmetric quantize w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size) + # Reshape input into 2D tensor + input_2d = a_input.view(-1, a_input.shape[-1]) + a_input_in, a_input_out = input_2d.shape + # Obtains reference output - output_ref = torch.matmul(a_input, w_24_ref) + output_ref = torch.matmul(input_2d, w_24_ref) + output_ref = output_ref.reshape(a_input.shape[:-1] + (scale.shape[1],)) # Packs to marlin 2:4 marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size) workspace_24 = marlin_24_workspace(size_n) fn_inputs = ( - a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, - num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1], + input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, + num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out, ) output = torchao.ops.marlin_24_gemm(*fn_inputs) - torch.cuda.synchronize() + output = output.reshape(a_input.shape[:-1] + (marlin_24_scale.shape[1],)) max_diff = compute_max_diff(output, output_ref) assert max_diff < 0.04 diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 3bea35cc49..0868c964db 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -38,3 +38,4 @@ kv cache quantization: 20240826171015, tok/s= 1.95, mem/s= 29.21 GB/s, peak_mem=59.27 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072 20240826172121, tok/s= 1.73, mem/s= 26.02 GB/s, peak_mem=52.62 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization 20240826173230, tok/s= 1.73, mem/s= 25.95 GB/s, peak_mem=34.18 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization --linear_causal_mask +20240906054415, tok/s=226.02, mem/s= 689.20 GB/s, peak_mem= 5.32 GB, model_size= 3.05 GB quant: marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index c86406735a..7d2bda1d6e 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -30,13 +30,15 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt - +# sparse marlin (NOTE: float16) +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt # auto-round w/ quant_lm_head python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround # auto-round w/o quant_lm_head python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 + export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index e452bf9708..2d0b2a035f 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -226,6 +226,9 @@ def main( groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model, int4_weight_only(group_size=groupsize)) + if "marlin" in quantization: + from torchao.dtypes import MarlinSparseLayoutType + quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) if "autoround" in quantization: from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_ from transformers import AutoTokenizer diff --git a/torchao/_models/sam/benchmark.sh b/torchao/_models/sam/benchmark.sh index c52ce33151..cfcf52792d 100755 --- a/torchao/_models/sam/benchmark.sh +++ b/torchao/_models/sam/benchmark.sh @@ -8,3 +8,5 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse # int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse) python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse +# int8 dynamic quant attn + int4 wo + sparse marlin lin 1 + 2:4 sparse lin2 +python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half float16 --device cuda --compress int4_weight_only_sparse diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index 46d3af8248..4fbd13d16a 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -283,6 +283,16 @@ def run( for block in predictor.model.image_encoder.blocks: block.attn.use_rel_pos = use_rel_pos + # Helper filter functions + def attn_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'attn' in name + def mlp_lin1_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'lin1' in name + def mlp_lin2_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'lin2' in name + def mlp_only(mod, name): + return isinstance(mod, torch.nn.Linear) and 'mlp' in name + if compress == "int8_dynamic_quant": quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) if not TORCH_VERSION_AT_LEAST_2_5: @@ -296,15 +306,6 @@ def mlp_only(mod, name): apply_fake_sparsity(predictor.model.image_encoder) sparsify_(predictor.model.image_encoder, semi_sparse_weight()) elif compress == "int8_dynamic_quant_sparse": - def attn_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'attn' in name - def mlp_lin1_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'lin1' in name - def mlp_lin2_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'lin2' in name - def mlp_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'mlp' in name - # apply sparsify first to set qparams apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) @@ -320,7 +321,20 @@ def mlp_only(mod, name): mlp_lin2_only) if not TORCH_VERSION_AT_LEAST_2_5: predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) - + elif compress == "int4_weight_only_sparse": + # apply sparsify first to set qparams + apply_fake_sparsity(predictor.model.image_encoder, + filter_fn=mlp_only) + from torchao.dtypes import MarlinSparseLayoutType + quantize_(predictor.model.image_encoder, + int8_dynamic_activation_int8_weight(), + attn_only) + quantize_(predictor.model.image_encoder, int4_weight_only(layout_type=MarlinSparseLayoutType()), mlp_lin1_only) + sparsify_(predictor.model.image_encoder + semi_sparse_weight(), + mlp_lin2_only) + if not TORCH_VERSION_AT_LEAST_2_5: + predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) else: assert compress is None, f"Unsupported compress mode {compress}" diff --git a/torchao/_models/sam/results.csv b/torchao/_models/sam/results.csv index 0be02c7f37..5ae8a6dbd0 100644 --- a/torchao/_models/sam/results.csv +++ b/torchao/_models/sam/results.csv @@ -4,3 +4,4 @@ cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,ma cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,17068,21,23.96093702681232,41.73459489004953,0.5485481164943489,max-autotune,torch.float16,int4_weight_only_sparse,False,True,True,32,154,4928,None,None diff --git a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu index 29c17d1bdd..4e9e757ff0 100644 --- a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu +++ b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu @@ -1123,4 +1123,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::marlin_24_gemm", &marlin_24_gemm); } -} // namespace torchao \ No newline at end of file +} // namespace torchao diff --git a/torchao/csrc/sparse_marlin.cpp b/torchao/csrc/sparse_marlin.cpp index 70350dda9d..b11da8a74e 100644 --- a/torchao/csrc/sparse_marlin.cpp +++ b/torchao/csrc/sparse_marlin.cpp @@ -5,4 +5,4 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); m.def("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor"); -} \ No newline at end of file +} diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 1ffadcf432..e27bf6497a 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -15,6 +15,7 @@ TensorCoreTiledLayoutType, Float8LayoutType, Float8AQTLayout, + MarlinSparseLayoutType, ) __all__ = [ @@ -33,4 +34,5 @@ "TensorCoreTiledLayoutType", "Float8LayoutType", "Float8AQTLayout", + "MarlinSparseLayoutType", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index f4e5446ba5..784c3c5d87 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,5 +1,6 @@ import torch -from typing import Dict, Callable, Any, Tuple, Optional, Union +from typing import Tuple, Optional, Union +import torchao.ops from collections import defaultdict import functools import math @@ -41,6 +42,7 @@ from torchao.float8.inference import Float8MMConfig aten = torch.ops.aten + ############################### # Base Layout Tensor Subclass # ############################### @@ -489,6 +491,27 @@ class Float8LayoutType(LayoutType): mm_config: Optional[Float8MMConfig] = None +@dataclass(frozen=True) +class MarlinSparseLayoutType(LayoutType): + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. + - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format + - 2º: tensor is injected with 2:4 sparsity + - 3º: transposes it again because the quantization process will compute the scales for dim=-1 + + Args: + input (torch.Tensor): the input tensor to preprocess + + Returns: + torch.Tensor: the preprocessed tensor + """ + from torchao.sparsity.marlin import inject_24 # avoid circular import + input_t = input.t() + w_24, _ = inject_24(input_t, *input_t.shape) + return w_24.t() + + @register_layout_cls(PlainLayoutType) class PlainAQTLayout(AQTLayout): """ @@ -642,6 +665,176 @@ def from_plain( return cls(int_data_compressed, scale, zero_point, layout_type) +@register_layout_cls(MarlinSparseLayoutType) +class MarlinSparseAQTLayout(AQTLayout): + """ + Layout storage class for sparse_marlin_24 layout for affine quantized tensor. + + Can be used with 4 bits and 8 bits quantization. + + Original marlin documentation and information: + https://github.com/IST-DASLab/marlin/tree/master + + Sparse marlin documentation and information: + https://github.com/IST-DASLab/Sparse-Marlin?tab=readme-ov-file + + fields: + original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape + group_size (int): the group size used to pack the tensor + num_bits (int): the number of bits used to quantize the tensor + """ + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + meta: torch.Tensor, + layout_type: LayoutType, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + meta: torch.Tensor, + layout_type: LayoutType, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self.meta = meta + self.layout_type = layout_type + self.original_shape = original_shape + self.group_size = group_size + self.num_bits = num_bits + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"MarlinSparseAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point", "meta"], [self.layout_type, self.original_shape, self.group_size, self.num_bits] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + zero_point = tensor_data_dict["zero_point"] + meta = tensor_data_dict["meta"] + layout_type, original_shape, group_size, num_bits = tensor_attributes + return cls(int_data, scale, zero_point, meta, layout_type, original_shape, group_size, num_bits) + + def get_plain(self): + from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import + int_data_expanded, scales_expanded = unpack_from_marlin_24( + self.int_data, + self.scale, + self.meta, + self.original_shape, + self.group_size, + self.num_bits, + ) + int_data_expanded_t = int_data_expanded.t() + scales_expanded_t = scales_expanded.t() + return int_data_expanded_t, scales_expanded_t, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import + assert isinstance(layout_type, MarlinSparseLayoutType) + + # Linear layers are (in_features, out_features) but the int_data that is reaching this point + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. + q_w_24 = int_data.t() + scale_t = scale.t() + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f'Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel.' + ) + + if q_w_24.dtype != torch.int32: + raise ValueError("Only `torch.int32` weights are supported.") + + in_features, out_features = q_w_24.shape + if in_features % 128 != 0 or out_features != 256 == 0: + raise ValueError( + "`in_features` must be divisible by 64 and `out_features` by 256." + ) + + # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8 + # will require a bit more work to get our current quantization flow to work with it. + # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main + num_bits = 4 if torch.max(q_w_24) < 16 else -1 + if num_bits not in [4]: + raise ValueError( + f"Only {[4]} bits are supported, got {num_bits}." + ) + + group_size = in_features // scale_t.shape[0] + if group_size == 0: + group_size = in_features + assert group_size <= in_features, "Group size must be less than or equal to in_features." + + if group_size not in const.SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." + ) + + # Compress quantized weight to marlin 2:4 format + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size) + + return cls( + marlin_24_q_w_comp, marlin_24_s, zero_point, + meta, layout_type, q_w_24.shape, + group_size, num_bits + ) + + def get_layout_type(self) -> LayoutType: + return self.layout_type + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.scale = fn(self.scale) + self.zero_point = fn(self.zero_point) + self.meta = fn(self.meta) + return self + + @register_layout_cls(Float8LayoutType) class Float8AQTLayout(AQTLayout): """ @@ -758,7 +951,7 @@ def __repr__(self): f"scale={scale},\n" f"transposed={self.transposed}, " f"layout_type={layout_type})") - + @register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): @@ -941,6 +1134,7 @@ def _aqt_is_uint4(aqt): aqt.quant_max is None or aqt.quant_max == 15 ) + implements = AffineQuantizedTensor.implements # following are a list of (dispatch_condition, implementation) functions that takes the following args: @@ -1010,7 +1204,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16 + w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, ).t() y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] @@ -1219,6 +1413,47 @@ def _linear_fp_act_fp8_weight_impl( ).reshape(out_shape) +def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): + return ( + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_uint4(weight_tensor) and + input_tensor.dtype == torch.float16 and + len(weight_tensor.shape) == 2 and + weight_tensor.zero_point_domain == ZeroPointDomain.INT and + isinstance(weight_tensor.layout_type, MarlinSparseLayoutType) + ) + +def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): + from torchao.sparsity.marlin import marlin_24_workspace, const + assert isinstance(weight_tensor, AffineQuantizedTensor) + + sparse_w_int4 = weight_tensor.layout_tensor.int_data + scale = weight_tensor.layout_tensor.scale + meta = weight_tensor.layout_tensor.meta + original_shape = weight_tensor.layout_tensor.original_shape + num_bits = weight_tensor.layout_tensor.num_bits + + # Folds batch dimension into the first dimension + input_2d = input_tensor.view(-1, input_tensor.shape[-1]) + + size_m = input_2d.shape[0] + size_n = scale.shape[1] + size_k = input_2d.shape[1] + workspace_24 = marlin_24_workspace(original_shape[1]) + + out = torchao.ops.marlin_24_gemm( + input_2d, sparse_w_int4, meta, scale, + workspace_24, num_bits, size_m, size_n, size_k + ) + + # Unfold the batch dimension + out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],)) + + if bias is not None: + out += bias.to(out.dtype) + return out + + def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), @@ -1227,6 +1462,7 @@ def _register_aqt_quantized_linear_dispatches(): (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), + (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 33bef5dd4f..092b5daf75 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -5,24 +5,25 @@ Typically quantization algorithms will have different schemes for how the activa Benchmarks are run on a machine with a single A100 GPU using the script in _models/llama which generates text in a latency optimized way (batchsize=1), evaluation was done Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-llama/Meta-Llama-3-8B. -| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | -| ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | -| | int8dq | 12.262 | 9.61 | 63.67 | 8.61 | 6.62 | -| | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 | -| | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 | -| | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 | -| | uintx-4-64 | 12.891 | 48.25 | 189.32 | 6.29 | 3.92 | -| | uintx-2-8 | 28.766 | 36.11 | 238.58 | 9.26 | 6.61 | -| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 | -| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | -| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | -| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | -| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | -| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | -| | uintx-4-64 | 8.113 | 47.77 | 212.90 | 11.85 | 4.46 | -| | uintx-2-8 | 39.368 | 33.21 | 249.22 | 15.04 | 7.51 | -| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | +| | int8dq | 12.262 | 9.61 | 63.67 | 8.61 | 6.62 | +| | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 | +| | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 | +| | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 | +| | uintx-4-64 | 12.891 | 48.25 | 189.32 | 6.29 | 3.92 | +| | uintx-2-8 | 28.766 | 36.11 | 238.58 | 9.26 | 6.61 | +| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 | +| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | +| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | +| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | +| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | +| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | +| | int4wo-64-sparse-marlin | N/A | 226.02 | 689.20 | 5.32 | 3.05 | +| | uintx-4-64 | 8.113 | 47.77 | 212.90 | 11.85 | 4.46 | +| | uintx-2-8 | 39.368 | 33.21 | 249.22 | 15.04 | 7.51 | +| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 239c369d2c..8b48c66c29 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -29,7 +29,8 @@ PlainLayoutType, AffineQuantizedTensor, SemiSparseLayoutType, - Float8LayoutType + Float8LayoutType, + MarlinSparseLayoutType, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, @@ -536,6 +537,15 @@ def apply_int4_weight_only_quant(weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT + + # Sparse Marlin only supports symmetric quantization. + # NOTE: If we start having lots of layouts that require different configurations, + # we should consider moving this logic somewhere else. + if isinstance(layout_type, MarlinSparseLayoutType): + mapping_type = MappingType.SYMMETRIC + preserve_zero = True + zero_point_domain = ZeroPointDomain.INT + return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type, use_hqq=use_hqq) return _get_linear_subclass_inserter(apply_int4_weight_only_quant) diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index b8e04be1b0..aa69eb90a2 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -58,6 +58,23 @@ For more information about accelerting BERT with semi-sturcutred sparsity, pleas | F1 (%) | 86.93 | 86.49 | -0.44 | | Time (bs=16) | 19.35 | 15.74 | 1.23x | +# Implemented APIs + +## Quantization + Sparsity + +### Sparse Marlin 2:4 + +Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regressive Linear (Marlin) dense kernel to support 4-bit quantized weights and 2:4 sparsity, improving performance in matrix multiplication and accumulation. Full documentation can be found [here](https://github.com/IST-DASLab/Sparse-Marlin). + +```py +from torchao.quantization.quant_api import quantize_, int4_weight_only +from torchao.dtypes import MarlinSparseLayoutType + +# Your FP16 model +model = model.cuda().half() + +quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) +``` # Design diff --git a/torchao/sparsity/marlin/README.md b/torchao/sparsity/marlin/README.md index 94d062365a..6ed1759915 100644 --- a/torchao/sparsity/marlin/README.md +++ b/torchao/sparsity/marlin/README.md @@ -3,4 +3,4 @@ Sparse Marlin implementation adapted from the two below sources: * [Sparse-Marlin](https://github.com/IST-DASLab/Sparse-Marlin/tree/main) -* [nm-vllm](https://github.com/neuralmagic/nm-vllm/tree/main) \ No newline at end of file +* [nm-vllm](https://github.com/neuralmagic/nm-vllm/tree/main) diff --git a/torchao/sparsity/marlin/__init__.py b/torchao/sparsity/marlin/__init__.py index 41a83be3d3..3cb45a271e 100644 --- a/torchao/sparsity/marlin/__init__.py +++ b/torchao/sparsity/marlin/__init__.py @@ -1,5 +1,4 @@ import torch -import numpy as np from typing import Tuple, Dict, List import torchao.sparsity.marlin.utils as utils @@ -201,6 +200,9 @@ def _decompress_quantized_24_weight( q_24_no_zp = utils.sparse_semi_structured_to_dense_cutlass(q_24_no_zp_comp, meta) q_24_no_zp = q_24_no_zp.t().contiguous() + # Revert meta resize + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + # Restore zp q_24 = q_24_no_zp + zp @@ -224,19 +226,21 @@ def _to_marlin_weights( torch.Tensor: The weight tensor in the marlin 2:4 format. """ # Permute - q_w = utils.marlin_permute_weights(q_w, size_k, size_n, marlin_24_perm[num_bits]) + perm_24, _, _ = utils.get_perms_24(num_bits) + q_w = utils.marlin_permute_weights(q_w, size_k, size_n, perm_24) # Pack pack_factor = utils.get_pack_factor(num_bits) orig_device = q_w.device - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + # Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32 + # does not support rshift_cpu. + q_w = q_w.cpu().to(torch.int64) + q_packed = torch.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=torch.int64, device=q_w.device) for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i + q_packed |= q_w[:, i::pack_factor] << (num_bits * i) - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + q_packed = q_packed.to(orig_device, dtype=torch.int32) return q_packed @@ -256,20 +260,22 @@ def _from_marlin_weights( Returns: torch.Tensor: The weight tensor in the quantized 2:4 sparse format. """ - reverse_perm = reverse_marlin_24_perm[num_bits] + perm_24, _, _ = utils.get_reverse_perms_24(num_bits) pack_factor = utils.get_pack_factor(num_bits) orig_device = q_packed.device - # Unpack - q_packed = q_packed.cpu().numpy().astype(np.uint32) - q_w_unpacked = np.zeros((q_packed.shape[0], q_packed.shape[1] * pack_factor), dtype=np.uint32) + # Unpack from marlin format. + # Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32 + # does not support rshift_cpu. + q_packed = q_packed.cpu().to(torch.int64) + q_w_unpacked = torch.zeros((q_packed.shape[0], q_packed.shape[1] * pack_factor), dtype=torch.int64, device=q_packed.device) for i in range(pack_factor): q_w_unpacked[:, i::pack_factor] = (q_packed >> (num_bits * i)) & ((1 << num_bits) - 1) - q_w_unpacked = torch.from_numpy(q_w_unpacked.astype(np.int32)).to(orig_device) + q_w_unpacked = q_w_unpacked.to(orig_device, dtype=torch.int32) - q_w_comp = utils.reverse_marlin_permute_weights(q_w_unpacked, size_k, size_n, reverse_perm) + q_w_comp = utils.reverse_marlin_permute_weights(q_w_unpacked, size_k, size_n, perm_24) return q_w_comp @@ -291,12 +297,11 @@ def _to_marlin_scales( Returns: torch.Tensor: The scale tensor in the marlin format. """ + _, scale_perm_24, scale_perm_single_24 = utils.get_perms_24(num_bits) if group_size < size_k and group_size != -1: - perms = marlin_24_scale_perm[num_bits] - scales = scales.reshape((-1, len(perms)))[:, perms] + scales = scales.reshape((-1, len(scale_perm_24)))[:, scale_perm_24] else: - perms = marlin_24_scale_perm_single[num_bits] - scales = scales.reshape((-1, len(perms)))[:, perms] + scales = scales.reshape((-1, len(scale_perm_single_24)))[:, scale_perm_single_24] scales = scales.reshape((-1, size_n)).contiguous() return scales @@ -319,33 +324,10 @@ def _from_marlin_scale( Returns: torch.Tensor: The scale tensor in their original format """ + _, scale_perm_24, scale_perm_single_24 = utils.get_reverse_perms_24(num_bits) if group_size < size_k and group_size != -1: - reverse_perms = reverse_marlin_24_scale_perm[num_bits] - scales = scales.reshape((-1, len(reverse_perms)))[:, reverse_perms] + scales = scales.reshape((-1, len(scale_perm_24)))[:, scale_perm_24] return scales.reshape((size_k // group_size, size_n)) else: - reverse_perms = reverse_marlin_24_scale_perm_single[num_bits] - scales = scales.reshape((-1, len(reverse_perms)))[:, reverse_perms] + scales = scales.reshape((-1, len(scale_perm_single_24)))[:, scale_perm_single_24] return scales.reshape((1, -1)) - - -# Contains the permutations for marlin 2:4 quantization -marlin_24_perm: Dict[int, torch.Tensor] = {} -marlin_24_scale_perm: Dict[int, List[int]] = {} -marlin_24_scale_perm_single: Dict[int, List[int]] = {} - -# Contains the reverse permutations for marlin 2:4 quantization -reverse_marlin_24_perm: Dict[int, torch.Tensor] = {} -reverse_marlin_24_scale_perm: Dict[int, List[int]] = {} -reverse_marlin_24_scale_perm_single: Dict[int, List[int]] = {} - -for num_bits in const.SUPPORTED_NUM_BITS: - perm_24, scale_perm_24, scale_perm_single_24 = utils.get_perms_24(num_bits) - - marlin_24_perm[num_bits] = perm_24 - marlin_24_scale_perm[num_bits] = scale_perm_24 - marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 - - reverse_marlin_24_perm[num_bits] = perm_24.argsort() - reverse_marlin_24_scale_perm[num_bits] = torch.tensor(scale_perm_24).argsort() - reverse_marlin_24_scale_perm_single[num_bits] = torch.tensor(scale_perm_single_24).argsort() \ No newline at end of file diff --git a/torchao/sparsity/marlin/utils.py b/torchao/sparsity/marlin/utils.py index 08b8f1efce..4ebdf432e3 100644 --- a/torchao/sparsity/marlin/utils.py +++ b/torchao/sparsity/marlin/utils.py @@ -1,5 +1,4 @@ import torch -import numpy as np from typing import List, Tuple from dataclasses import dataclass, field @@ -93,7 +92,6 @@ def reverse_marlin_permute_weights( return q_w_comp - def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]: """Precompute permutations for Marlin24 weight and scale shuffling @@ -107,8 +105,8 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]: Args: num_bits (int): Number of bits to pack. Returns: - Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list and - scale permutation list for single group. + Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list, and + scale permutation list for a single group. """ perm_list: List[int] = [] for i in range(32): @@ -126,23 +124,46 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]: 4 * block) for j in range(4): perm_list.extend([p + 1 * j for p in perm1]) - perm = np.array(perm_list) + + # Convert to torch tensor + perm = torch.tensor(perm_list, dtype=torch.int32) if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + interleave = torch.tensor([0, 2, 4, 6, 1, 3, 5, 7], dtype=torch.int32) elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) + interleave = torch.tensor([0, 2, 1, 3], dtype=torch.int32) else: raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) + # Reshape and apply interleave + perm = perm.view(-1, len(interleave))[:, interleave].reshape(-1) + scale_perm: List[int] = [] for i in range(8): scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] for i in range(8): scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + + return perm, scale_perm, scale_perm_single + + +def get_reverse_perms_24(num_bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Reverse permutation for Marlin24 weight and scale shuffling from `get_perms_24`. + + Args: + num_bits (int): Number of bits to pack. + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The reversed weight permutation tensor, scale permutation list and + scale permutation list for single group. + """ + perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) + + perm = perm_24.argsort() + scale_perm = torch.tensor(scale_perm_24).argsort() + scale_perm_single = torch.tensor(scale_perm_single_24).argsort() + return perm, scale_perm, scale_perm_single @@ -414,4 +435,4 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): dense.view(torch.half).scatter_(0, dense_offsets, sparse.view(torch.half).view(-1)) - return dense.view(m, 2 * k) \ No newline at end of file + return dense.view(m, 2 * k) diff --git a/torchao/sparsity/utils.py b/torchao/sparsity/utils.py index 0669c3cd70..c383ff87ff 100644 --- a/torchao/sparsity/utils.py +++ b/torchao/sparsity/utils.py @@ -37,7 +37,6 @@ def create_semi_structured_tensor( sparse_weight = torch.rand(r, c).to(dtype).cuda() * mask return sparse_weight - # Observers class PerChannelNormObserver(UniformQuantizationObserverBase): """