forked from microsoft/nni
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AutoML for model compression (microsoft#2573)
(cherry picked from commit e9f3cdd)
- Loading branch information
Showing
22 changed files
with
2,851 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import sys | ||
import argparse | ||
import time | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from nni.compression.torch import AMCPruner | ||
from data import get_split_dataset | ||
from utils import AverageMeter, accuracy | ||
|
||
sys.path.append('../models') | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='AMC search script') | ||
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2'], help='model to prune') | ||
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)') | ||
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size') | ||
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path') | ||
parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model') | ||
parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity') | ||
parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity') | ||
parser.add_argument('--ckpt_path', default=None, type=str, help='manual path of checkpoint') | ||
|
||
parser.add_argument('--train_episode', default=800, type=int, help='number of training episode') | ||
parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use') | ||
parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker') | ||
parser.add_argument('--job', default='train_export', type=str, choices=['train_export', 'export_only'], | ||
help='search best pruning policy and export or just export model with searched policy') | ||
parser.add_argument('--export_path', default=None, type=str, help='path for exporting models') | ||
parser.add_argument('--searched_model_path', default=None, type=str, help='path for searched best wrapped model') | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1): | ||
if model == 'mobilenet' and dataset == 'imagenet': | ||
from mobilenet import MobileNet | ||
net = MobileNet(n_class=1000) | ||
elif model == 'mobilenetv2' and dataset == 'imagenet': | ||
from mobilenet_v2 import MobileNetV2 | ||
net = MobileNetV2(n_class=1000) | ||
elif model == 'mobilenet' and dataset == 'cifar10': | ||
from mobilenet import MobileNet | ||
net = MobileNet(n_class=10) | ||
elif model == 'mobilenetv2' and dataset == 'cifar10': | ||
from mobilenet_v2 import MobileNetV2 | ||
net = MobileNetV2(n_class=10) | ||
else: | ||
raise NotImplementedError | ||
if checkpoint_path: | ||
print('loading {}...'.format(checkpoint_path)) | ||
sd = torch.load(checkpoint_path, map_location=torch.device('cpu')) | ||
if 'state_dict' in sd: # a checkpoint but not a state_dict | ||
sd = sd['state_dict'] | ||
sd = {k.replace('module.', ''): v for k, v in sd.items()} | ||
net.load_state_dict(sd) | ||
|
||
if torch.cuda.is_available() and n_gpu > 0: | ||
net = net.cuda() | ||
if n_gpu > 1: | ||
net = torch.nn.DataParallel(net, range(n_gpu)) | ||
|
||
return net | ||
|
||
def init_data(args): | ||
# split the train set into train + val | ||
# for CIFAR, split 5k for val | ||
# for ImageNet, split 3k for val | ||
val_size = 5000 if 'cifar' in args.dataset else 3000 | ||
train_loader, val_loader, _ = get_split_dataset( | ||
args.dataset, args.batch_size, | ||
args.n_worker, val_size, | ||
data_root=args.data_root, | ||
shuffle=False | ||
) # same sampling | ||
return train_loader, val_loader | ||
|
||
def validate(val_loader, model, verbose=False): | ||
batch_time = AverageMeter() | ||
losses = AverageMeter() | ||
top1 = AverageMeter() | ||
top5 = AverageMeter() | ||
|
||
criterion = nn.CrossEntropyLoss().cuda() | ||
# switch to evaluate mode | ||
model.eval() | ||
end = time.time() | ||
|
||
t1 = time.time() | ||
with torch.no_grad(): | ||
for i, (input, target) in enumerate(val_loader): | ||
target = target.to(device) | ||
input_var = torch.autograd.Variable(input).to(device) | ||
target_var = torch.autograd.Variable(target).to(device) | ||
|
||
# compute output | ||
output = model(input_var) | ||
loss = criterion(output, target_var) | ||
|
||
# measure accuracy and record loss | ||
prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(prec1.item(), input.size(0)) | ||
top5.update(prec5.item(), input.size(0)) | ||
|
||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
t2 = time.time() | ||
if verbose: | ||
print('* Test loss: %.3f top1: %.3f top5: %.3f time: %.3f' % | ||
(losses.avg, top1.avg, top5.avg, t2 - t1)) | ||
return top5.avg | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
|
||
device = torch.device('cuda') if torch.cuda.is_available() and args.n_gpu > 0 else torch.device('cpu') | ||
|
||
model = get_model_and_checkpoint(args.model_type, args.dataset, checkpoint_path=args.ckpt_path, n_gpu=args.n_gpu) | ||
_, val_loader = init_data(args) | ||
|
||
config_list = [{ | ||
'op_types': ['Conv2d', 'Linear'] | ||
}] | ||
pruner = AMCPruner( | ||
model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset, | ||
train_episode=args.train_episode, job=args.job, export_path=args.export_path, | ||
searched_model_path=args.searched_model_path, | ||
flops_ratio=args.flops_ratio, lbound=args.lbound, rbound=args.rbound) | ||
pruner.compress() |
Oops, something went wrong.