diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 7e8ac41795..784c3c5d87 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -40,6 +40,7 @@ logger = logging.getLogger(__name__) from torchao.float8.inference import Float8MMConfig +aten = torch.ops.aten ############################### @@ -682,11 +683,6 @@ class MarlinSparseAQTLayout(AQTLayout): group_size (int): the group size used to pack the tensor num_bits (int): the number of bits used to quantize the tensor """ - - implements = classmethod(_implements) - __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) - __torch_function__ = classmethod(_dispatch__torch_function__) - @staticmethod def __new__( cls, @@ -729,6 +725,19 @@ def __init__( self.group_size = group_size self.num_bits = num_bits + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"MarlinSparseAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + def __tensor_flatten__(self): return ["int_data", "scale", "zero_point", "meta"], [self.layout_type, self.original_shape, self.group_size, self.num_bits] @@ -826,12 +835,6 @@ def _apply_fn_to_data(self, fn): return self -# Marlin Sparse op dispatch registration -@MarlinSparseAQTLayout.implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) - - @register_layout_cls(Float8LayoutType) class Float8AQTLayout(AQTLayout): """ diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 72b4b72aaf..4cbece2df2 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -454,7 +454,9 @@ def apply_int4_weight_only_quant(weight): zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - # Sparse Marlin only supports symmetric quantization + # Sparse Marlin only supports symmetric quantization. + # NOTE: If we start having lots of layouts that require different configurations, + # we should consider moving this logic somewhere else. if isinstance(layout_type, MarlinSparseLayoutType): mapping_type = MappingType.SYMMETRIC preserve_zero = True diff --git a/torchao/sparsity/marlin/__init__.py b/torchao/sparsity/marlin/__init__.py index 4652e1687f..3cb45a271e 100644 --- a/torchao/sparsity/marlin/__init__.py +++ b/torchao/sparsity/marlin/__init__.py @@ -1,5 +1,4 @@ import torch -import numpy as np from typing import Tuple, Dict, List import torchao.sparsity.marlin.utils as utils diff --git a/torchao/sparsity/marlin/utils.py b/torchao/sparsity/marlin/utils.py index f2d7d7efe9..9d404cb82c 100644 --- a/torchao/sparsity/marlin/utils.py +++ b/torchao/sparsity/marlin/utils.py @@ -1,5 +1,4 @@ import torch -import numpy as np from typing import List, Tuple from dataclasses import dataclass, field @@ -97,17 +96,13 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]: """Precompute permutations for Marlin24 weight and scale shuffling Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible - with the tensor-core format that is described here: - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type - - As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core - (without the need to use ldmatrix instructions) + with the tensor-core format. Args: num_bits (int): Number of bits to pack. Returns: - Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list and - scale permutation list for single group. + Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list, and + scale permutation list for a single group. """ perm_list: List[int] = [] for i in range(32): @@ -125,23 +120,28 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]: 4 * block) for j in range(4): perm_list.extend([p + 1 * j for p in perm1]) - perm = np.array(perm_list) + + # Convert to torch tensor + perm = torch.tensor(perm_list, dtype=torch.int32) if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + interleave = torch.tensor([0, 2, 4, 6, 1, 3, 5, 7], dtype=torch.int32) elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) + interleave = torch.tensor([0, 2, 1, 3], dtype=torch.int32) else: raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) + # Reshape and apply interleave + perm = perm.view(-1, len(interleave))[:, interleave].reshape(-1) + scale_perm: List[int] = [] for i in range(8): scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] for i in range(8): scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return perm, scale_perm, scale_perm_single diff --git a/wip_test_llama2.py b/wip_test_llama2.py deleted file mode 100644 index 75d5cb7f5d..0000000000 --- a/wip_test_llama2.py +++ /dev/null @@ -1,92 +0,0 @@ -# This script shows how to accelerate an off-the-shelf 2:4 sparse checkpoint -# using pytorch's `to_sparse_semi_structured` - -# Also shows how to use marlin - -# It takes advantage of the model checkpoints offered by neuralmagic: -# https://huggingface.co/nm-testing/SparseLlama-3-8B-pruned_50.2of4-FP8 - -import os -import torch -from torchao.sparsity import sparsify_, semi_sparse_weight - -from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer -from torchao.utils import benchmark_model, profiler_runner -from torchao.quantization import int4_weight_only, quantize_ -from torchao.dtypes import MarlinSparseLayoutType - -os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling -model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" -warmup = 5 -num_runs = 25 - -torch.set_float32_matmul_precision('high') - - -# Even though we need to pad the matmul shapes from (1, hidden) @ (hidden, output) -# to (8, hidden) @ (hidden, output) we are still able to achieve speedups on -# the mlp.up and mlp.gate linear layers of the FFN. -def is_mlp_up_or_mlp_gate(mod, name): - return isinstance(mod, torch.nn.Linear) and ('mlp.gate' in name or 'mlp.up' in name) - -def run_benchmark(compression_config="baseline", dtype=torch.float16): - print (f"\n Running: {compression_config} benchmark with dtype={dtype}\n") - - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype).cuda() - tokenizer = AutoTokenizer.from_pretrained(model_name) - prompt = "Why dogs are so cute?" - inputs = tokenizer(prompt, return_tensors="pt").to("cuda") - - # Specify the max length (including both the prompt and the response) - # When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object - # with sequence length = `max_length`. The longer the more you will re-use it - model.generation_config.max_length = 128 - model.generation_config.pad_token_id = tokenizer.eos_token_id - model.generation_config.cache_implementation = "static" - - if compression_config == "24_sparse": - sparsify_(model, semi_sparse_weight(), filter_fn=is_mlp_up_or_mlp_gate) - elif compression_config == "int4_wo": - assert dtype == torch.bfloat16, "int4 quantization only works with bf16" - quantize_(model, int4_weight_only()) - elif compression_config == "sparse_marlin": - assert dtype == torch.float16, "sparse_marlin only works with fp16" - quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) - elif compression_config == "baseline": - pass - else: - raise ValueError(f"Unknown compression config: {compression_config}") - - # `torch.compile(model, ...)` is not recommended as you compile callbacks - # and full generate. We recommend compiling only the forward for now. - # "reduce-overhead" will use cudagraphs. - torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit = None - - model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) - - # WARMUP - benchmark_model(lambda: model.generate(**inputs), warmup, device_type="cuda") - # res is in ms so multiply by 1000 to get tok/s - res = benchmark_model(lambda: model.generate(**inputs), num_runs, device_type="cuda") - tokens_per_second = 1000 * (121 / res) - print(f"Average time: {res:.3f}ms | Tokens/second: {tokens_per_second:.3f}") - - # sanity check we get same output as non-compiled model - outputs = model.generate(**inputs) - response = tokenizer.batch_decode(outputs)[0] - print(response) - - del model - -## baseline -# run_benchmark(compression_config="baseline", dtype=torch.bfloat16) - -# # ## int4_wo -# run_benchmark(compression_config="int4_wo", dtype=torch.bfloat16) - -# ## sparse marlin -run_benchmark(compression_config="sparse_marlin", dtype=torch.float16) - -## sparse -# run_benchmark(compression_config="24_sparse", dtype=torch.bfloat16)