Torch-Pruning (TP) 是一个用于结构化剪枝的库,具有以下特点:
- 通用剪枝工具包: TP支持对各种深度神经网络进行结构化剪枝,包括 大型语言模型(LLMs)、Segment Anything Model(SAM)、扩散模型、Yolov7、Yolov8、视觉Transformer、Swin Transformer、BERT、FasterRCNN、SSD、ResNe(X)t、ConvNext、DenseNet、RegNet、DeepLab等。不同于 torch.nn.utils.prune 通过参数置零参实现的剪枝,Torch-Pruning 使用一种名为 DepGraph 的算法物理移除互相耦合的参数。
- 示例: 剪枝来自Timm、Huggingface Transformers、Torchvision、Yolo等库的预训练模型。
- 基准测试: 复现我们在DepGraph论文中的结果。
DepGraph: Towards Any Structural Pruning
Gongfan Fang、Xinyin Ma、Mingli Song、Michael Bi Mi、Xinchao Wang
Learning and Vision Lab、National University of Singapore
- 2023.12.19 🚀 DeepCache: Accelerating Diffusion Models for Free
- 2023.12.19 🚀 SlimSAM: 0.1% Data Makes Segment Anything Slim
- 2023.09.06 Vision Transformers、Swin Transformers、Bert的剪枝和微调示例
- 2023.07.19 支持LLaMA、LLaMA-2、Vicuna、Baichuan、Bloom: LLM-Pruner
- 2023.05.20 LLM-Pruner: 大型语言模型的结构化剪枝 [arXiv]
- 2023.05.19 扩散模型的结构化剪枝 [arXiv]
- 2023.04.15 YOLOv7 / YOLOv8的剪枝和后训练
- 高级剪枝器:MetaPruner、MagnitudePruner、BNScalePruner、GroupNormPruner、GrowingRegPruner、RandomPruner等。可以在我们的 wiki页面 上找到相关论文列表。
- 自动化结构化剪枝的依赖图
- 低级剪枝函数
- 支持的重要性准则:L-p 范数、Taylor、Random、BNScaling等
- 支持的模块:Linear、(Transposed) Conv、Normalization、PReLU、Embedding、MultiheadAttention、nn.Parameters、自定义模块 及嵌套/组合模块
- 支持的操作:split、concatenation、skip connection、flatten、reshape、view、所有element-wise操作等
- 基准测试、教程和示例
如果在使用库或论文时遇到任何问题,请随时提交 Issue。
Torch-Pruning兼容PyTorch 1.x和2.x版本。强烈推荐使用PyTorch 2.0。
pip install torch-pruning
git clone https://github.com/VainF/Torch-Pruning.git
这里我们提供一个Torch-Pruning的快速入门。更多详细的解释可以在 Tutorals 中找到。
请确保你的模型已设置为启用AutoGrad,且没有使用 torch.no_grad 或 .requires_grad=False。
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True).eval()
# 1. 为resnet18构建依赖图
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))
# 2. 为model.conv1分组耦合层
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
# 3. 执行剪枝
if DG.check_pruning_group(group): # 避免完全剪枝,即channels=0
# 4. 保存与加载
model.zero_grad() # 清除梯度,避免较大的checkpoint
torch.save(model, 'model.pth') # 我们不能使用.state_dict进行存储,这是因为剪枝导致模型结构发生变化。
model = torch.load('model.pth') # 加载剪枝后的模型
上面的示例展示了使用DepGraph的基本剪枝流程。目标层resnet.conv1与多个层耦合,因此在结构化剪枝期间需要同时移除这些层。为了观察剪枝操作的级联效应,我们可以打印这些组并观察一个剪枝操作如何“触发”其他操作。在后续输出中,“A => B”表示剪枝操作“A”触发了剪枝操作“B”。组[0]指的是在DG.get_pruning_group中的剪枝起始位置。有关分组的更多详细信息,请参阅 Wiki - DepGraph & Group.
Pruning Group
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
我们可以使用DG.get_all_groups(ignored_layers, root_module_types)
for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
# handle groups in sequential order
idxs = [2,4,6] # your pruning indices
使用DepGraph,我们在该项目中开发了几种高级剪枝器,以简化剪枝过程。通过指定所需的通道剪枝比例,剪枝器将扫描所有可剪枝的组,估算重要性,剪枝整个模型,并使用您自己的训练代码进行微调。有关此过程的详细信息,请参阅 this tutorial,该教程展示了如何从头实现一个 slimming 剪枝器。此外,一个更实际的示例可在 benchmarks/main.py 中找到。
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)
# 1. Importance criterion
imp = tp.importance.GroupTaylorImportance() # or GroupNormImportance(p=2), GroupHessianImportance(), etc.
# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
# pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks
# 3. Prune & finetune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
if isinstance(imp, tp.importance.GroupTaylorImportance):
# Taylor expansion requires gradients for importance estimation
loss = model(example_inputs).sum() # A dummy loss, please replace this line with your loss function and data!
loss.backward() # before pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune the pruned model here
# finetune(model)
# ...
一些剪枝器如 BNScalePruner
和 GroupNormPruner
for epoch in range(epochs):
pruner.update_regularizer() # <== initialize regularizer
for i, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
out = model(data)
loss = F.cross_entropy(out, target)
loss.backward() # after loss.backward()
pruner.regularize(model) # <== for sparse training
optimizer.step() # before optimizer.step()
for i in range(iterative_steps):
for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
# do whatever you like with the group
dep, idxs = group[0] # get the idxs
target_module = dep.target.module # get the root module
pruning_fn = dep.handler # get the pruning function
# group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune your model here
# finetune(model)
# ...
可以轻松实现软剪枝,它将参数置零而不移除它们。一个示例可以在 tests/test_soft_pruning.py
- Pruning a ResNet50 pre-trained on ImageNet-1K without fine-tuning.
- Pruning a Vision Transformer pre-trained on ImageNet-1K without fine-tuning.
class Scale(nn.Module):
Scale vector by element multiplications.
def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):
self.shape = (dim, 1, 1) if use_nchw else (dim,) # static shape, which should be updated after pruning
self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)
def forward(self, x):
return x * self.scale.view(self.shape) # => x * self.scale.view(-1, 1, 1), this works for pruning
以下脚本将整个模型对象(结构+权重)保存为 model.pth
model.zero_grad() # Remove gradients
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model
# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(pruned_model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')
# create a new model, e.g. resnet18
new_model = resnet18().eval()
# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
参阅 tests/test_serialization.py 了解ViT示例。在此示例中,我们将剪枝模型并修改一些属性如model.hidden_dims
tp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] )
# fix the broken dependencies manually
tp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] )
The following pruning functions are available:
请参阅 examples/transformers/prune_hf_swin.py,该示例为自定义模块SwinPatchMerging
实现了一个新的剪枝器。一个更简单的示例可以在 tests/test_customized_layer.py 中找到。
Method | Base (%) | Pruned (%) |
Speed Up |
NIPS [1] | - | - | -0.03 | 1.76x |
Geometric [2] | 93.59 | 93.26 | -0.33 | 1.70x |
Polar [3] | 93.80 | 93.83 | +0.03 | 1.88x |
CP [4] | 92.80 | 91.80 | -1.00 | 2.00x |
AMC [5] | 92.80 | 91.90 | -0.90 | 2.00x |
HRank [6] | 93.26 | 92.17 | -0.09 | 2.00x |
SFP [7] | 93.59 | 93.36 | +0.23 | 2.11x |
ResRep [8] | 93.71 | 93.71 | +0.00 | 2.12x |
Ours-L1 | 93.53 | 92.93 | -0.60 | 2.12x |
Ours-BN | 93.53 | 93.29 | -0.24 | 2.12x |
Ours-Group | 93.53 | 93.77 | +0.38 | 2.13x |
Latency test on ResNet-50, Batch Size=64.
[Iter 0] Pruning ratio: 0.00, MACs: 4.12 G, Params: 25.56 M, Latency: 45.22 ms +- 0.03 ms
[Iter 1] Pruning ratio: 0.05, MACs: 3.68 G, Params: 22.97 M, Latency: 46.53 ms +- 0.06 ms
[Iter 2] Pruning ratio: 0.10, MACs: 3.31 G, Params: 20.63 M, Latency: 43.85 ms +- 0.08 ms
[Iter 3] Pruning ratio: 0.15, MACs: 2.97 G, Params: 18.36 M, Latency: 41.22 ms +- 0.10 ms
[Iter 4] Pruning ratio: 0.20, MACs: 2.63 G, Params: 16.27 M, Latency: 39.28 ms +- 0.20 ms
[Iter 5] Pruning ratio: 0.25, MACs: 2.35 G, Params: 14.39 M, Latency: 34.60 ms +- 0.19 ms
[Iter 6] Pruning ratio: 0.30, MACs: 2.02 G, Params: 12.46 M, Latency: 33.38 ms +- 0.27 ms
[Iter 7] Pruning ratio: 0.35, MACs: 1.74 G, Params: 10.75 M, Latency: 31.46 ms +- 0.20 ms
[Iter 8] Pruning ratio: 0.40, MACs: 1.50 G, Params: 9.14 M, Latency: 29.04 ms +- 0.19 ms
[Iter 9] Pruning ratio: 0.45, MACs: 1.26 G, Params: 7.68 M, Latency: 27.47 ms +- 0.28 ms
[Iter 10] Pruning ratio: 0.50, MACs: 1.07 G, Params: 6.41 M, Latency: 20.68 ms +- 0.13 ms
[Iter 11] Pruning ratio: 0.55, MACs: 0.85 G, Params: 5.14 M, Latency: 20.48 ms +- 0.21 ms
[Iter 12] Pruning ratio: 0.60, MACs: 0.67 G, Params: 4.07 M, Latency: 18.12 ms +- 0.15 ms
[Iter 13] Pruning ratio: 0.65, MACs: 0.53 G, Params: 3.10 M, Latency: 15.19 ms +- 0.01 ms
[Iter 14] Pruning ratio: 0.70, MACs: 0.39 G, Params: 2.28 M, Latency: 13.47 ms +- 0.01 ms
[Iter 15] Pruning ratio: 0.75, MACs: 0.29 G, Params: 1.61 M, Latency: 10.07 ms +- 0.01 ms
[Iter 16] Pruning ratio: 0.80, MACs: 0.18 G, Params: 1.01 M, Latency: 8.96 ms +- 0.02 ms
[Iter 17] Pruning ratio: 0.85, MACs: 0.10 G, Params: 0.57 M, Latency: 7.03 ms +- 0.04 ms
[Iter 18] Pruning ratio: 0.90, MACs: 0.05 G, Params: 0.25 M, Latency: 5.81 ms +- 0.03 ms
[Iter 19] Pruning ratio: 0.95, MACs: 0.01 G, Params: 0.06 M, Latency: 5.70 ms +- 0.03 ms
[Iter 20] Pruning ratio: 1.00, MACs: 0.01 G, Params: 0.06 M, Latency: 5.71 ms +- 0.03 ms
DepGraph: Towards Any Structural Pruning [Project] [Paper]
Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang
CVPR 2023
LLM-Pruner: On the Structural Pruning of Large Language Models [Project] [arXiv]
Xinyin Ma, Gongfan Fang, Xinchao Wang
NeurIPS 2023
Structural Pruning for Diffusion Models [Project] [arxiv]
Gongfan Fang, Xinyin Ma, Xinchao Wang
NeurIPS 2023
DeepCache: Accelerating Diffusion Models for Free [Project] [Arxiv]
Xinyin Ma, Gongfan Fang, and Xinchao Wang
CVPR 2024
0.1% Data Makes Segment Anything Slim [Project] [Arxiv]
Zigeng Chen, Gongfan Fang, Xinyin Ma, Xinchao Wang
Preprint 2023
title={Depgraph: Towards any structural pruning},
author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
