-
Notifications
You must be signed in to change notification settings - Fork 241
/
quantize_torch_model.py
49 lines (38 loc) · 1.85 KB
/
quantize_torch_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Iterable
import torch
import torchvision
from torch.utils.data import DataLoader
from ppq import BaseGraph, QuantizationSettingFactory, TargetPlatform
from ppq.api import export_ppq_graph, quantize_torch_model
BATCHSIZE = 32
INPUT_SHAPE = [3, 224, 224]
DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs.
PLATFORM = TargetPlatform.PPL_CUDA_INT8 # identify a target platform for your network.
def load_calibration_dataset() -> Iterable:
return [torch.rand(size=INPUT_SHAPE) for _ in range(32)]
def collate_fn(batch: torch.Tensor) -> torch.Tensor:
return batch.to(DEVICE)
# Load a pretrained mobilenet v2 model
model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)
model = model.to(DEVICE)
# create a setting for quantizing your network with PPL CUDA.
quant_setting = QuantizationSettingFactory.pplcuda_setting()
quant_setting.equalization = True # use layerwise equalization algorithm.
quant_setting.dispatcher = 'conservative' # dispatch this network in conservertive way.
# Load training data for creating a calibration dataloader.
calibration_dataset = load_calibration_dataset()
calibration_dataloader = DataLoader(
dataset=calibration_dataset,
batch_size=BATCHSIZE, shuffle=True)
# quantize your model.
quantized = quantize_torch_model(
model=model, calib_dataloader=calibration_dataloader,
calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE,
setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM,
onnx_export_file='Output/onnx.model', device=DEVICE, verbose=0)
# Quantization Result is a PPQ BaseGraph instance.
assert isinstance(quantized, BaseGraph)
# export quantized graph.
export_ppq_graph(graph=quantized, platform=PLATFORM,
graph_save_to='Output/quantized(onnx).onnx',
config_save_to='Output/quantized(onnx).json')