Skip to content

Commit

Permalink
compile kind of working
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip authored and Diogo-V committed Aug 26, 2024
1 parent a3f32f9 commit 9bcc422
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 42 deletions.
45 changes: 41 additions & 4 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
class SparseMarlin24(TestCase):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_e2e(self):
input = torch.randn((16, 4096), dtype=torch.float16, device="cuda")
def test_quant_sparse_marlin_layout_eager(self):
# this batch input fails
input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
model = (
nn.Sequential(
nn.Linear(4096, 11008), # Llama2 shapes
Expand All @@ -35,20 +36,57 @@ def test_quant_sparse_marlin_layout_e2e(self):
.cuda()
)

apply_fake_sparsity(model)
# Baseline
ref_result = model(input)
model_copy = copy.deepcopy(model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
dense_result = model_copy(input.bfloat16()).half()

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
sparse_result = model(input)

error_dense = torch.mean(torch.abs(ref_result - dense_result) ** 2)
error_sparse = torch.mean(torch.abs(ref_result - sparse_result) ** 2)
assert torch.allclose(dense_model, sparse_model, atol=1e-2), "Mean error is not close"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_compile(self):
input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
model = (
nn.Sequential(
nn.Linear(4096, 11008), # Llama2 shapes
# nn.Linear(11008, 4096),
# nn.ReLU(),
# nn.Linear(4096, 11008),
# nn.Linear(11008, 4096),
)
.half()
.cuda()
)

# Baseline
apply_fake_sparsity(model)
ref_result = model(input)

model_copy = copy.deepcopy(model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
model_copy.foward = torch.compile(model_copy.forward, fullgraph=True)
dense_result = model_copy(input.bfloat16()).half()

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
model.forward = torch.compile(model.forward, fullgraph=True)
sparse_result = model(input)

print(dense_result)
print(sparse_result)
torch.allclose(sparse_result, dense_result)

error_dense = torch.mean(torch.abs(ref_result - dense_result) ** 2)
error_sparse = torch.mean(torch.abs(ref_result - sparse_result) ** 2)
assert torch.allclose(error_dense, error_sparse, atol=1e-2), "Mean error is not close"
Expand All @@ -70,7 +108,6 @@ def test_pack_unpack_equivalence(self):
)

scales = scales.reshape(-1, w_q_24.shape[1])

# Test pack/unpack equivalence
q_w_comp, packed_scales, meta = pack_to_marlin_24(
w_q_24, scales, num_bits, group_size
Expand Down
67 changes: 48 additions & 19 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ def from_plain(
):
pass

@torch._dynamo.disable
def __repr__(self):
int_data, scale, zero_point = self.get_plain()
layout_type = self.get_layout_type()
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"
# This is a hack, torch.compile tries to trace the __repr__ function which then calls `dequantize` function, causing an error.
# by removing the call to dequantize the error goes away.
# int_data, scale, zero_point = self.get_plain()
# layout_type = self.get_layout_type()
return f"{self.__class__.__name__}" #(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
Expand Down Expand Up @@ -152,10 +155,13 @@ def __init__(
self.quant_max = quant_max
self.zero_point_domain = zero_point_domain

@torch._dynamo.disable
def __repr__(self):
return (
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
f"{self.__class__.__name__}"
# Same hack here
#(data={self.dequantize()}, shape={self.shape}, "
#f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
)

def dequantize(self, output_dtype=None):
Expand Down Expand Up @@ -552,6 +558,8 @@ class MarlinSparseAQTLayout(AQTLayout):
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)

@staticmethod
@torch._dynamo.disable
def __new__(
cls,
int_data: torch.Tensor,
Expand All @@ -573,6 +581,7 @@ def __new__(
shape = int_data.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

@torch._dynamo.disable
def __init__(
self,
int_data: torch.Tensor,
Expand All @@ -593,8 +602,24 @@ def __init__(
self.group_size = group_size
self.num_bits = num_bits

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point", "meta"], [self.layout_type, self.original_shape, self.group_size, self.num_bits]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data = tensor_data_dict["int_data"]
scale = tensor_data_dict["scale"]
zero_point = tensor_data_dict["zero_point"]
meta = tensor_data_dict["meta"]
layout_type, original_shape, group_size, num_bits = tensor_attributes
return cls(int_data, scale, zero_point, meta, layout_type, original_shape, group_size, num_bits)

@torch._dynamo.disable
def get_plain(self):
from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import
unpack_from_marlin_24 = torch._dynamo.disable(unpack_from_marlin_24)
int_data_expanded, scales_expanded = unpack_from_marlin_24(
self.int_data,
self.scale,
Expand All @@ -606,6 +631,7 @@ def get_plain(self):
return int_data_expanded, scales_expanded, self.zero_point

@classmethod
@torch._dynamo.disable
def from_plain(
cls,
int_data: torch.Tensor,
Expand Down Expand Up @@ -674,7 +700,7 @@ def _apply_fn_to_data(self, fn):
@MarlinSparseAQTLayout.implements(aten.detach.default)
def block_sparse_detach(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))


@register_layout_cls(TensorCoreTiledLayoutType)
class TensorCoreTiledAQTLayout(AQTLayout):
Expand Down Expand Up @@ -920,7 +946,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16
w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16,
).t()
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
Expand Down Expand Up @@ -1037,6 +1063,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):

def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
return (
isinstance(weight_tensor, AffineQuantizedTensor) and
_aqt_is_uint4(weight_tensor) and
input_tensor.dtype == torch.float16 and
len(weight_tensor.shape) == 2 and
Expand All @@ -1046,11 +1073,13 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,

def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias):
from torchao.sparsity.marlin import marlin_24_workspace, const
assert isinstance(weight_tensor, AffineQuantizedTensor)

sparse_w_int4 = weight_tensor.layout_tensor.int_data
scale = weight_tensor.layout_tensor.scale
meta = weight_tensor.layout_tensor.meta
original_shape = weight_tensor.layout_tensor.original_shape
print("original_shape", original_shape)
num_bits = weight_tensor.layout_tensor.num_bits

# Saves batch size for reshaping back to original shape after the matmul
Expand All @@ -1059,13 +1088,15 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
batch_size = -1
if input_tensor.dim() == 3:
batch_size = input_tensor.size(0)
input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]).contiguous()
input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1])

size_m = input_tensor.shape[0]
size_n = original_shape[1]
size_k = input_tensor.shape[1]
workspace_24 = marlin_24_workspace(original_shape[1])

print(size_m, size_n, size_k)

# Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
if size_k % const.TILE != 0:
pad_size = find_multiple(size_k, const.TILE)
Expand All @@ -1076,11 +1107,9 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
input_tensor, sparse_w_int4, meta, scale,
workspace_24, num_bits, size_m, size_n, size_k
)
torch.cuda.synchronize()

# Reshape back to original shape
if batch_size != -1:
out = out.reshape(batch_size, -1, out.shape[-1])
out = out.view(batch_size, -1, out.shape[-1])

if bias is not None:
out += bias.to(out.dtype)
Expand Down Expand Up @@ -1113,14 +1142,14 @@ def _(func, types, args, kwargs):
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
try:
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
# try:
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
# except:
# if isinstance(input_tensor, AffineQuantizedTensor):
# input_tensor = input_tensor.dequantize()
# if isinstance(weight_tensor, AffineQuantizedTensor):
# weight_tensor = weight_tensor.dequantize()
# return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements(aten.addmm.default)
def _(func, types, args, kwargs):
Expand Down
3 changes: 1 addition & 2 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
# Move to cpu, until issue with MPS memory management of temporary tensors is resolved
if int_data_device_type == 'mps':
int_data = int_data.cpu()
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
if int_data_device_type == 'mps':
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
int_data = int_data.to(device='mps')
return int_data

Expand Down
99 changes: 82 additions & 17 deletions wip_test_llama2.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,89 @@
# This script shows how to accelerate an off-the-shelf 2:4 sparse checkpoint
# using pytorch's `to_sparse_semi_structured`

# Also shows how to use marlin

# It takes advantage of the model checkpoints offered by neuralmagic:
# https://huggingface.co/nm-testing/SparseLlama-3-8B-pruned_50.2of4-FP8

import os
import torch
from torchao import quantize_
from torchao.quantization import int4_weight_only
from torchao.sparsity import sparsify_, semi_sparse_weight

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchao.utils import benchmark_model, profiler_runner
from torchao.quantization import int4_weight_only, quantize_
from torchao.dtypes import MarlinSparseLayoutType
from transformers import AutoTokenizer, LlamaForCausalLM

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
name = "meta-llama/Llama-2-7b-hf"
token = "your token"
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling

torch.set_float32_matmul_precision('high')


# Even though we need to pad the matmul shapes from (1, hidden) @ (hidden, output)
# to (8, hidden) @ (hidden, output) we are still able to achieve speedups on
# the mlp.up and mlp.gate linear layers of the FFN.
def is_mlp_up_or_mlp_gate(mod, name):
return isinstance(mod, torch.nn.Linear) and ('mlp.gate' in name or 'mlp.up' in name)

def run_benchmark(compression_config="baseline", dtype=torch.float16):
print (f"\n Running: {compression_config} benchmark with dtype={dtype}\n")

model = AutoModelForCausalLM.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4", torch_dtype=dtype).cuda()
tokenizer = AutoTokenizer.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4")
prompt = "Why dogs are so cute?"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Specify the max length (including both the prompt and the response)
# When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
# with sequence length = `max_length`. The longer the more you will re-use it
model.generation_config.max_length = 128
model.generation_config.pad_token_id = tokenizer.eos_token_id
model.generation_config.cache_implementation = "static"

if compression_config == "24_sparse":
sparsify_(model, semi_sparse_weight(), filter_fn=is_mlp_up_or_mlp_gate)
elif compression_config == "int4_wo":
assert dtype == torch.bfloat16, "int4 quantization only works with bf16"
quantize_(model, int4_weight_only())
elif compression_config == "sparse_marlin":
assert dtype == torch.float16, "sparse_marlin only works with fp16"
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
elif compression_config == "baseline":
pass
else:
raise ValueError(f"Unknown compression config: {compression_config}")

# `torch.compile(model, ...)` is not recommended as you compile callbacks
# and full generate. We recommend compiling only the forward for now.
# "reduce-overhead" will use cudagraphs.
torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit = None

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

# WARMUP
benchmark_model(lambda: model.generate(**inputs), 5, device_type="cuda")
# res is in ms so multiply by 1000 to get tok/s
res = benchmark_model(lambda: model.generate(**inputs), 25, device_type="cuda")
tokens_per_second = 1000 * (121 / res)
print(f"Average time: {res:.3f}ms | Tokens/second: {tokens_per_second:.3f}")

# sanity check we get same output as non-compiled model
outputs = model.generate(**inputs)
response = tokenizer.batch_decode(outputs)[0]
print(response)

del model

model = LlamaForCausalLM.from_pretrained(name, torch_dtype=torch.float16, token=token).to(device)
tokenizer = AutoTokenizer.from_pretrained(name, token=token)
## baseline
# run_benchmark(compression_config="baseline", dtype=torch.bfloat16)

prompt = "Hey, are you conscious? Can you talk to me? I'm"
inputs = tokenizer(prompt, return_tensors="pt")
# # ## int4_wo
run_benchmark(compression_config="int4_wo", dtype=torch.bfloat16)

# Quantize
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
# ## sparse marlin
# run_benchmark(compression_config="sparse_marlin", dtype=torch.float16)

# Generate
ids = inputs.input_ids.to(device)
generate_ids = model.generate(ids, max_length=30)
out = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(out)
## sparse
# run_benchmark(compression_config="24_sparse", dtype=torch.bfloat16)

0 comments on commit 9bcc422

Please sign in to comment.