Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Diogo-V committed Sep 5, 2024
1 parent cf5e286 commit ef17ed6
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 118 deletions.
25 changes: 14 additions & 11 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
logger = logging.getLogger(__name__)

from torchao.float8.inference import Float8MMConfig
aten = torch.ops.aten


###############################
Expand Down Expand Up @@ -682,11 +683,6 @@ class MarlinSparseAQTLayout(AQTLayout):
group_size (int): the group size used to pack the tensor
num_bits (int): the number of bits used to quantize the tensor
"""

implements = classmethod(_implements)
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)

@staticmethod
def __new__(
cls,
Expand Down Expand Up @@ -729,6 +725,19 @@ def __init__(
self.group_size = group_size
self.num_bits = num_bits

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

raise NotImplementedError(
f"MarlinSparseAQTLayout dispatch: attempting to run {func}, this is not supported"
)

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point", "meta"], [self.layout_type, self.original_shape, self.group_size, self.num_bits]

Expand Down Expand Up @@ -826,12 +835,6 @@ def _apply_fn_to_data(self, fn):
return self


# Marlin Sparse op dispatch registration
@MarlinSparseAQTLayout.implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))


@register_layout_cls(Float8LayoutType)
class Float8AQTLayout(AQTLayout):
"""
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,9 @@ def apply_int4_weight_only_quant(weight):
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

# Sparse Marlin only supports symmetric quantization
# Sparse Marlin only supports symmetric quantization.
# NOTE: If we start having lots of layouts that require different configurations,
# we should consider moving this logic somewhere else.
if isinstance(layout_type, MarlinSparseLayoutType):
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
Expand Down
1 change: 0 additions & 1 deletion torchao/sparsity/marlin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import numpy as np
from typing import Tuple, Dict, List

import torchao.sparsity.marlin.utils as utils
Expand Down
26 changes: 13 additions & 13 deletions torchao/sparsity/marlin/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import numpy as np
from typing import List, Tuple
from dataclasses import dataclass, field

Expand Down Expand Up @@ -97,17 +96,13 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]:
"""Precompute permutations for Marlin24 weight and scale shuffling
Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible
with the tensor-core format that is described here:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core
(without the need to use ldmatrix instructions)
with the tensor-core format.
Args:
num_bits (int): Number of bits to pack.
Returns:
Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list and
scale permutation list for single group.
Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list, and
scale permutation list for a single group.
"""
perm_list: List[int] = []
for i in range(32):
Expand All @@ -125,23 +120,28 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]:
4 * block)
for j in range(4):
perm_list.extend([p + 1 * j for p in perm1])
perm = np.array(perm_list)

# Convert to torch tensor
perm = torch.tensor(perm_list, dtype=torch.int32)

if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
interleave = torch.tensor([0, 2, 4, 6, 1, 3, 5, 7], dtype=torch.int32)
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
interleave = torch.tensor([0, 2, 1, 3], dtype=torch.int32)
else:
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))

perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
# Reshape and apply interleave
perm = perm.view(-1, len(interleave))[:, interleave].reshape(-1)

scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])

scale_perm_single: List[int] = []
for i in range(8):
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])

return perm, scale_perm, scale_perm_single


Expand Down
92 changes: 0 additions & 92 deletions wip_test_llama2.py

This file was deleted.

0 comments on commit ef17ed6

Please sign in to comment.