Skip to content

Commit

Permalink
Update calls to quantize_ everywhere (pytorch#496)
Browse files Browse the repository at this point in the history
Summary:
att

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jul 10, 2024
1 parent 25d2dec commit 7d250fd
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ torchao.quantization
smooth_fq_linear_to_inference
Int4WeightOnlyGPTQQuantizer
Int4WeightOnlyQuantizer
quantize
quantize_
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
int4_weight_only
Expand Down
27 changes: 14 additions & 13 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
""")
raise # Re-raise the ImportError

from torchao.quantization.quant_api import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
from torchao.quantization import (
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
quantize_,
autoquant,
)

Expand All @@ -36,32 +37,32 @@ def format_value(value):
subtable.sort(key=lambda x: x[0]) # Sort metrics alphabetically
formatted_subtable = tabulate(subtable, tablefmt='grid')
main_table.append([task, formatted_subtable])

print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid'))

def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

if quantization == "int8dq":
change_linear_weights_to_int8_dqtensors(model)
quantize_(model, int8_dynamic_activation_int8_weight())
elif quantization == "int8wo":
change_linear_weights_to_int8_woqtensors(model)
elif quantization == "int4wo":
quantize_(model, int8_weight_only())
elif quantization == "int4wo":
# note cannot quantize this model on cpu and run it on cuda at this time
change_linear_weights_to_int4_woqtensors(model.to(device=device))
quantize_(model.to(device=device), int4_weight_only())
elif quantization == "autoquant":
model = autoquant(model.to(device=device))
with torch.no_grad():
result = evaluate(
HFLM(
pretrained=model.to(device),
tokenizer=tokenizer,
batch_size=batch_size,
pretrained=model.to(device),
tokenizer=tokenizer,
batch_size=batch_size,
max_length=max_length),
get_task_dict(tasks),
limit = limit,
Expand Down
18 changes: 10 additions & 8 deletions scripts/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ def run(
block.attn.use_rel_pos = use_rel_pos

if compress == "int8_dynamic_quant":
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass
predictor.model.image_encoder = quantize(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
elif compress == "sparse_mlp_only":
def mlp_only(mod, name):
Expand All @@ -300,7 +300,7 @@ def mlp_only(mod, name):
SparseSemiStructuredTensor._FORCE_CUTLASS = False
from torchao.sparsity import sparsify, apply_fake_sparsity
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass

def attn_only(mod, name):
Expand All @@ -316,9 +316,11 @@ def mlp_only(mod, name):
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)

predictor.model.image_encoder = quantize(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(
predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only
)
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
Expand Down Expand Up @@ -380,7 +382,7 @@ def mlp_only(mod, name):
batch_size,
use_compile,
use_compile_decoder,
pad_input_image_batch,
pad_input_image_batch,
compress)

results = [[r[0], r[1], r[2], r[3].item()] for r in results]
Expand Down Expand Up @@ -411,6 +413,6 @@ def mlp_only(mod, name):
vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile,
use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path]))
f.write(vals+"\n")

if __name__ == '__main__':
fire.Fire(run)
8 changes: 4 additions & 4 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ torch._export.aot_compile(m_unwrapped, example_inputs)
```

### Automatic Inductor Configuration
The `quantize` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues.
The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues.

### Other Available Quantization Techniques
#### A8W8 Dynamic Quantization
Expand All @@ -172,7 +172,7 @@ The `quantize` and `autoquant` apis now automatically use our recommended induct
torch._inductor.config.force_fuse_int_mm_with_mul = True

# for torch 2.4+
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
quantize_(model, int8_dynamic_activation_int8_weight())

# for torch 2.2.2 and 2.3
Expand All @@ -184,7 +184,7 @@ change_linear_weights_to_int8_dqtensors(model)

```python
# for torch 2.4+
from torchao.quantization import quantize, int8_weight_only
from torchao.quantization import quantize_, int8_weight_only
quantize_(model, int8_weight_only())

# for torch 2.2.2 and 2.3
Expand All @@ -199,7 +199,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is

```python
# for torch 2.4+
from torchao.quantization import quantize, int4_weight_only
from torchao.quantization import quantize_, int4_weight_only
quantize_(model, int4_weight_only())

# for torch 2.2.2 and 2.3
Expand Down

0 comments on commit 7d250fd

Please sign in to comment.