This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Filter prune algo implementation #1655
Merged
Merged
Changes from 19 commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
3a45961
Merge pull request #31 from microsoft/master
chicm-ms 633db43
Merge pull request #32 from microsoft/master
chicm-ms 3e926f1
Merge pull request #33 from microsoft/master
chicm-ms f173789
Merge pull request #34 from microsoft/master
chicm-ms 508850a
Merge pull request #35 from microsoft/master
chicm-ms 5a0e9c9
Merge pull request #36 from microsoft/master
chicm-ms e7df061
Merge pull request #37 from microsoft/master
chicm-ms e47c923
fpgm pruner pytorch implementation
chicm-ms c51f688
updates
chicm-ms b1165da
updates
chicm-ms cd32a6a
updates
chicm-ms 8717026
updates
chicm-ms 2175cef
Merge pull request #38 from microsoft/master
chicm-ms 2ccbfbb
Merge pull request #39 from microsoft/master
chicm-ms b29cb0b
Merge pull request #40 from microsoft/master
chicm-ms e25d9be
Merge branch 'master' into filter_prune
chicm-ms 216a9a7
updates
chicm-ms a42a067
updates
chicm-ms 8fd58bb
updates
chicm-ms 4a3ba83
Merge pull request #41 from microsoft/master
chicm-ms c8a1148
Merge pull request #42 from microsoft/master
chicm-ms 73c6101
Merge pull request #43 from microsoft/master
chicm-ms fef6ec2
Merge branch 'master' into filter_prune
chicm-ms ec2b3fb
updates per refactored framework
chicm-ms 3040b6e
updates
chicm-ms 0ca60cf
updates
chicm-ms cd069fd
updates
chicm-ms a4a999b
update documents
chicm-ms bd622a2
updates
chicm-ms 8a939b4
updates
chicm-ms 6a518a9
Merge pull request #44 from microsoft/master
chicm-ms a0d587f
Merge pull request #45 from microsoft/master
chicm-ms 302f1bd
tensorflow 2.0 implementation
chicm-ms a3e4b90
updates
chicm-ms 20aedfc
updates
chicm-ms 676348b
updates
chicm-ms 2b22a1a
updates
chicm-ms e905bfe
Merge pull request #46 from microsoft/master
chicm-ms 43bf2b7
Merge branch 'master' into filter_prune
chicm-ms 9e68e2b
updates
chicm-ms eadc941
updates
chicm-ms 2978b7c
updates
chicm-ms 03d71da
updates
chicm-ms 22b6475
updates
chicm-ms 053e3d1
updates
chicm-ms 08f5237
updates
chicm-ms ec8bb4e
updates
chicm-ms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from nni.compression.tensorflow import FPGMPruner | ||
import tensorflow as tf | ||
from tensorflow.examples.tutorials.mnist import input_data | ||
|
||
|
||
def weight_variable(shape): | ||
return tf.Variable(tf.truncated_normal(shape, stddev=0.1)) | ||
|
||
|
||
def bias_variable(shape): | ||
return tf.Variable(tf.constant(0.1, shape=shape)) | ||
|
||
|
||
def conv2d(x_input, w_matrix): | ||
return tf.nn.conv2d(x_input, w_matrix, strides=[1, 1, 1, 1], padding='SAME') | ||
|
||
|
||
def max_pool(x_input, pool_size): | ||
size = [1, pool_size, pool_size, 1] | ||
return tf.nn.max_pool(x_input, ksize=size, strides=size, padding='SAME') | ||
|
||
|
||
class Mnist: | ||
def __init__(self): | ||
images = tf.placeholder(tf.float32, [None, 784], name='input_x') | ||
labels = tf.placeholder(tf.float32, [None, 10], name='input_y') | ||
keep_prob = tf.placeholder(tf.float32, name='keep_prob') | ||
|
||
self.images = images | ||
self.labels = labels | ||
self.keep_prob = keep_prob | ||
|
||
self.train_step = None | ||
self.accuracy = None | ||
|
||
self.w1 = None | ||
self.b1 = None | ||
self.fcw1 = None | ||
self.cross = None | ||
with tf.name_scope('reshape'): | ||
x_image = tf.reshape(images, [-1, 28, 28, 1]) | ||
with tf.name_scope('conv1'): | ||
w_conv1 = weight_variable([5, 5, 1, 32]) | ||
self.w1 = w_conv1 | ||
b_conv1 = bias_variable([32]) | ||
self.b1 = b_conv1 | ||
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) | ||
with tf.name_scope('pool1'): | ||
h_pool1 = max_pool(h_conv1, 2) | ||
with tf.name_scope('conv2'): | ||
w_conv2 = weight_variable([5, 5, 32, 64]) | ||
b_conv2 = bias_variable([64]) | ||
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) | ||
with tf.name_scope('pool2'): | ||
h_pool2 = max_pool(h_conv2, 2) | ||
with tf.name_scope('fc1'): | ||
w_fc1 = weight_variable([7 * 7 * 64, 1024]) | ||
self.fcw1 = w_fc1 | ||
b_fc1 = bias_variable([1024]) | ||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) | ||
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1) | ||
with tf.name_scope('dropout'): | ||
h_fc1_drop = tf.nn.dropout(h_fc1, 0.5) | ||
with tf.name_scope('fc2'): | ||
w_fc2 = weight_variable([1024, 10]) | ||
b_fc2 = bias_variable([10]) | ||
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2 | ||
with tf.name_scope('loss'): | ||
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_conv)) | ||
self.cross = cross_entropy | ||
with tf.name_scope('adam_optimizer'): | ||
self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy) | ||
with tf.name_scope('accuracy'): | ||
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1)) | ||
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | ||
|
||
|
||
def main(): | ||
tf.set_random_seed(0) | ||
|
||
data = input_data.read_data_sets('data', one_hot=True) | ||
|
||
model = Mnist() | ||
|
||
'''you can change this to LevelPruner to implement it | ||
chicm-ms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pruner = LevelPruner(configure_list) | ||
''' | ||
configure_list = [{ | ||
'pruning_rate': 0.5, | ||
'op_types': ['Conv2D'] | ||
}] | ||
pruner = FPGMPruner(configure_list) | ||
# if you want to load from yaml file | ||
# configure_file = nni.compressors.tf_compressor._nnimc_tf._tf_default_load_configure_file('configure_example.yaml','AGPruner') | ||
# configure_list = configure_file.get('config',[]) | ||
# pruner.load_configure(configure_list) | ||
# you can also handle it yourself and input an configure list in json | ||
pruner(tf.get_default_graph()) | ||
# you can also use compress(model) or compress_default_graph() for tensorflow compressor | ||
# pruner.compress(tf.get_default_graph()) | ||
|
||
with tf.Session() as sess: | ||
sess.run(tf.global_variables_initializer()) | ||
for batch_idx in range(2000): | ||
if batch_idx % 10 == 0: | ||
pruner.update_epoch(batch_idx / 10, sess) | ||
batch = data.train.next_batch(2000) | ||
model.train_step.run(feed_dict={ | ||
model.images: batch[0], | ||
model.labels: batch[1], | ||
model.keep_prob: 0.5 | ||
}) | ||
if batch_idx % 10 == 0: | ||
test_acc = model.accuracy.eval(feed_dict={ | ||
model.images: data.test.images, | ||
model.labels: data.test.labels, | ||
model.keep_prob: 1.0 | ||
}) | ||
print('test accuracy', test_acc) | ||
|
||
test_acc = model.accuracy.eval(feed_dict={ | ||
model.images: data.test.images, | ||
model.labels: data.test.labels, | ||
model.keep_prob: 1.0 | ||
}) | ||
print('final result is', test_acc) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from nni.compression.torch import FPGMPruner | ||
import torch | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
|
||
|
||
class Mnist(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) | ||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) | ||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) | ||
self.fc2 = torch.nn.Linear(500, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.conv1(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = F.relu(self.conv2(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = x.view(-1, 4 * 4 * 50) | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
def train(model, device, train_loader, optimizer): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % 100 == 0: | ||
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) | ||
|
||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
test_loss /= len(test_loader.dataset) | ||
|
||
print('Loss: {} Accuracy: {}%)\n'.format( | ||
test_loss, 100 * correct / len(test_loader.dataset))) | ||
|
||
|
||
def main(): | ||
torch.manual_seed(0) | ||
device = torch.device('cpu') | ||
|
||
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
train_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=True, download=True, transform=trans), | ||
batch_size=64, shuffle=True) | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=False, transform=trans), | ||
batch_size=1000, shuffle=True) | ||
|
||
model = Mnist() | ||
|
||
'''you can change this to LevelPruner to implement it | ||
pruner = LevelPruner(configure_list) | ||
''' | ||
configure_list = [{ | ||
'pruning_rate': 0.5, | ||
'op_types': ['Conv2d'] | ||
}] | ||
|
||
pruner = FPGMPruner(configure_list) | ||
pruner(model) | ||
# you can also use compress(model) method | ||
# like that pruner.compress(model) | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) | ||
for epoch in range(10): | ||
pruner.update_epoch(epoch) | ||
print('# Epoch {} #'.format(epoch)) | ||
train(model, device, train_loader, optimizer) | ||
test(model, device, test_loader) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import tensorflow as tf | ||
from .compressor import Pruner | ||
|
||
__all__ = ['LevelPruner', 'AGP_Pruner'] | ||
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner'] | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
@@ -94,3 +94,83 @@ def update_epoch(self, epoch, sess): | |
sess.run(tf.assign(self.now_epoch, int(epoch))) | ||
for k in self.if_init_list: | ||
self.if_init_list[k] = True | ||
|
||
class FPGMPruner(Pruner): | ||
"""A filter pruner via geometric median. | ||
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", | ||
https://arxiv.org/pdf/1811.00250.pdf | ||
""" | ||
|
||
def __init__(self, config_list): | ||
""" | ||
config_list: supported keys: | ||
- pruning_rate: percentage of convolutional filters to be pruned. | ||
""" | ||
super().__init__(config_list) | ||
self.mask_list = {} | ||
self.assign_handler = [] | ||
|
||
def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name, **kwargs): | ||
"""supports Conv1d, Conv2d, Conv3d | ||
filter dimensions for Conv1D: | ||
chicm-ms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
LEN: filter length | ||
IN: number of input channel | ||
OUT: number of output channel | ||
|
||
filter dimensions for Conv2D: | ||
H: filter height | ||
W: filter width | ||
IN: number of input channel | ||
OUT: number of output channel | ||
""" | ||
|
||
assert 0 <= config.get('pruning_rate') < 1 | ||
# TODO uncomment this | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rmove? |
||
#assert op_type in ['Conv1D', 'Conv2D', 'Conv3D'] | ||
|
||
if op_type == config['op_type']: | ||
weight = tf.stop_gradient(tf.transpose(conv_kernel_weight, [2,3,0,1])) | ||
masks = tf.Variable(tf.ones_like(weight)) | ||
|
||
num_kernels = weight.shape[0].value * weight.shape[1].value | ||
num_prune = int(num_kernels * config.get('pruning_rate')) | ||
if num_kernels < 2 or num_prune < 1: | ||
self.mask_list.update({op_name: masks}) | ||
return masks | ||
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) | ||
tf.scatter_nd_update(masks, min_gm_idx, tf.zeros((min_gm_idx.shape[0].value, weight.shape[-2].value, weight.shape[-1].value))) | ||
masks = tf.transpose(masks, [2,3,0,1]) | ||
self.assign_handler.append(tf.assign(conv_kernel_weight, conv_kernel_weight*masks)) | ||
self.mask_list.update({op_name: masks}) | ||
else: | ||
masks = tf.Variable(tf.ones_like(conv_kernel_weight)) | ||
self.mask_list.update({op_name: masks}) | ||
|
||
return masks | ||
|
||
def _get_min_gm_kernel_idx(self, weight, n): | ||
assert len(weight.shape) >= 3 | ||
assert weight.shape[0].value * weight.shape[1].value > 2 | ||
|
||
dist_list, idx_list = [], [] | ||
for in_i in range(weight.shape[0].value): | ||
for out_i in range(weight.shape[1].value): | ||
dist_sum = self._get_distance_sum(weight, in_i, out_i) | ||
dist_list.append(dist_sum) | ||
idx_list.append([in_i, out_i]) | ||
dist_tensor = tf.convert_to_tensor(dist_list) | ||
idx_tensor = tf.constant(idx_list) | ||
|
||
_, idx = tf.math.top_k(dist_tensor, k=n) | ||
return tf.gather(idx_tensor, idx) | ||
|
||
def _get_distance_sum(self, weight, in_idx, out_idx): | ||
w = tf.reshape(weight, (-1, weight.shape[-2].value, weight.shape[-1].value)) | ||
anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0].value, 1, 1]) | ||
x = w - anchor_w | ||
x = tf.math.reduce_sum((x*x), (-2, -1)) | ||
x = tf.math.sqrt(x) | ||
return tf.math.reduce_sum(x) | ||
|
||
def update_epoch(self, epoch, sess): | ||
sess.run(self.assign_handler) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it easy to reproduce one of the experiments in the paper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key metrics of the pruning are reduced FLOPs and accuracy change, to reproduce the reduced FLOPs, we need to use the pruned compact model, which is not implemented yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to offer some test metrics to make sure the implementation is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I have added verification code in fpgm_torch_mnist.py example to verify the pruned conv kernel weight sparsity. By checking the sparsity and loss, we can verify:
But this code still can not verify the implementation is same as the paper. I am considering to add some kind of verification in UT.