Skip to content

Commit

Permalink
chore: review
Browse files Browse the repository at this point in the history
  • Loading branch information
Diogo-V committed Sep 5, 2024
1 parent ef17ed6 commit ea70c74
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 55 deletions.
93 changes: 39 additions & 54 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass
from torchao.dtypes import MarlinSparseLayoutType
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.quantization.quant_api import int4_weight_only, quantize_
Expand All @@ -12,20 +13,22 @@
unpack_from_marlin_24,
inject_24
)
from torchao.quantization.utils import (
get_group_qparams_symmetric,
groupwise_affine_quantize_tensor_from_qparams,
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
quantize_affine,
ZeroPointDomain,
MappingType,
)


class SparseMarlin24(TestCase):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_eager(self):
def setUp(self):
super().setUp()
torch.manual_seed(0)

input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
model = (
self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
self.model = (
nn.Sequential(
nn.Linear(4096, 21504),
nn.Linear(21504, 4096),
Expand All @@ -37,48 +40,38 @@ def test_quant_sparse_marlin_layout_eager(self):
.cuda()
)

apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_eager(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
dense_result = model_copy(input.bfloat16()).half()
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
sparse_result = model(input)
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_compile(self):
torch.manual_seed(0)

input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
model = (
nn.Sequential(
nn.Linear(4096, 21504),
nn.Linear(21504, 4096),
nn.ReLU(),
nn.Linear(4096, 21504),
nn.Linear(21504, 4096),
)
.half()
.cuda()
)

apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
model_copy.foward = torch.compile(model_copy.forward, fullgraph=True)
dense_result = model_copy(input.bfloat16()).half()
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
model.forward = torch.compile(model.forward, fullgraph=True)
sparse_result = model(input)
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(self.model)

self.model.forward = torch.compile(self.model.forward, fullgraph=True)
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

Expand All @@ -87,34 +80,26 @@ def test_pack_unpack_equivalence(self):
num_bits = 4
group_size = 128
shape = (11008, 4096)
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
zero_point_dtype = torch.bfloat16
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
scale_dtype = None

w = torch.rand(shape, dtype=torch.float16, device="cuda")
size_k, size_n = w.shape

# Inject 2:4 sparsity mask
w_24, _ = inject_24(w, *w.shape)

# Quantize weights
w_24 = w_24.reshape((-1, group_size, size_n))
w_24 = w_24.permute(1, 0, 2)
w_24 = w_24.reshape((group_size, -1))

# Compute scale for each group
scales = torch.max(torch.abs(w_24), 0, keepdim=True)[0]
scales *= 2 / max_q_val # 2 => symmetric

# Quantize
w_q_24 = torch.round(w_24 / scales).int()
w_q_24 += half_q_val
w_q_24 = torch.clamp(w_q_24, 0, max_q_val)

# Shape back to original shape
w_q_24 = w_q_24.reshape((group_size, -1, size_n))
w_q_24 = w_q_24.permute(1, 0, 2)
w_q_24 = w_q_24.reshape((size_k, size_n)).contiguous()
scales = scales.reshape((-1, size_n)).contiguous()
scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain)
scales = scales.reshape(-1, w_q_24.shape[1])

# Test pack/unpack equivalence
q_w_comp, packed_scales, meta = pack_to_marlin_24(
Expand Down
17 changes: 17 additions & 0 deletions torchao/sparsity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ For more information about accelerting BERT with semi-sturcutred sparsity, pleas
| F1 (%) | 86.93 | 86.49 | -0.44 |
| Time (bs=16) | 19.35 | 15.74 | 1.23x |

# Implemented APIs

## Quantization + Sparsity

### Sparse Marlin 2:4

Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regressive Linear (Marlin) dense kernel to support 4-bit quantized weights and 2:4 sparsity, improving performance in matrix multiplication and accumulation. Full documentation can be found [here](https://github.com/IST-DASLab/Sparse-Marlin).

```py
from torchao.quantization.quant_api import quantize_, int4_weight_only
from torchao.dtypes import MarlinSparseLayoutType

# Your FP16 model
model = model.cuda().half()

quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
```

# Design

Expand Down
6 changes: 5 additions & 1 deletion torchao/sparsity/marlin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]:
"""Precompute permutations for Marlin24 weight and scale shuffling
Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible
with the tensor-core format.
with the tensor-core format that is described here:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core
(without the need to use ldmatrix instructions)
Args:
num_bits (int): Number of bits to pack.
Expand Down

0 comments on commit ea70c74

Please sign in to comment.