Skip to content

Commit

Permalink
AutoML for model compression (microsoft#2573)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms authored Aug 12, 2020
1 parent 3757cf2 commit e9f3cdd
Show file tree
Hide file tree
Showing 22 changed files with 2,851 additions and 12 deletions.
4 changes: 4 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
set -e
sudo apt-get install -y pandoc
python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==2.2.0 --user
python3 -m pip install keras==2.4.2 --user
python3 -m pip install gym onnx peewee thop --user
Expand Down Expand Up @@ -68,6 +69,7 @@ jobs:
- script: |
set -e
python3 -m pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx peewee --user
Expand Down Expand Up @@ -117,6 +119,7 @@ jobs:
set -e
# pytorch Mac binary does not support CUDA, default is cpu version
python3 -m pip install torchvision==0.6.0 torch==1.5.0 --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user
brew install swig@3
rm -f /usr/local/bin/swig
Expand Down Expand Up @@ -144,6 +147,7 @@ jobs:
python -m pip install scikit-learn==0.23.2 --user
python -m pip install keras==2.1.6 --user
python -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install tensorboardX==1.9
python -m pip install tensorflow==1.15.2 --user
displayName: 'Install dependencies'
- script: |
Expand Down
34 changes: 34 additions & 0 deletions docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a
* [NetAdapt Pruner](#netadapt-pruner)
* [SimulatedAnnealing Pruner](#simulatedannealing-pruner)
* [AutoCompress Pruner](#autocompress-pruner)
* [AutoML for Model Compression Pruner](#automl-for-model-compression-pruner)
* [Sensitivity Pruner](#sensitivity-pruner)

**Others**
Expand Down Expand Up @@ -476,6 +477,39 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AutoCompressPruner
```

## AutoML for Model Compression Pruner

AutoML for Model Compression Pruner (AMCPruner) leverages reinforcement learning to provide the model compression policy.
This learning-based compression policy outperforms conventional rule-based compression policy by having higher compression ratio,
better preserving the accuracy and freeing human labor.

![](../../img/amc_pruner.jpg)

For more details, please refer to [AMC: AutoML for Model Compression and Acceleration on Mobile Devices](https://arxiv.org/pdf/1802.03494.pdf).


#### Usage

PyTorch code

```python
from nni.compression.torch import AMCPruner
config_list = [{
'op_types': ['Conv2d', 'Linear']
}]
pruner = AMCPruner(model, config_list, evaluator, val_loader, flops_ratio=0.5)
pruner.compress()
```

You can view [example](https://github.com/microsoft/nni/blob/master/examples/model_compress/amc/) for more information.

#### User configuration for AutoCompress Pruner

##### PyTorch

```eval_rst
.. autoclass:: nni.compression.torch.AMCPruner
```

## ADMM Pruner
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
Expand Down
Binary file added docs/img/amc_pruner.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
136 changes: 136 additions & 0 deletions examples/model_compress/amc/amc_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import argparse
import time

import torch
import torch.nn as nn

from nni.compression.torch import AMCPruner
from data import get_split_dataset
from utils import AverageMeter, accuracy

sys.path.append('../models')

def parse_args():
parser = argparse.ArgumentParser(description='AMC search script')
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2'], help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model')
parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity')
parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity')
parser.add_argument('--ckpt_path', default=None, type=str, help='manual path of checkpoint')

parser.add_argument('--train_episode', default=800, type=int, help='number of training episode')
parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use')
parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker')
parser.add_argument('--job', default='train_export', type=str, choices=['train_export', 'export_only'],
help='search best pruning policy and export or just export model with searched policy')
parser.add_argument('--export_path', default=None, type=str, help='path for exporting models')
parser.add_argument('--searched_model_path', default=None, type=str, help='path for searched best wrapped model')

return parser.parse_args()


def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
if model == 'mobilenet' and dataset == 'imagenet':
from mobilenet import MobileNet
net = MobileNet(n_class=1000)
elif model == 'mobilenetv2' and dataset == 'imagenet':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=1000)
elif model == 'mobilenet' and dataset == 'cifar10':
from mobilenet import MobileNet
net = MobileNet(n_class=10)
elif model == 'mobilenetv2' and dataset == 'cifar10':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=10)
else:
raise NotImplementedError
if checkpoint_path:
print('loading {}...'.format(checkpoint_path))
sd = torch.load(checkpoint_path, map_location=torch.device('cpu'))
if 'state_dict' in sd: # a checkpoint but not a state_dict
sd = sd['state_dict']
sd = {k.replace('module.', ''): v for k, v in sd.items()}
net.load_state_dict(sd)

if torch.cuda.is_available() and n_gpu > 0:
net = net.cuda()
if n_gpu > 1:
net = torch.nn.DataParallel(net, range(n_gpu))

return net

def init_data(args):
# split the train set into train + val
# for CIFAR, split 5k for val
# for ImageNet, split 3k for val
val_size = 5000 if 'cifar' in args.dataset else 3000
train_loader, val_loader, _ = get_split_dataset(
args.dataset, args.batch_size,
args.n_worker, val_size,
data_root=args.data_root,
shuffle=False
) # same sampling
return train_loader, val_loader

def validate(val_loader, model, verbose=False):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

criterion = nn.CrossEntropyLoss().cuda()
# switch to evaluate mode
model.eval()
end = time.time()

t1 = time.time()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
target = target.to(device)
input_var = torch.autograd.Variable(input).to(device)
target_var = torch.autograd.Variable(target).to(device)

# compute output
output = model(input_var)
loss = criterion(output, target_var)

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
t2 = time.time()
if verbose:
print('* Test loss: %.3f top1: %.3f top5: %.3f time: %.3f' %
(losses.avg, top1.avg, top5.avg, t2 - t1))
return top5.avg


if __name__ == "__main__":
args = parse_args()

device = torch.device('cuda') if torch.cuda.is_available() and args.n_gpu > 0 else torch.device('cpu')

model = get_model_and_checkpoint(args.model_type, args.dataset, checkpoint_path=args.ckpt_path, n_gpu=args.n_gpu)
_, val_loader = init_data(args)

config_list = [{
'op_types': ['Conv2d', 'Linear']
}]
pruner = AMCPruner(
model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset,
train_episode=args.train_episode, job=args.job, export_path=args.export_path,
searched_model_path=args.searched_model_path,
flops_ratio=args.flops_ratio, lbound=args.lbound, rbound=args.rbound)
pruner.compress()
Loading

0 comments on commit e9f3cdd

Please sign in to comment.