Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Filter prune algo implementation #1655

Merged
merged 47 commits into from
Nov 15, 2019
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3a45961
Merge pull request #31 from microsoft/master
chicm-ms Aug 6, 2019
633db43
Merge pull request #32 from microsoft/master
chicm-ms Sep 9, 2019
3e926f1
Merge pull request #33 from microsoft/master
chicm-ms Oct 8, 2019
f173789
Merge pull request #34 from microsoft/master
chicm-ms Oct 9, 2019
508850a
Merge pull request #35 from microsoft/master
chicm-ms Oct 9, 2019
5a0e9c9
Merge pull request #36 from microsoft/master
chicm-ms Oct 10, 2019
e7df061
Merge pull request #37 from microsoft/master
chicm-ms Oct 23, 2019
e47c923
fpgm pruner pytorch implementation
chicm-ms Oct 23, 2019
c51f688
updates
chicm-ms Oct 25, 2019
b1165da
updates
chicm-ms Oct 28, 2019
cd32a6a
updates
chicm-ms Oct 28, 2019
8717026
updates
chicm-ms Oct 29, 2019
2175cef
Merge pull request #38 from microsoft/master
chicm-ms Oct 29, 2019
2ccbfbb
Merge pull request #39 from microsoft/master
chicm-ms Oct 30, 2019
b29cb0b
Merge pull request #40 from microsoft/master
chicm-ms Oct 30, 2019
e25d9be
Merge branch 'master' into filter_prune
chicm-ms Oct 30, 2019
216a9a7
updates
chicm-ms Oct 31, 2019
a42a067
updates
chicm-ms Oct 31, 2019
8fd58bb
updates
chicm-ms Oct 31, 2019
4a3ba83
Merge pull request #41 from microsoft/master
chicm-ms Nov 4, 2019
c8a1148
Merge pull request #42 from microsoft/master
chicm-ms Nov 4, 2019
73c6101
Merge pull request #43 from microsoft/master
chicm-ms Nov 5, 2019
fef6ec2
Merge branch 'master' into filter_prune
chicm-ms Nov 7, 2019
ec2b3fb
updates per refactored framework
chicm-ms Nov 7, 2019
3040b6e
updates
chicm-ms Nov 7, 2019
0ca60cf
updates
chicm-ms Nov 7, 2019
cd069fd
updates
chicm-ms Nov 7, 2019
a4a999b
update documents
chicm-ms Nov 7, 2019
bd622a2
updates
chicm-ms Nov 7, 2019
8a939b4
updates
chicm-ms Nov 7, 2019
6a518a9
Merge pull request #44 from microsoft/master
chicm-ms Nov 11, 2019
a0d587f
Merge pull request #45 from microsoft/master
chicm-ms Nov 12, 2019
302f1bd
tensorflow 2.0 implementation
chicm-ms Nov 13, 2019
a3e4b90
updates
chicm-ms Nov 13, 2019
20aedfc
updates
chicm-ms Nov 13, 2019
676348b
updates
chicm-ms Nov 13, 2019
2b22a1a
updates
chicm-ms Nov 13, 2019
e905bfe
Merge pull request #46 from microsoft/master
chicm-ms Nov 14, 2019
43bf2b7
Merge branch 'master' into filter_prune
chicm-ms Nov 14, 2019
9e68e2b
updates
chicm-ms Nov 14, 2019
eadc941
updates
chicm-ms Nov 14, 2019
2978b7c
updates
chicm-ms Nov 14, 2019
03d71da
updates
chicm-ms Nov 14, 2019
22b6475
updates
chicm-ms Nov 14, 2019
053e3d1
updates
chicm-ms Nov 14, 2019
08f5237
updates
chicm-ms Nov 14, 2019
ec8bb4e
updates
chicm-ms Nov 14, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions examples/model_compress/fpgm_tf_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from nni.compression.tensorflow import FPGMPruner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it easy to reproduce one of the experiments in the paper?

Copy link
Contributor Author

@chicm-ms chicm-ms Oct 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key metrics of the pruning are reduced FLOPs and accuracy change, to reproduce the reduced FLOPs, we need to use the pruned compact model, which is not implemented yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to offer some test metrics to make sure the implementation is correct.

Copy link
Contributor Author

@chicm-ms chicm-ms Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I have added verification code in fpgm_torch_mnist.py example to verify the pruned conv kernel weight sparsity. By checking the sparsity and loss, we can verify:

  1. the configured layers are pruned.
  2. the pruned model has similar loss with origin model

But this code still can not verify the implementation is same as the paper. I am considering to add some kind of verification in UT.

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev=0.1))


def bias_variable(shape):
return tf.Variable(tf.constant(0.1, shape=shape))


def conv2d(x_input, w_matrix):
return tf.nn.conv2d(x_input, w_matrix, strides=[1, 1, 1, 1], padding='SAME')


def max_pool(x_input, pool_size):
size = [1, pool_size, pool_size, 1]
return tf.nn.max_pool(x_input, ksize=size, strides=size, padding='SAME')


class Mnist:
def __init__(self):
images = tf.placeholder(tf.float32, [None, 784], name='input_x')
labels = tf.placeholder(tf.float32, [None, 10], name='input_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')

self.images = images
self.labels = labels
self.keep_prob = keep_prob

self.train_step = None
self.accuracy = None

self.w1 = None
self.b1 = None
self.fcw1 = None
self.cross = None
with tf.name_scope('reshape'):
x_image = tf.reshape(images, [-1, 28, 28, 1])
with tf.name_scope('conv1'):
w_conv1 = weight_variable([5, 5, 1, 32])
self.w1 = w_conv1
b_conv1 = bias_variable([32])
self.b1 = b_conv1
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
with tf.name_scope('pool1'):
h_pool1 = max_pool(h_conv1, 2)
with tf.name_scope('conv2'):
w_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
with tf.name_scope('pool2'):
h_pool2 = max_pool(h_conv2, 2)
with tf.name_scope('fc1'):
w_fc1 = weight_variable([7 * 7 * 64, 1024])
self.fcw1 = w_fc1
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
with tf.name_scope('dropout'):
h_fc1_drop = tf.nn.dropout(h_fc1, 0.5)
with tf.name_scope('fc2'):
w_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
with tf.name_scope('loss'):
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_conv))
self.cross = cross_entropy
with tf.name_scope('adam_optimizer'):
self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
with tf.name_scope('accuracy'):
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


def main():
tf.set_random_seed(0)

data = input_data.read_data_sets('data', one_hot=True)

model = Mnist()

'''you can change this to LevelPruner to implement it
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
pruner = LevelPruner(configure_list)
'''
configure_list = [{
'pruning_rate': 0.5,
'op_types': ['Conv2D']
}]
pruner = FPGMPruner(configure_list)
# if you want to load from yaml file
# configure_file = nni.compressors.tf_compressor._nnimc_tf._tf_default_load_configure_file('configure_example.yaml','AGPruner')
# configure_list = configure_file.get('config',[])
# pruner.load_configure(configure_list)
# you can also handle it yourself and input an configure list in json
pruner(tf.get_default_graph())
# you can also use compress(model) or compress_default_graph() for tensorflow compressor
# pruner.compress(tf.get_default_graph())

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for batch_idx in range(2000):
if batch_idx % 10 == 0:
pruner.update_epoch(batch_idx / 10, sess)
batch = data.train.next_batch(2000)
model.train_step.run(feed_dict={
model.images: batch[0],
model.labels: batch[1],
model.keep_prob: 0.5
})
if batch_idx % 10 == 0:
test_acc = model.accuracy.eval(feed_dict={
model.images: data.test.images,
model.labels: data.test.labels,
model.keep_prob: 1.0
})
print('test accuracy', test_acc)

test_acc = model.accuracy.eval(feed_dict={
model.images: data.test.images,
model.labels: data.test.labels,
model.keep_prob: 1.0
})
print('final result is', test_acc)


if __name__ == '__main__':
main()
92 changes: 92 additions & 0 deletions examples/model_compress/fpgm_torch_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from nni.compression.torch import FPGMPruner
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms


class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))


def main():
torch.manual_seed(0)
device = torch.device('cpu')

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()

'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
configure_list = [{
'pruning_rate': 0.5,
'op_types': ['Conv2d']
}]

pruner = FPGMPruner(configure_list)
pruner(model)
# you can also use compress(model) method
# like that pruner.compress(model)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)


if __name__ == '__main__':
main()
82 changes: 81 additions & 1 deletion src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tensorflow as tf
from .compressor import Pruner

__all__ = ['LevelPruner', 'AGP_Pruner']
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner']

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,3 +94,83 @@ def update_epoch(self, epoch, sess):
sess.run(tf.assign(self.now_epoch, int(epoch)))
for k in self.if_init_list:
self.if_init_list[k] = True

class FPGMPruner(Pruner):
"""A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""

def __init__(self, config_list):
"""
config_list: supported keys:
- pruning_rate: percentage of convolutional filters to be pruned.
"""
super().__init__(config_list)
self.mask_list = {}
self.assign_handler = []

def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name, **kwargs):
"""supports Conv1d, Conv2d, Conv3d
filter dimensions for Conv1D:
chicm-ms marked this conversation as resolved.
Show resolved Hide resolved
LEN: filter length
IN: number of input channel
OUT: number of output channel

filter dimensions for Conv2D:
H: filter height
W: filter width
IN: number of input channel
OUT: number of output channel
"""

assert 0 <= config.get('pruning_rate') < 1
# TODO uncomment this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rmove?

#assert op_type in ['Conv1D', 'Conv2D', 'Conv3D']

if op_type == config['op_type']:
weight = tf.stop_gradient(tf.transpose(conv_kernel_weight, [2,3,0,1]))
masks = tf.Variable(tf.ones_like(weight))

num_kernels = weight.shape[0].value * weight.shape[1].value
num_prune = int(num_kernels * config.get('pruning_rate'))
if num_kernels < 2 or num_prune < 1:
self.mask_list.update({op_name: masks})
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
tf.scatter_nd_update(masks, min_gm_idx, tf.zeros((min_gm_idx.shape[0].value, weight.shape[-2].value, weight.shape[-1].value)))
masks = tf.transpose(masks, [2,3,0,1])
self.assign_handler.append(tf.assign(conv_kernel_weight, conv_kernel_weight*masks))
self.mask_list.update({op_name: masks})
else:
masks = tf.Variable(tf.ones_like(conv_kernel_weight))
self.mask_list.update({op_name: masks})

return masks

def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.shape) >= 3
assert weight.shape[0].value * weight.shape[1].value > 2

dist_list, idx_list = [], []
for in_i in range(weight.shape[0].value):
for out_i in range(weight.shape[1].value):
dist_sum = self._get_distance_sum(weight, in_i, out_i)
dist_list.append(dist_sum)
idx_list.append([in_i, out_i])
dist_tensor = tf.convert_to_tensor(dist_list)
idx_tensor = tf.constant(idx_list)

_, idx = tf.math.top_k(dist_tensor, k=n)
return tf.gather(idx_tensor, idx)

def _get_distance_sum(self, weight, in_idx, out_idx):
w = tf.reshape(weight, (-1, weight.shape[-2].value, weight.shape[-1].value))
anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0].value, 1, 1])
x = w - anchor_w
x = tf.math.reduce_sum((x*x), (-2, -1))
x = tf.math.sqrt(x)
return tf.math.reduce_sum(x)

def update_epoch(self, epoch, sess):
sess.run(self.assign_handler)
Loading