Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add layout option to woq int4 api #670

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchao.quantization.dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from torchao.dtypes import TensorCoreTiledLayoutType
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
Expand Down Expand Up @@ -852,18 +853,20 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)}

def api(mod):
kwargs_copy = kwargs.copy()
if TORCH_VERSION_AFTER_2_4:
kwargs_copy = kwargs.copy()
kwargs_copy["group_size"] = groupsize
del kwargs_copy["groupsize"]
quantize_(mod, int4_weight_only(**kwargs_copy))
if not TORCH_VERSION_AFTER_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
kwargs_copy["inner_k_tiles"] = inner_k_tiles
del kwargs_copy["layout_type"]
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)

self._test_lin_weight_subclass_api_impl(
api,
Expand Down
37 changes: 10 additions & 27 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict, Optional

from torchao.dtypes import PlainLayoutType
from torchao.dtypes.uintx.Uintx import UintxLayoutType
from torchao.dtypes import (
to_affine_quantized,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, if there is no circular dep you can remove the import from other functions as well, e.g. int8_weight_only

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

TensorCoreTiledLayoutType,
PlainLayoutType,
AffineQuantizedTensor,
SemiSparseLayoutType
)
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
Expand Down Expand Up @@ -182,9 +189,6 @@ def _replace_with_custom_fn_if_matches_filter(


def _is_linear(mod, *args):
# avoid circular dep
from torchao.dtypes import AffineQuantizedTensor

# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
# when it is shared by multiple linear modules
return (
Expand Down Expand Up @@ -328,9 +332,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
)

def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int8
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)
Expand All @@ -339,9 +340,6 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
Expand Down Expand Up @@ -373,7 +371,7 @@ def insert_subclass(lin):
return insert_subclass


def int4_weight_only(group_size=128, inner_k_tiles=8):
def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)):
"""
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
"tensor_core_tiled" layout for speedup with tinygemm kernel
Expand All @@ -389,16 +387,12 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, choices are [256, 128, 64, 32]
`inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2]
`layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)`
"""
def apply_int4_weight_only_quant(weight):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
from torchao.dtypes import TensorCoreTiledLayoutType

mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int32
Expand All @@ -408,7 +402,6 @@ def apply_int4_weight_only_quant(weight):
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)

return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
Expand All @@ -419,9 +412,6 @@ def int8_weight_only():
Applies int8 weight-only symmetric per-channel quantization to linear layers.
"""
def apply_int8wo_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
Expand All @@ -432,8 +422,6 @@ def apply_int8wo_quant(weight):
return _get_linear_subclass_inserter(apply_int8wo_quant)

def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = 1e-5
Expand All @@ -453,8 +441,6 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
if in_features <= 16:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
Expand All @@ -479,7 +465,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
quantization + 2:4 sparsity to linear layers.
"""
from torchao.dtypes import SemiSparseLayoutType
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())


Expand All @@ -495,8 +480,6 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
quantize_affine,
dequantize_affine,
)
from torchao.dtypes.uintx.Uintx import UintxLayoutType
from torchao.dtypes import to_affine_quantized
from torchao.quantization.quant_api import _get_linear_subclass_inserter
def apply_uintx_weight_only_quant(weight):

Expand Down
Loading