From e47c923ce9d075909d1d275e8cc2309bb8cdff04 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Wed, 23 Oct 2019 19:38:14 +0800 Subject: [PATCH 01/28] fpgm pruner pytorch implementation --- examples/model_compress/fpgm_mnist.py | 94 +++++++++++++++++++ .../nni/compression/torch/builtin_pruners.py | 73 +++++++++++++- 2 files changed, 166 insertions(+), 1 deletion(-) create mode 100644 examples/model_compress/fpgm_mnist.py diff --git a/examples/model_compress/fpgm_mnist.py b/examples/model_compress/fpgm_mnist.py new file mode 100644 index 0000000000..f8e99b6180 --- /dev/null +++ b/examples/model_compress/fpgm_mnist.py @@ -0,0 +1,94 @@ +from nni.compression.torch import FPGMPruner +import torch +import torch.nn.functional as F +from torchvision import datasets, transforms + + +class Mnist(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) + self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) + self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) + self.fc2 = torch.nn.Linear(500, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4 * 4 * 50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def train(model, device, train_loader, optimizer): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + + +def main(): + torch.manual_seed(0) + device = torch.device('cpu') + + trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=True, download=True, transform=trans), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=False, transform=trans), + batch_size=1000, shuffle=True) + + model = Mnist() + + '''you can change this to LevelPruner to implement it + pruner = LevelPruner(configure_list) + ''' + configure_list = [{ + 'start_epoch': 0, + 'end_epoch': 10, + 'pruning_rate': 0.5, + 'op_type': 'Conv2d' + }] + + pruner = FPGMPruner(configure_list) + pruner(model) + # you can also use compress(model) method + # like that pruner.compress(model) + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + for epoch in range(10): + pruner.update_epoch(epoch) + print('# Epoch {} #'.format(epoch)) + train(model, device, train_loader, optimizer) + test(model, device, test_loader) + + +if __name__ == '__main__': + main() diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 0ad09b9135..10c2a0fb55 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -2,7 +2,7 @@ import torch from .compressor import Pruner -__all__ = ['LevelPruner', 'AGP_Pruner'] +__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner'] logger = logging.getLogger('torch pruner') @@ -104,3 +104,74 @@ def update_epoch(self, epoch): self.now_epoch = epoch for k in self.if_init_list.keys(): self.if_init_list[k] = True + +class FPGMPruner(Pruner): + """A filter pruner via geometric median. + "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", + https://arxiv.org/pdf/1811.00250.pdf + """ + + def __init__(self, config_list): + """ + config_list: supported keys: + - pruning_rate: percentage of convolutional filters to be pruned. + - start_epoch: start epoch number begin update mask + - end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch + """ + super().__init__(config_list) + self.mask_list = {} + #self.pruning_rate = config_list.get('pruning_rate') + print(config_list) + + def calc_mask(self, weight, config, op_name, **kwargs): + #print(weight.size(), type(weight)) + #print('config:', config) + #print('kgargs:', kwargs) + + if kwargs['op_type'] in ['Conv1d', 'Conv2d', 'Conv3d']: + num_kernels = weight.size(0) * weight.size(1) + num_prune = int(num_kernels * config.get('pruning_rate')) + if num_kernels < 3 or num_prune < 1: + return torch.ones(weight.size()) + min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) + masks = torch.ones(weight.size()) + #num_before = masks.sum() + for idx in min_gm_idx: + masks[idx] = 0. + #print('pruned: {}'.format(masks.sum() / num_before)) + return masks + else: + return torch.ones(weight.size()) + + def _get_min_gm_kernel_idx(self, weight, n): + """filter/kernel dimensions for Conv2d: + IN: number of input channel + OUT: number of output channel + H: filter height + W: filter width + """ + assert len(weight.size()) >= 3 # supports Conv1d, Conv2d, Conv3d + assert weight.size(0) * weight.size(1) > 2 + + dist_list = [] + for in_i in range(weight.size(0)): + for out_i in range(weight.size(1)): + dist_sum = self._get_distance_sum_fast(weight, in_i, out_i) + dist_list.append((dist_sum, (in_i, out_i))) + min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n] + return [x[1] for x in min_gm_kernels] + + def _get_distance_sum(self, weight, in_idx, out_idx): + w = weight.view(-1, weight.size(-2), weight.size(-1)) + dist_sum = 0. + for k in w: + dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2) + return dist_sum + + def _get_distance_sum_fast(self, weight, in_idx, out_idx): + w = weight.view(-1, weight.size(-2), weight.size(-1)) + anchor_w = weight[in_idx, out_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) + x = w - anchor_w + x = x*x + x = torch.sqrt(x) + return x.sum() From c51f688023a2dccbecb1f39654e2494fd16e4a16 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Fri, 25 Oct 2019 19:44:33 +0800 Subject: [PATCH 02/28] updates --- examples/model_compress/fpgm_tf_mnist.py | 132 ++++++++++++++++++ .../{fpgm_mnist.py => fpgm_torch_mnist.py} | 2 - .../compression/tensorflow/builtin_pruners.py | 72 +++++++++- .../nni/compression/torch/builtin_pruners.py | 36 +++-- 4 files changed, 220 insertions(+), 22 deletions(-) create mode 100644 examples/model_compress/fpgm_tf_mnist.py rename examples/model_compress/{fpgm_mnist.py => fpgm_torch_mnist.py} (98%) diff --git a/examples/model_compress/fpgm_tf_mnist.py b/examples/model_compress/fpgm_tf_mnist.py new file mode 100644 index 0000000000..4fc8d6443b --- /dev/null +++ b/examples/model_compress/fpgm_tf_mnist.py @@ -0,0 +1,132 @@ +from nni.compression.tensorflow import FPGMPruner +import tensorflow as tf +from tensorflow.examples.tutorials.mnist import input_data + + +def weight_variable(shape): + return tf.Variable(tf.truncated_normal(shape, stddev=0.1)) + + +def bias_variable(shape): + return tf.Variable(tf.constant(0.1, shape=shape)) + + +def conv2d(x_input, w_matrix): + return tf.nn.conv2d(x_input, w_matrix, strides=[1, 1, 1, 1], padding='SAME') + + +def max_pool(x_input, pool_size): + size = [1, pool_size, pool_size, 1] + return tf.nn.max_pool(x_input, ksize=size, strides=size, padding='SAME') + + +class Mnist: + def __init__(self): + images = tf.placeholder(tf.float32, [None, 784], name='input_x') + labels = tf.placeholder(tf.float32, [None, 10], name='input_y') + keep_prob = tf.placeholder(tf.float32, name='keep_prob') + + self.images = images + self.labels = labels + self.keep_prob = keep_prob + + self.train_step = None + self.accuracy = None + + self.w1 = None + self.b1 = None + self.fcw1 = None + self.cross = None + with tf.name_scope('reshape'): + x_image = tf.reshape(images, [-1, 28, 28, 1]) + with tf.name_scope('conv1'): + w_conv1 = weight_variable([5, 5, 1, 32]) + self.w1 = w_conv1 + b_conv1 = bias_variable([32]) + self.b1 = b_conv1 + h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) + with tf.name_scope('pool1'): + h_pool1 = max_pool(h_conv1, 2) + with tf.name_scope('conv2'): + w_conv2 = weight_variable([5, 5, 32, 64]) + b_conv2 = bias_variable([64]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) + with tf.name_scope('pool2'): + h_pool2 = max_pool(h_conv2, 2) + with tf.name_scope('fc1'): + w_fc1 = weight_variable([7 * 7 * 64, 1024]) + self.fcw1 = w_fc1 + b_fc1 = bias_variable([1024]) + h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) + h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1) + with tf.name_scope('dropout'): + h_fc1_drop = tf.nn.dropout(h_fc1, 0.5) + with tf.name_scope('fc2'): + w_fc2 = weight_variable([1024, 10]) + b_fc2 = bias_variable([10]) + y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2 + with tf.name_scope('loss'): + cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_conv)) + self.cross = cross_entropy + with tf.name_scope('adam_optimizer'): + self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy) + with tf.name_scope('accuracy'): + correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1)) + self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + + +def main(): + tf.set_random_seed(0) + + data = input_data.read_data_sets('data', one_hot=True) + + model = Mnist() + + '''you can change this to LevelPruner to implement it + pruner = LevelPruner(configure_list) + ''' + configure_list = [{ + 'start_epoch': 0, + 'end_epoch': 10, + 'pruning_rate': 0.5, + 'op_type': 'default' + }] + pruner = FPGMPruner(configure_list) + # if you want to load from yaml file + # configure_file = nni.compressors.tf_compressor._nnimc_tf._tf_default_load_configure_file('configure_example.yaml','AGPruner') + # configure_list = configure_file.get('config',[]) + # pruner.load_configure(configure_list) + # you can also handle it yourself and input an configure list in json + pruner(tf.get_default_graph()) + # you can also use compress(model) or compress_default_graph() for tensorflow compressor + # pruner.compress(tf.get_default_graph()) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for batch_idx in range(2000): + if batch_idx % 10 == 0: + pruner.update_epoch(batch_idx / 10, sess) + batch = data.train.next_batch(2000) + model.train_step.run(feed_dict={ + model.images: batch[0], + model.labels: batch[1], + model.keep_prob: 0.5 + }) + if batch_idx % 10 == 0: + test_acc = model.accuracy.eval(feed_dict={ + model.images: data.test.images, + model.labels: data.test.labels, + model.keep_prob: 1.0 + }) + print('test accuracy', test_acc) + + test_acc = model.accuracy.eval(feed_dict={ + model.images: data.test.images, + model.labels: data.test.labels, + model.keep_prob: 1.0 + }) + print('final result is', test_acc) + + +if __name__ == '__main__': + main() diff --git a/examples/model_compress/fpgm_mnist.py b/examples/model_compress/fpgm_torch_mnist.py similarity index 98% rename from examples/model_compress/fpgm_mnist.py rename to examples/model_compress/fpgm_torch_mnist.py index f8e99b6180..d38e2955a4 100644 --- a/examples/model_compress/fpgm_mnist.py +++ b/examples/model_compress/fpgm_torch_mnist.py @@ -71,8 +71,6 @@ def main(): pruner = LevelPruner(configure_list) ''' configure_list = [{ - 'start_epoch': 0, - 'end_epoch': 10, 'pruning_rate': 0.5, 'op_type': 'Conv2d' }] diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index 88f00b4d52..218dae078c 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -2,7 +2,7 @@ import tensorflow as tf from .compressor import Pruner -__all__ = ['LevelPruner', 'AGP_Pruner'] +__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner'] _logger = logging.getLogger(__name__) @@ -54,6 +54,9 @@ def __init__(self, config_list): self.assign_handler = [] def calc_mask(self, weight, config, op_name, **kwargs): + print('config:', config) + print('kwargs:', kwargs) + print('op_name:', op_name) start_epoch = config.get('start_epoch', 0) freq = config.get('frequency', 1) if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) and ( @@ -94,3 +97,70 @@ def update_epoch(self, epoch, sess): sess.run(tf.assign(self.now_epoch, int(epoch))) for k in self.if_init_list.keys(): self.if_init_list[k] = True + +class FPGMPruner(Pruner): + """A filter pruner via geometric median. + "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", + https://arxiv.org/pdf/1811.00250.pdf + """ + + def __init__(self, config_list): + """ + config_list: supported keys: + - pruning_rate: percentage of convolutional filters to be pruned. + - start_epoch: start epoch number begin update mask + - end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch + """ + super().__init__(config_list) + self.mask_list = {} + + def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name): + print('config:', config) + print('op:', op) + print('op_type:', op_type) + print('op_name:', op_name) + assert 0 <= config.get('pruning_rate') < 1 + assert config['op_type'] in ['Conv1D', 'Conv2D', 'Conv3D'] + + weight = tf.stop_gradient(conv_kernel_weight) + masks = tf.ones_like(weight) + + if op_type == config['op_type']: + num_kernels = weight.shape[0].value * weight.shape[1].value + num_prune = int(num_kernels * config.get('pruning_rate')) + if num_kernels < 2 or num_prune < 1: + self.mask_list.update({op_name: masks}) + return masks + min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) + for idx in min_gm_idx: + masks[idx] = 0. + + self.mask_list.update({op_name: masks}) + return masks + + def _get_min_gm_kernel_idx(self, weight, n): + """supports Conv1D, Conv2D, Conv3D + filter/kernel dimensions for Conv2d: + IN: number of input channel + OUT: number of output channel + H: filter height + W: filter width + """ + assert len(weight.shape) >= 3 + assert weight.shape[0].value * weight.shape[1].value > 2 + + dist_list = [] + for in_i in range(weight.shape[0].value): + for out_i in range(weight.shape[1].value): + dist_sum = self._get_distance_sum_fast(weight, in_i, out_i) + dist_list.append((dist_sum, (in_i, out_i))) + min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n] + return [x[1] for x in min_gm_kernels] + + def _get_distance_sum_fast(self, weight, in_idx, out_idx): + w = tf.reshape(weight, (-1, weight.shape[-2].value, weight.shape[-1].value)) + anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0].value, 1, 1]) + x = w - anchor_w + x = tf.math.reduce_sum((x*x), (-2, -1)) + x = tf.math.sqrt(x) + return tf.math.reduce_sum(x) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 10c2a0fb55..02f55210d0 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -120,37 +120,35 @@ def __init__(self, config_list): """ super().__init__(config_list) self.mask_list = {} - #self.pruning_rate = config_list.get('pruning_rate') - print(config_list) - def calc_mask(self, weight, config, op_name, **kwargs): - #print(weight.size(), type(weight)) - #print('config:', config) - #print('kgargs:', kwargs) - - if kwargs['op_type'] in ['Conv1d', 'Conv2d', 'Conv3d']: + def calc_mask(self, weight, config, op, op_type, op_name): + assert 0 <= config.get('pruning_rate') < 1 + assert config['op_type'] in ['Conv1d', 'Conv2d', 'Conv3d'] + + masks = torch.ones(weight.size()) + + if op_type == config['op_type']: num_kernels = weight.size(0) * weight.size(1) num_prune = int(num_kernels * config.get('pruning_rate')) - if num_kernels < 3 or num_prune < 1: - return torch.ones(weight.size()) + if num_kernels < 2 or num_prune < 1: + self.mask_list.update({op_name: masks}) + return masks min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) - masks = torch.ones(weight.size()) - #num_before = masks.sum() for idx in min_gm_idx: masks[idx] = 0. - #print('pruned: {}'.format(masks.sum() / num_before)) - return masks - else: - return torch.ones(weight.size()) + + self.mask_list.update({op_name: masks}) + return masks def _get_min_gm_kernel_idx(self, weight, n): - """filter/kernel dimensions for Conv2d: + """supports Conv1d, Conv2d, Conv3d + filter/kernel dimensions for Conv2d: IN: number of input channel OUT: number of output channel H: filter height W: filter width """ - assert len(weight.size()) >= 3 # supports Conv1d, Conv2d, Conv3d + assert len(weight.size()) >= 3 assert weight.size(0) * weight.size(1) > 2 dist_list = [] @@ -172,6 +170,6 @@ def _get_distance_sum_fast(self, weight, in_idx, out_idx): w = weight.view(-1, weight.size(-2), weight.size(-1)) anchor_w = weight[in_idx, out_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) x = w - anchor_w - x = x*x + x = (x*x).sum((-2,-1)) x = torch.sqrt(x) return x.sum() From b1165da9dd822c5b1533ab90ea25b13c54f5e616 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Mon, 28 Oct 2019 18:29:53 +0800 Subject: [PATCH 03/28] updates --- .../compression/tensorflow/builtin_pruners.py | 76 +++++++++++-------- .../nni/compression/torch/builtin_pruners.py | 41 +++++----- src/sdk/pynni/tests/test_compressor.py | 18 ++++- 3 files changed, 82 insertions(+), 53 deletions(-) diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index 218dae078c..18cf4943ef 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -54,9 +54,6 @@ def __init__(self, config_list): self.assign_handler = [] def calc_mask(self, weight, config, op_name, **kwargs): - print('config:', config) - print('kwargs:', kwargs) - print('op_name:', op_name) start_epoch = config.get('start_epoch', 0) freq = config.get('frequency', 1) if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) and ( @@ -108,59 +105,74 @@ def __init__(self, config_list): """ config_list: supported keys: - pruning_rate: percentage of convolutional filters to be pruned. - - start_epoch: start epoch number begin update mask - - end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch """ super().__init__(config_list) self.mask_list = {} + self.assign_handler = [] - def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name): - print('config:', config) - print('op:', op) - print('op_type:', op_type) - print('op_name:', op_name) - assert 0 <= config.get('pruning_rate') < 1 - assert config['op_type'] in ['Conv1D', 'Conv2D', 'Conv3D'] + def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name, **kwargs): + """supports Conv1d, Conv2d, Conv3d + filter/kernel dimensions for Conv1D: + LEN: kernel length + IN: number of input channel + OUT: number of output channel + + filter/kernel dimensions for Conv2D: + H: filter height + W: filter width + IN: number of input channel + OUT: number of output channel + """ - weight = tf.stop_gradient(conv_kernel_weight) - masks = tf.ones_like(weight) + assert 0 <= config.get('pruning_rate') < 1 + # TODO uncomment this + #assert op_type in ['Conv1D', 'Conv2D', 'Conv3D'] if op_type == config['op_type']: + weight = tf.stop_gradient(tf.transpose(conv_kernel_weight, [2,3,0,1])) + + print(weight.shape) + masks = tf.Variable(tf.ones_like(weight)) + num_kernels = weight.shape[0].value * weight.shape[1].value num_prune = int(num_kernels * config.get('pruning_rate')) if num_kernels < 2 or num_prune < 1: self.mask_list.update({op_name: masks}) return masks min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) - for idx in min_gm_idx: - masks[idx] = 0. - - self.mask_list.update({op_name: masks}) + tf.scatter_nd_update(masks, min_gm_idx, tf.zeros((min_gm_idx.shape[0].value, weight.shape[-2].value, weight.shape[-1].value))) + masks = tf.transpose(masks, [2,3,0,1]) + self.assign_handler.append(tf.assign(conv_kernel_weight, conv_kernel_weight*masks)) + self.mask_list.update({op_name: masks}) + else: + masks = tf.Variable(tf.ones_like(conv_kernel_weight)) + self.mask_list.update({op_name: masks}) + return masks def _get_min_gm_kernel_idx(self, weight, n): - """supports Conv1D, Conv2D, Conv3D - filter/kernel dimensions for Conv2d: - IN: number of input channel - OUT: number of output channel - H: filter height - W: filter width - """ assert len(weight.shape) >= 3 assert weight.shape[0].value * weight.shape[1].value > 2 - dist_list = [] + dist_list, idx_list = [], [] for in_i in range(weight.shape[0].value): for out_i in range(weight.shape[1].value): - dist_sum = self._get_distance_sum_fast(weight, in_i, out_i) - dist_list.append((dist_sum, (in_i, out_i))) - min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n] - return [x[1] for x in min_gm_kernels] - - def _get_distance_sum_fast(self, weight, in_idx, out_idx): + dist_sum = self._get_distance_sum(weight, in_i, out_i) + dist_list.append(dist_sum) + idx_list.append([in_i, out_i]) + dist_tensor = tf.convert_to_tensor(dist_list) + idx_tensor = tf.constant(idx_list) + + _, idx = tf.math.top_k(dist_tensor, k=n) + return tf.gather(idx_tensor, idx) + + def _get_distance_sum(self, weight, in_idx, out_idx): w = tf.reshape(weight, (-1, weight.shape[-2].value, weight.shape[-1].value)) anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0].value, 1, 1]) x = w - anchor_w x = tf.math.reduce_sum((x*x), (-2, -1)) x = tf.math.sqrt(x) return tf.math.reduce_sum(x) + + def update_epoch(self, epoch, sess): + sess.run(self.assign_handler) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 02f55210d0..bbf88e80a5 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -121,9 +121,22 @@ def __init__(self, config_list): super().__init__(config_list) self.mask_list = {} - def calc_mask(self, weight, config, op, op_type, op_name): + def calc_mask(self, weight, config, op, op_type, op_name, **kwargs): + """supports Conv1d, Conv2d, Conv3d + filter/kernel dimensions for Conv1d: + IN: number of input channel + OUT: number of output channel + LEN: kernel length + + filter/kernel dimensions for Conv2d: + IN: number of input channel + OUT: number of output channel + H: filter height + W: filter width + """ + assert 0 <= config.get('pruning_rate') < 1 - assert config['op_type'] in ['Conv1d', 'Conv2d', 'Conv3d'] + assert op_type in ['Conv1d', 'Conv2d', 'Conv3d'] masks = torch.ones(weight.size()) @@ -141,32 +154,26 @@ def calc_mask(self, weight, config, op, op_type, op_name): return masks def _get_min_gm_kernel_idx(self, weight, n): - """supports Conv1d, Conv2d, Conv3d - filter/kernel dimensions for Conv2d: - IN: number of input channel - OUT: number of output channel - H: filter height - W: filter width - """ assert len(weight.size()) >= 3 assert weight.size(0) * weight.size(1) > 2 dist_list = [] for in_i in range(weight.size(0)): for out_i in range(weight.size(1)): - dist_sum = self._get_distance_sum_fast(weight, in_i, out_i) + dist_sum = self._get_distance_sum(weight, in_i, out_i) dist_list.append((dist_sum, (in_i, out_i))) min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n] return [x[1] for x in min_gm_kernels] def _get_distance_sum(self, weight, in_idx, out_idx): - w = weight.view(-1, weight.size(-2), weight.size(-1)) - dist_sum = 0. - for k in w: - dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2) - return dist_sum - - def _get_distance_sum_fast(self, weight, in_idx, out_idx): + """ Optimized verision of following naive implementation: + def _get_distance_sum(self, weight, in_idx, out_idx): + w = weight.view(-1, weight.size(-2), weight.size(-1)) + dist_sum = 0. + for k in w: + dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2) + return dist_sum + """ w = weight.view(-1, weight.size(-2), weight.size(-1)) anchor_w = weight[in_idx, out_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) x = w - anchor_w diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index d1f4a724cb..601b7521b6 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -100,21 +100,31 @@ def forward(self, x): class CompressorTestCase(TestCase): def test_tf_pruner(self): model = TfMnist() - configure_list = [{'sparsity': 0.8, 'op_types': 'default'}] + configure_list = [{'sparsity': 0.8, 'op_type': 'default'}] tf_compressor.LevelPruner(configure_list).compress_default_graph() def test_tf_quantizer(self): model = TfMnist() - tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph() + tf_compressor.NaiveQuantizer([{'op_type': 'default'}]).compress_default_graph() def test_torch_pruner(self): model = TorchMnist() - configure_list = [{'sparsity': 0.8, 'op_types': 'default'}] + configure_list = [{'sparsity': 0.8, 'op_type': 'default'}] torch_compressor.LevelPruner(configure_list).compress(model) + def test_torch_fpgm_pruner(self): + model = TorchMnist() + configure_list = [{'pruning_rate': 0.5, 'op_type': 'Conv2d'}] + torch_compressor.FPGMPruner(configure_list).compress(model) + + def test_tf_fpgm_pruner(self): + model = TfMnist() + configure_list = [{'pruning_rate': 0.5, 'op_type': 'Conv2D'}] + tf_compressor.FPGMPruner(configure_list).compress_default_graph() + def test_torch_quantizer(self): model = TorchMnist() - torch_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress(model) + torch_compressor.NaiveQuantizer([{'op_type': 'default'}]).compress(model) if __name__ == '__main__': From cd32a6adfcfc31e04f4736c47e7c356d41b018e0 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Mon, 28 Oct 2019 18:35:26 +0800 Subject: [PATCH 04/28] updates --- examples/model_compress/fpgm_tf_mnist.py | 4 +--- src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/model_compress/fpgm_tf_mnist.py b/examples/model_compress/fpgm_tf_mnist.py index 4fc8d6443b..95307765b7 100644 --- a/examples/model_compress/fpgm_tf_mnist.py +++ b/examples/model_compress/fpgm_tf_mnist.py @@ -86,10 +86,8 @@ def main(): pruner = LevelPruner(configure_list) ''' configure_list = [{ - 'start_epoch': 0, - 'end_epoch': 10, 'pruning_rate': 0.5, - 'op_type': 'default' + 'op_type': 'Conv2D' }] pruner = FPGMPruner(configure_list) # if you want to load from yaml file diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index 18cf4943ef..de24aba50c 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -129,9 +129,7 @@ def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name, **kwargs): #assert op_type in ['Conv1D', 'Conv2D', 'Conv3D'] if op_type == config['op_type']: - weight = tf.stop_gradient(tf.transpose(conv_kernel_weight, [2,3,0,1])) - - print(weight.shape) + weight = tf.stop_gradient(tf.transpose(conv_kernel_weight, [2,3,0,1])) masks = tf.Variable(tf.ones_like(weight)) num_kernels = weight.shape[0].value * weight.shape[1].value From 871702667594903bbf8854d361320a6f00876c3d Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Tue, 29 Oct 2019 11:47:14 +0800 Subject: [PATCH 05/28] updates --- .../pynni/nni/compression/tensorflow/builtin_pruners.py | 6 +++--- src/sdk/pynni/nni/compression/torch/builtin_pruners.py | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index de24aba50c..4f1bde45a4 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -112,12 +112,12 @@ def __init__(self, config_list): def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name, **kwargs): """supports Conv1d, Conv2d, Conv3d - filter/kernel dimensions for Conv1D: - LEN: kernel length + filter dimensions for Conv1D: + LEN: filter length IN: number of input channel OUT: number of output channel - filter/kernel dimensions for Conv2D: + filter dimensions for Conv2D: H: filter height W: filter width IN: number of input channel diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index bbf88e80a5..aeb536d49e 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -115,20 +115,18 @@ def __init__(self, config_list): """ config_list: supported keys: - pruning_rate: percentage of convolutional filters to be pruned. - - start_epoch: start epoch number begin update mask - - end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch """ super().__init__(config_list) self.mask_list = {} def calc_mask(self, weight, config, op, op_type, op_name, **kwargs): """supports Conv1d, Conv2d, Conv3d - filter/kernel dimensions for Conv1d: + filter dimensions for Conv1d: IN: number of input channel OUT: number of output channel - LEN: kernel length + LEN: filter length - filter/kernel dimensions for Conv2d: + filter dimensions for Conv2d: IN: number of input channel OUT: number of output channel H: filter height From 216a9a7243921348bcd7cf57627c9af4ef20093c Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 31 Oct 2019 11:11:04 +0800 Subject: [PATCH 06/28] updates --- examples/model_compress/fpgm_tf_mnist.py | 2 +- examples/model_compress/fpgm_torch_mnist.py | 2 +- .../nni/compression/torch/builtin_pruners.py | 17 +++++++++++++---- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/model_compress/fpgm_tf_mnist.py b/examples/model_compress/fpgm_tf_mnist.py index 95307765b7..b05dd923fd 100644 --- a/examples/model_compress/fpgm_tf_mnist.py +++ b/examples/model_compress/fpgm_tf_mnist.py @@ -87,7 +87,7 @@ def main(): ''' configure_list = [{ 'pruning_rate': 0.5, - 'op_type': 'Conv2D' + 'op_types': ['Conv2D'] }] pruner = FPGMPruner(configure_list) # if you want to load from yaml file diff --git a/examples/model_compress/fpgm_torch_mnist.py b/examples/model_compress/fpgm_torch_mnist.py index d38e2955a4..2b984ae91e 100644 --- a/examples/model_compress/fpgm_torch_mnist.py +++ b/examples/model_compress/fpgm_torch_mnist.py @@ -72,7 +72,7 @@ def main(): ''' configure_list = [{ 'pruning_rate': 0.5, - 'op_type': 'Conv2d' + 'op_types': ['Conv2d'] }] pruner = FPGMPruner(configure_list) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 998ca929af..f4aed553a6 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -135,20 +135,26 @@ def calc_mask(self, weight, config, op, op_type, op_name, **kwargs): assert 0 <= config.get('pruning_rate') < 1 assert op_type in ['Conv1d', 'Conv2d', 'Conv3d'] + assert op_type in config['op_types'] + + if op_name in self.epoch_pruned_layers: + assert op_name in self.mask_list + return self.mask_list.get(op_name) masks = torch.ones(weight.size()) - if op_type == config['op_type']: + try: num_kernels = weight.size(0) * weight.size(1) num_prune = int(num_kernels * config.get('pruning_rate')) if num_kernels < 2 or num_prune < 1: - self.mask_list.update({op_name: masks}) return masks min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) for idx in min_gm_idx: masks[idx] = 0. - - self.mask_list.update({op_name: masks}) + finally: + self.mask_list.update({op_name: masks}) + self.epoch_pruned_layers.add(op_name) + return masks def _get_min_gm_kernel_idx(self, weight, n): @@ -178,3 +184,6 @@ def _get_distance_sum(self, weight, in_idx, out_idx): x = (x*x).sum((-2,-1)) x = torch.sqrt(x) return x.sum() + + def update_epoch(self, epoch): + self.epoch_pruned_layers = set() From a42a06785a764948059631d51fa646f0a80a5142 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 31 Oct 2019 11:59:46 +0800 Subject: [PATCH 07/28] updates --- src/sdk/pynni/nni/compression/torch/builtin_pruners.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index f4aed553a6..8b2311ae44 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -159,7 +159,6 @@ def calc_mask(self, weight, config, op, op_type, op_name, **kwargs): def _get_min_gm_kernel_idx(self, weight, n): assert len(weight.size()) >= 3 - assert weight.size(0) * weight.size(1) > 2 dist_list = [] for in_i in range(weight.size(0)): From 8fd58bbb59c84d263b72e0c658320c162ce020db Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 31 Oct 2019 12:12:17 +0800 Subject: [PATCH 08/28] updates --- .../nni/compression/torch/builtin_pruners.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 8b2311ae44..406e3dda37 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -120,7 +120,7 @@ def __init__(self, config_list): self.mask_list = {} def calc_mask(self, weight, config, op, op_type, op_name, **kwargs): - """supports Conv1d, Conv2d, Conv3d + """supports Conv1d, Conv2d filter dimensions for Conv1d: IN: number of input channel OUT: number of output channel @@ -134,7 +134,7 @@ def calc_mask(self, weight, config, op, op_type, op_name, **kwargs): """ assert 0 <= config.get('pruning_rate') < 1 - assert op_type in ['Conv1d', 'Conv2d', 'Conv3d'] + assert op_type in ['Conv1d', 'Conv2d'] assert op_type in config['op_types'] if op_name in self.epoch_pruned_layers: @@ -158,7 +158,7 @@ def calc_mask(self, weight, config, op, op_type, op_name, **kwargs): return masks def _get_min_gm_kernel_idx(self, weight, n): - assert len(weight.size()) >= 3 + assert len(weight.size()) in [3, 4] dist_list = [] for in_i in range(weight.size(0)): @@ -177,8 +177,14 @@ def _get_distance_sum(self, weight, in_idx, out_idx): dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2) return dist_sum """ - w = weight.view(-1, weight.size(-2), weight.size(-1)) - anchor_w = weight[in_idx, out_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) + if len(weight.size()) == 4: # Conv2d + w = weight.view(-1, weight.size(-2), weight.size(-1)) + anchor_w = weight[in_idx, out_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) + elif len(weight.size()) == 3: # Conv1d + w = weight.view(-1, weight.size(-1)) + anchor_w = weight[in_idx, out_idx].unsqueeze(0).expand(w.size(0), w.size(1)) + else: + raise RuntimeError('unsupported layer type') x = w - anchor_w x = (x*x).sum((-2,-1)) x = torch.sqrt(x) From ec2b3fbafc743f53e91c1e4945e08d3489fe1540 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 7 Nov 2019 13:18:58 +0800 Subject: [PATCH 09/28] updates per refactored framework --- examples/model_compress/fpgm_torch_mnist.py | 23 ++++++--- .../nni/compression/torch/builtin_pruners.py | 47 ++++++++++--------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/examples/model_compress/fpgm_torch_mnist.py b/examples/model_compress/fpgm_torch_mnist.py index 2b984ae91e..30bc581dc2 100644 --- a/examples/model_compress/fpgm_torch_mnist.py +++ b/examples/model_compress/fpgm_torch_mnist.py @@ -22,6 +22,16 @@ def forward(self, x): x = self.fc2(x) return F.log_softmax(x, dim=1) + def _get_conv_weight_sparsity(self, conv_layer): + num_zero_filters = (conv_layer.weight.data.sum((2,3)) == 0).sum() + num_filters = conv_layer.weight.data.size(0) * conv_layer.weight.data.size(1) + return num_zero_filters, num_filters, float(num_zero_filters)/num_filters + + def print_conv_filter_sparsity(self): + conv1_data = self._get_conv_weight_sparsity(self.conv1) + conv2_data = self._get_conv_weight_sparsity(self.conv2) + print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2])) + print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2])) def train(model, device, train_loader, optimizer): model.train() @@ -30,11 +40,11 @@ def train(model, device, train_loader, optimizer): optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() if batch_idx % 100 == 0: print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) - + model.print_conv_filter_sparsity() + loss.backward() + optimizer.step() def test(model, device, test_loader): model.eval() @@ -66,6 +76,7 @@ def main(): batch_size=1000, shuffle=True) model = Mnist() + model.print_conv_filter_sparsity() '''you can change this to LevelPruner to implement it pruner = LevelPruner(configure_list) @@ -75,10 +86,8 @@ def main(): 'op_types': ['Conv2d'] }] - pruner = FPGMPruner(configure_list) - pruner(model) - # you can also use compress(model) method - # like that pruner.compress(model) + pruner = FPGMPruner(model, configure_list) + pruner.compress() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for epoch in range(10): diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 3f0ee1fd57..f8536d1a3f 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -108,40 +108,42 @@ def update_epoch(self, epoch): self.if_init_list[k] = True class FPGMPruner(Pruner): - """A filter pruner via geometric median. - "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", + """ + A filter pruner via geometric median. + "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", https://arxiv.org/pdf/1811.00250.pdf """ - def __init__(self, config_list): + def __init__(self, model, config_list): """ config_list: supported keys: - pruning_rate: percentage of convolutional filters to be pruned. """ - super().__init__(config_list) + super().__init__(model, config_list) self.mask_list = {} - def calc_mask(self, weight, config, op, op_type, op_name, **kwargs): - """supports Conv1d, Conv2d + def calc_mask(self, layer, config): + """ + Supports Conv1d, Conv2d filter dimensions for Conv1d: - IN: number of input channel OUT: number of output channel + IN: number of input channel LEN: filter length filter dimensions for Conv2d: - IN: number of input channel OUT: number of output channel + IN: number of input channel H: filter height W: filter width """ - + weight = layer.module.weight.data assert 0 <= config.get('pruning_rate') < 1 - assert op_type in ['Conv1d', 'Conv2d'] - assert op_type in config['op_types'] + assert layer.type in ['Conv1d', 'Conv2d'] + assert layer.type in config['op_types'] - if op_name in self.epoch_pruned_layers: - assert op_name in self.mask_list - return self.mask_list.get(op_name) + if layer.name in self.epoch_pruned_layers: + assert layer.name in self.mask_list + return self.mask_list.get(layer.name) masks = torch.ones(weight.size()) @@ -154,8 +156,8 @@ def calc_mask(self, weight, config, op, op_type, op_name, **kwargs): for idx in min_gm_idx: masks[idx] = 0. finally: - self.mask_list.update({op_name: masks}) - self.epoch_pruned_layers.add(op_name) + self.mask_list.update({layer.name: masks}) + self.epoch_pruned_layers.add(layer.name) return masks @@ -163,15 +165,16 @@ def _get_min_gm_kernel_idx(self, weight, n): assert len(weight.size()) in [3, 4] dist_list = [] - for in_i in range(weight.size(0)): - for out_i in range(weight.size(1)): - dist_sum = self._get_distance_sum(weight, in_i, out_i) - dist_list.append((dist_sum, (in_i, out_i))) + for out_i in range(weight.size(0)): + for in_i in range(weight.size(1)): + dist_sum = self._get_distance_sum(weight, out_i, in_i) + dist_list.append((dist_sum, (out_i, in_i))) min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n] return [x[1] for x in min_gm_kernels] def _get_distance_sum(self, weight, in_idx, out_idx): - """ Optimized verision of following naive implementation: + """ + Optimized verision of following naive implementation: def _get_distance_sum(self, weight, in_idx, out_idx): w = weight.view(-1, weight.size(-2), weight.size(-1)) dist_sum = 0. @@ -188,7 +191,7 @@ def _get_distance_sum(self, weight, in_idx, out_idx): else: raise RuntimeError('unsupported layer type') x = w - anchor_w - x = (x*x).sum((-2,-1)) + x = (x*x).sum((-2, -1)) x = torch.sqrt(x) return x.sum() From 3040b6ec7605af7304d65b5c9863fbe423c455f1 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 7 Nov 2019 13:42:35 +0800 Subject: [PATCH 10/28] updates --- .../nni/compression/torch/builtin_pruners.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index f8536d1a3f..5d28c23fd8 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -116,8 +116,13 @@ class FPGMPruner(Pruner): def __init__(self, model, config_list): """ - config_list: supported keys: - - pruning_rate: percentage of convolutional filters to be pruned. + Parameters + ---------- + model : pytorch model + the model user wants to compress + config_list: list + support key for each list item: + - pruning_rate: percentage of convolutional filters to be pruned. """ super().__init__(model, config_list) self.mask_list = {} @@ -135,6 +140,13 @@ def calc_mask(self, layer, config): IN: number of input channel H: filter height W: filter width + + Parameters + ---------- + layer : LayerInfo + calculate mask for `layer`'s weight + config : dict + the configuration for generating the mask """ weight = layer.module.weight.data assert 0 <= config.get('pruning_rate') < 1 @@ -172,7 +184,7 @@ def _get_min_gm_kernel_idx(self, weight, n): min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n] return [x[1] for x in min_gm_kernels] - def _get_distance_sum(self, weight, in_idx, out_idx): + def _get_distance_sum(self, weight, out_idx, in_idx): """ Optimized verision of following naive implementation: def _get_distance_sum(self, weight, in_idx, out_idx): @@ -182,12 +194,13 @@ def _get_distance_sum(self, weight, in_idx, out_idx): dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2) return dist_sum """ + logger.debug('weight size: %s', weight.size()) if len(weight.size()) == 4: # Conv2d w = weight.view(-1, weight.size(-2), weight.size(-1)) - anchor_w = weight[in_idx, out_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) + anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) elif len(weight.size()) == 3: # Conv1d w = weight.view(-1, weight.size(-1)) - anchor_w = weight[in_idx, out_idx].unsqueeze(0).expand(w.size(0), w.size(1)) + anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1)) else: raise RuntimeError('unsupported layer type') x = w - anchor_w From 0ca60cf86804d448f02d8e37c024b6807c7cae1e Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 7 Nov 2019 13:57:34 +0800 Subject: [PATCH 11/28] updates --- examples/model_compress/fpgm_tf_mnist.py | 9 ++--- .../compression/tensorflow/builtin_pruners.py | 35 +++++++++++++------ 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/examples/model_compress/fpgm_tf_mnist.py b/examples/model_compress/fpgm_tf_mnist.py index b05dd923fd..313dab9926 100644 --- a/examples/model_compress/fpgm_tf_mnist.py +++ b/examples/model_compress/fpgm_tf_mnist.py @@ -82,20 +82,21 @@ def main(): model = Mnist() - '''you can change this to LevelPruner to implement it + """ + You can change this to LevelPruner to implement it pruner = LevelPruner(configure_list) - ''' + """ configure_list = [{ 'pruning_rate': 0.5, 'op_types': ['Conv2D'] }] - pruner = FPGMPruner(configure_list) + pruner = FPGMPruner(tf.get_default_graph(), configure_list) + pruner.compress() # if you want to load from yaml file # configure_file = nni.compressors.tf_compressor._nnimc_tf._tf_default_load_configure_file('configure_example.yaml','AGPruner') # configure_list = configure_file.get('config',[]) # pruner.load_configure(configure_list) # you can also handle it yourself and input an configure list in json - pruner(tf.get_default_graph()) # you can also use compress(model) or compress_default_graph() for tensorflow compressor # pruner.compress(tf.get_default_graph()) diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index 017e1690b1..e96f21ec9b 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -105,16 +105,21 @@ class FPGMPruner(Pruner): https://arxiv.org/pdf/1811.00250.pdf """ - def __init__(self, config_list): + def __init__(self, model, config_list): """ - config_list: supported keys: - - pruning_rate: percentage of convolutional filters to be pruned. + Parameters + ---------- + model : pytorch model + the model user wants to compress + config_list: list + support key for each list item: + - pruning_rate: percentage of convolutional filters to be pruned. """ - super().__init__(config_list) + super().__init__(model, config_list) self.mask_list = {} self.assign_handler = [] - def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name, **kwargs): + def calc_mask(self, layer, config): """supports Conv1d, Conv2d, Conv3d filter dimensions for Conv1D: LEN: filter length @@ -126,14 +131,24 @@ def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name, **kwargs): W: filter width IN: number of input channel OUT: number of output channel + + Parameters + ---------- + layer : LayerInfo + calculate mask for `layer`'s weight + config : dict + the configuration for generating the mask + """ + weight = layer.weight + op_type = layer.type + op_name = layer.name assert 0 <= config.get('pruning_rate') < 1 - # TODO uncomment this - #assert op_type in ['Conv1D', 'Conv2D', 'Conv3D'] + assert op_type in ['Conv1D', 'Conv2D'] if op_type == config['op_type']: - weight = tf.stop_gradient(tf.transpose(conv_kernel_weight, [2,3,0,1])) + weight = tf.stop_gradient(tf.transpose(weight, [2,3,0,1])) masks = tf.Variable(tf.ones_like(weight)) num_kernels = weight.shape[0].value * weight.shape[1].value @@ -144,10 +159,10 @@ def calc_mask(self, conv_kernel_weight, config, op, op_type, op_name, **kwargs): min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) tf.scatter_nd_update(masks, min_gm_idx, tf.zeros((min_gm_idx.shape[0].value, weight.shape[-2].value, weight.shape[-1].value))) masks = tf.transpose(masks, [2,3,0,1]) - self.assign_handler.append(tf.assign(conv_kernel_weight, conv_kernel_weight*masks)) + self.assign_handler.append(tf.assign(weight, weight*masks)) self.mask_list.update({op_name: masks}) else: - masks = tf.Variable(tf.ones_like(conv_kernel_weight)) + masks = tf.Variable(tf.ones_like(weight)) self.mask_list.update({op_name: masks}) return masks From cd069fd226933a4695ee4ccdb8f8f9c16c3180cc Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 7 Nov 2019 14:16:28 +0800 Subject: [PATCH 12/28] updates --- src/sdk/pynni/tests/test_compressor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index f117f7a38b..a6e4e9d601 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -115,12 +115,12 @@ def test_torch_pruner(self): def test_torch_fpgm_pruner(self): model = TorchMnist() configure_list = [{'pruning_rate': 0.5, 'op_types': ['Conv2d']}] - torch_compressor.FPGMPruner(configure_list).compress(model) + torch_compressor.FPGMPruner(model, configure_list).compress() def test_tf_fpgm_pruner(self): model = TfMnist() configure_list = [{'pruning_rate': 0.5, 'op_types': ['Conv2D']}] - tf_compressor.FPGMPruner(configure_list).compress_default_graph() + tf_compressor.FPGMPruner(tf.get_default_graph(), configure_list).compress() def test_torch_quantizer(self): model = TorchMnist() From a4a999b75a7af980269d7933a3b153310c2d6a27 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 7 Nov 2019 16:58:53 +0800 Subject: [PATCH 13/28] update documents --- docs/en_US/Compressor/Overview.md | 1 + docs/en_US/Compressor/Pruner.md | 47 +++++++++++++++++++ .../compression/tensorflow/builtin_pruners.py | 6 +-- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/docs/en_US/Compressor/Overview.md b/docs/en_US/Compressor/Overview.md index 7ee603e3e3..1cc47adf43 100644 --- a/docs/en_US/Compressor/Overview.md +++ b/docs/en_US/Compressor/Overview.md @@ -12,6 +12,7 @@ We have provided two naive compression algorithms and three popular ones for use |---|---| | [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights | | [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)| +| [FPGM Pruner](./Pruner.md#fpgm-pruner) | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration [Reference Paper](https://arxiv.org/pdf/1811.00250.pdf)| | [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits | | [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)| | [DoReFa Quantizer](./Quantizer.md#dorefa-quantizer) | DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. [Reference Paper](https://arxiv.org/abs/1606.06160)| diff --git a/docs/en_US/Compressor/Pruner.md b/docs/en_US/Compressor/Pruner.md index 731503fc2d..317f2af624 100644 --- a/docs/en_US/Compressor/Pruner.md +++ b/docs/en_US/Compressor/Pruner.md @@ -92,3 +92,50 @@ You can view example for more information *** +## FPGM Pruner +FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf) + +Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance. + +### Usage +You can prune all weight from 0% to 80% sparsity in 10 epoch with the code below. + +First, you should import pruner and add mask to model. + +Tensorflow code +```python +from nni.compression.tensorflow import FPGMPruner +config_list = [{ + 'pruning_rate': 0.5, + 'op_types': ['Conv2D'] +}] +pruner = FPGMPruner(tf.get_default_graph(), config_list) +pruner.compress() +``` +PyTorch code +```python +from nni.compression.torch import FPGMPruner +config_list = [{ + 'pruning_rate': 0.5, + 'op_types': ['Conv2d'] +}] +pruner = FPGMPruner(model, config_list) +pruner.compress() +``` + +Second, you should add code below to update epoch number when you finish one epoch in your training code. + +Tensorflow code +```python +pruner.update_epoch(epoch, sess) +``` +PyTorch code +```python +pruner.update_epoch(epoch) +``` +You can view example for more information + +#### User configuration for FPGM Pruner +* **pruning_rate:** How much percentage of convolutional filters are to be pruned. + +*** diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index e96f21ec9b..ff69ba7347 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -148,7 +148,7 @@ def calc_mask(self, layer, config): assert op_type in ['Conv1D', 'Conv2D'] if op_type == config['op_type']: - weight = tf.stop_gradient(tf.transpose(weight, [2,3,0,1])) + weight = tf.stop_gradient(tf.transpose(weight, [2, 3, 0, 1])) masks = tf.Variable(tf.ones_like(weight)) num_kernels = weight.shape[0].value * weight.shape[1].value @@ -158,7 +158,7 @@ def calc_mask(self, layer, config): return masks min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) tf.scatter_nd_update(masks, min_gm_idx, tf.zeros((min_gm_idx.shape[0].value, weight.shape[-2].value, weight.shape[-1].value))) - masks = tf.transpose(masks, [2,3,0,1]) + masks = tf.transpose(masks, [2, 3, 0, 1]) self.assign_handler.append(tf.assign(weight, weight*masks)) self.mask_list.update({op_name: masks}) else: @@ -179,7 +179,7 @@ def _get_min_gm_kernel_idx(self, weight, n): idx_list.append([in_i, out_i]) dist_tensor = tf.convert_to_tensor(dist_list) idx_tensor = tf.constant(idx_list) - + _, idx = tf.math.top_k(dist_tensor, k=n) return tf.gather(idx_tensor, idx) From bd622a2b2e1d165e359e8c1444195ef7ed99a142 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 7 Nov 2019 17:01:42 +0800 Subject: [PATCH 14/28] updates --- docs/en_US/Compressor/Pruner.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/en_US/Compressor/Pruner.md b/docs/en_US/Compressor/Pruner.md index 317f2af624..8c056c6fec 100644 --- a/docs/en_US/Compressor/Pruner.md +++ b/docs/en_US/Compressor/Pruner.md @@ -98,8 +98,6 @@ FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median f Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance. ### Usage -You can prune all weight from 0% to 80% sparsity in 10 epoch with the code below. - First, you should import pruner and add mask to model. Tensorflow code From 8a939b4abe84147e9b7014261666ebd155bd00ec Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 7 Nov 2019 17:15:15 +0800 Subject: [PATCH 15/28] updates --- .../nni/compression/torch/builtin_pruners.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 5d28c23fd8..315d1a1d85 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -186,6 +186,8 @@ def _get_min_gm_kernel_idx(self, weight, n): def _get_distance_sum(self, weight, out_idx, in_idx): """ + Calculate the total distance between a specified filter (by out_idex and in_idx) and + all other filters. Optimized verision of following naive implementation: def _get_distance_sum(self, weight, in_idx, out_idx): w = weight.view(-1, weight.size(-2), weight.size(-1)) @@ -193,6 +195,21 @@ def _get_distance_sum(self, weight, in_idx, out_idx): for k in w: dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2) return dist_sum + + Parameters + ---------- + weight: Tensor + convolutional filter weight + out_idx: int + output channel index of specified filter, this method calculates the total distance + between this specified filter and all other filters. + in_idx: int + input channel index of specified filter + + Returns + ------- + float32 + The total distance """ logger.debug('weight size: %s', weight.size()) if len(weight.size()) == 4: # Conv2d From 302f1bd3c9b56550687f09df4b3ece7ef9cf798e Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Wed, 13 Nov 2019 19:23:02 +0800 Subject: [PATCH 16/28] tensorflow 2.0 implementation --- examples/model_compress/fpgm_tf_mnist.py | 155 +++++------------- .../compression/tensorflow/builtin_pruners.py | 38 +++-- .../nni/compression/tensorflow/compressor.py | 80 ++++----- .../compression/tensorflow/default_layers.py | 25 ++- 4 files changed, 120 insertions(+), 178 deletions(-) diff --git a/examples/model_compress/fpgm_tf_mnist.py b/examples/model_compress/fpgm_tf_mnist.py index 313dab9926..0482b79fef 100644 --- a/examples/model_compress/fpgm_tf_mnist.py +++ b/examples/model_compress/fpgm_tf_mnist.py @@ -1,130 +1,53 @@ -from nni.compression.tensorflow import FPGMPruner import tensorflow as tf -from tensorflow.examples.tutorials.mnist import input_data - - -def weight_variable(shape): - return tf.Variable(tf.truncated_normal(shape, stddev=0.1)) - - -def bias_variable(shape): - return tf.Variable(tf.constant(0.1, shape=shape)) - - -def conv2d(x_input, w_matrix): - return tf.nn.conv2d(x_input, w_matrix, strides=[1, 1, 1, 1], padding='SAME') - - -def max_pool(x_input, pool_size): - size = [1, pool_size, pool_size, 1] - return tf.nn.max_pool(x_input, ksize=size, strides=size, padding='SAME') - - -class Mnist: - def __init__(self): - images = tf.placeholder(tf.float32, [None, 784], name='input_x') - labels = tf.placeholder(tf.float32, [None, 10], name='input_y') - keep_prob = tf.placeholder(tf.float32, name='keep_prob') - - self.images = images - self.labels = labels - self.keep_prob = keep_prob - - self.train_step = None - self.accuracy = None - - self.w1 = None - self.b1 = None - self.fcw1 = None - self.cross = None - with tf.name_scope('reshape'): - x_image = tf.reshape(images, [-1, 28, 28, 1]) - with tf.name_scope('conv1'): - w_conv1 = weight_variable([5, 5, 1, 32]) - self.w1 = w_conv1 - b_conv1 = bias_variable([32]) - self.b1 = b_conv1 - h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) - with tf.name_scope('pool1'): - h_pool1 = max_pool(h_conv1, 2) - with tf.name_scope('conv2'): - w_conv2 = weight_variable([5, 5, 32, 64]) - b_conv2 = bias_variable([64]) - h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) - with tf.name_scope('pool2'): - h_pool2 = max_pool(h_conv2, 2) - with tf.name_scope('fc1'): - w_fc1 = weight_variable([7 * 7 * 64, 1024]) - self.fcw1 = w_fc1 - b_fc1 = bias_variable([1024]) - h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) - h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1) - with tf.name_scope('dropout'): - h_fc1_drop = tf.nn.dropout(h_fc1, 0.5) - with tf.name_scope('fc2'): - w_fc2 = weight_variable([1024, 10]) - b_fc2 = bias_variable([10]) - y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2 - with tf.name_scope('loss'): - cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_conv)) - self.cross = cross_entropy - with tf.name_scope('adam_optimizer'): - self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy) - with tf.name_scope('accuracy'): - correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1)) - self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) +from tensorflow import keras +assert tf.__version__ >= "2.0" +import numpy as np +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout +from nni.compression.tensorflow import FPGMPruner +def get_data(): + (X_train_full, y_train_full), _ = keras.datasets.mnist.load_data() + X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:] + y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:] + + X_mean = X_train.mean(axis=0, keepdims=True) + X_std = X_train.std(axis=0, keepdims=True) + 1e-7 + X_train = (X_train - X_mean) / X_std + X_valid = (X_valid - X_mean) / X_std + + X_train = X_train[..., np.newaxis] + X_valid = X_valid[..., np.newaxis] + + return X_train, X_valid, y_train, y_valid + +def get_model(): + model = keras.models.Sequential([ + Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), + MaxPooling2D(pool_size=2), + Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"), + MaxPooling2D(pool_size=2), + Flatten(), + Dense(units=128, activation='relu'), + Dropout(0.5), + Dense(units=10, activation='softmax'), + ]) + model.compile(loss="sparse_categorical_crossentropy", + optimizer=keras.optimizers.SGD(lr=1e-3), + metrics=["accuracy"]) + return model def main(): - tf.set_random_seed(0) - - data = input_data.read_data_sets('data', one_hot=True) + X_train, X_valid, y_train, y_valid = get_data() + model = get_model() - model = Mnist() - - """ - You can change this to LevelPruner to implement it - pruner = LevelPruner(configure_list) - """ configure_list = [{ 'pruning_rate': 0.5, 'op_types': ['Conv2D'] }] - pruner = FPGMPruner(tf.get_default_graph(), configure_list) + pruner = FPGMPruner(model, configure_list) pruner.compress() - # if you want to load from yaml file - # configure_file = nni.compressors.tf_compressor._nnimc_tf._tf_default_load_configure_file('configure_example.yaml','AGPruner') - # configure_list = configure_file.get('config',[]) - # pruner.load_configure(configure_list) - # you can also handle it yourself and input an configure list in json - # you can also use compress(model) or compress_default_graph() for tensorflow compressor - # pruner.compress(tf.get_default_graph()) - - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - for batch_idx in range(2000): - if batch_idx % 10 == 0: - pruner.update_epoch(batch_idx / 10, sess) - batch = data.train.next_batch(2000) - model.train_step.run(feed_dict={ - model.images: batch[0], - model.labels: batch[1], - model.keep_prob: 0.5 - }) - if batch_idx % 10 == 0: - test_acc = model.accuracy.eval(feed_dict={ - model.images: data.test.images, - model.labels: data.test.labels, - model.keep_prob: 1.0 - }) - print('test accuracy', test_acc) - test_acc = model.accuracy.eval(feed_dict={ - model.images: data.test.images, - model.labels: data.test.labels, - model.keep_prob: 1.0 - }) - print('final result is', test_acc) + model.fit(X_train, y_train, epochs=2, validation_data=(X_valid, y_valid)) if __name__ == '__main__': diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index ff69ba7347..7a07396973 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -1,4 +1,5 @@ import logging +import numpy as np import tensorflow as tf from .compressor import Pruner @@ -118,6 +119,7 @@ def __init__(self, model, config_list): super().__init__(model, config_list) self.mask_list = {} self.assign_handler = [] + self.epoch_pruned_layers = set() def calc_mask(self, layer, config): """supports Conv1d, Conv2d, Conv3d @@ -146,34 +148,38 @@ def calc_mask(self, layer, config): op_name = layer.name assert 0 <= config.get('pruning_rate') < 1 assert op_type in ['Conv1D', 'Conv2D'] + assert op_type in config['op_types'] - if op_type == config['op_type']: + if layer.name in self.epoch_pruned_layers: + assert layer.name in self.mask_list + return self.mask_list.get(layer.name) + + try: weight = tf.stop_gradient(tf.transpose(weight, [2, 3, 0, 1])) - masks = tf.Variable(tf.ones_like(weight)) + masks = np.ones(weight.shape) - num_kernels = weight.shape[0].value * weight.shape[1].value + num_kernels = weight.shape[0] * weight.shape[1] num_prune = int(num_kernels * config.get('pruning_rate')) if num_kernels < 2 or num_prune < 1: - self.mask_list.update({op_name: masks}) return masks min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) - tf.scatter_nd_update(masks, min_gm_idx, tf.zeros((min_gm_idx.shape[0].value, weight.shape[-2].value, weight.shape[-1].value))) - masks = tf.transpose(masks, [2, 3, 0, 1]) - self.assign_handler.append(tf.assign(weight, weight*masks)) - self.mask_list.update({op_name: masks}) - else: - masks = tf.Variable(tf.ones_like(weight)) + for idx in min_gm_idx: + masks[tuple(idx)] = 0. + finally: + masks = np.transpose(masks, [2, 3, 0, 1]) + masks = tf.Variable(masks) self.mask_list.update({op_name: masks}) + self.epoch_pruned_layers.add(layer.name) return masks def _get_min_gm_kernel_idx(self, weight, n): assert len(weight.shape) >= 3 - assert weight.shape[0].value * weight.shape[1].value > 2 + assert weight.shape[0] * weight.shape[1] > 2 dist_list, idx_list = [], [] - for in_i in range(weight.shape[0].value): - for out_i in range(weight.shape[1].value): + for in_i in range(weight.shape[0]): + for out_i in range(weight.shape[1]): dist_sum = self._get_distance_sum(weight, in_i, out_i) dist_list.append(dist_sum) idx_list.append([in_i, out_i]) @@ -184,12 +190,12 @@ def _get_min_gm_kernel_idx(self, weight, n): return tf.gather(idx_tensor, idx) def _get_distance_sum(self, weight, in_idx, out_idx): - w = tf.reshape(weight, (-1, weight.shape[-2].value, weight.shape[-1].value)) - anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0].value, 1, 1]) + w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1])) + anchor_w = tf.tile(tf.expand_dims(weight[in_idx, out_idx], 0), [w.shape[0], 1, 1]) x = w - anchor_w x = tf.math.reduce_sum((x*x), (-2, -1)) x = tf.math.sqrt(x) return tf.math.reduce_sum(x) def update_epoch(self, epoch, sess): - sess.run(self.assign_handler) + self.epoch_pruned_layers = set() diff --git a/src/sdk/pynni/nni/compression/tensorflow/compressor.py b/src/sdk/pynni/nni/compression/tensorflow/compressor.py index 6382c25a8a..bae47b77c4 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/compressor.py +++ b/src/sdk/pynni/nni/compression/tensorflow/compressor.py @@ -1,18 +1,20 @@ import logging import tensorflow as tf from . import default_layers +tf.config.experimental_run_functions_eagerly(True) _logger = logging.getLogger(__name__) class LayerInfo: - def __init__(self, op, weight, weight_op): - self.op = op - self.name = op.name - self.type = op.type - self.weight = weight - self.weight_op = weight_op - + def __init__(self, keras_layer): + self.keras_layer = keras_layer + self.name = keras_layer.name + self.type = default_layers.get_op_type(type(keras_layer)) + self.weight_index = default_layers.get_weight_index(self.type) + if self.weight_index is not None: + self.weight = keras_layer.weights[self.weight_index] + self._call = None class Compressor: """ @@ -35,25 +37,13 @@ def __init__(self, model, config_list): self.modules_to_compress = [] def compress(self): - """ - Compress the model with algorithm implemented by subclass. - - The model will be instrumented and user should never edit it after calling this method. - `self.modules_to_compress` records all the to-be-compressed layers - """ - for op in self.bound_model.get_operations(): - weight_index = _detect_weight_index(op) - if weight_index is None: - _logger.warning('Failed to detect weight for layer %s', op.name) - return - weight_op = op.inputs[weight_index].op - weight = weight_op.inputs[0] - - layer = LayerInfo(op, weight, weight_op) + for keras_layer in self.bound_model.layers: + layer = LayerInfo(keras_layer) config = self.select_config(layer) if config is not None: self._instrument_layer(layer, config) self.modules_to_compress.append((layer, config)) + return self.bound_model def get_modules_to_compress(self): @@ -74,7 +64,7 @@ def select_config(self, layer): Parameters ---------- - layer : LayerInfo + op : LayerInfo one layer Returns @@ -84,11 +74,12 @@ def select_config(self, layer): not be compressed """ ret = None + if layer.type is None: + return None for config in self.config_list: - op_types = config.get('op_types') - if op_types == 'default': - op_types = default_layers.op_weight_index.keys() - if op_types and layer.type not in op_types: + config = config.copy() + config['op_types'] = self._expand_config_op_types(config) + if layer.type not in config['op_types']: continue if config.get('op_names') and layer.name not in config['op_names']: continue @@ -127,6 +118,18 @@ def _instrument_layer(self, layer, config): """ raise NotImplementedError() + def _expand_config_op_types(self, config): + if config is None: + return [] + op_types = [] + + for op_type in config.get('op_types', []): + if op_type == 'default': + op_types.extend(default_layers.default_layers) + else: + op_types.append(op_type) + return op_types + class Pruner(Compressor): """ @@ -160,10 +163,17 @@ def _instrument_layer(self, layer, config): config : dict the configuration for generating the mask """ - mask = self.calc_mask(layer, config) - new_weight = layer.weight * mask - tf.contrib.graph_editor.swap_outputs(layer.weight_op, new_weight.op) + layer._call = layer.keras_layer.call + def new_call(*inputs): + weights = [x.numpy() for x in layer.keras_layer.weights] + mask = self.calc_mask(layer, config) + weights[layer.weight_index] = weights[layer.weight_index] * mask + layer.keras_layer.set_weights(weights) + ret = layer._call(*inputs) + return ret + + layer.keras_layer.call = new_call class Quantizer(Compressor): """ @@ -182,13 +192,3 @@ def _instrument_layer(self, layer, config): weight = weight_op.inputs[0] new_weight = self.quantize_weight(weight, config, op=layer.op, op_type=layer.type, op_name=layer.name) tf.contrib.graph_editor.swap_outputs(weight_op, new_weight.op) - - -def _detect_weight_index(layer): - index = default_layers.op_weight_index.get(layer.type) - if index is not None: - return index - weight_indices = [i for i, op in enumerate(layer.inputs) if op.name.endswith('Variable/read')] - if len(weight_indices) == 1: - return weight_indices[0] - return None diff --git a/src/sdk/pynni/nni/compression/tensorflow/default_layers.py b/src/sdk/pynni/nni/compression/tensorflow/default_layers.py index 0f44ca2987..bdf9627fba 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/default_layers.py +++ b/src/sdk/pynni/nni/compression/tensorflow/default_layers.py @@ -1,8 +1,21 @@ -op_weight_index = { - 'Conv2D': None, - 'Conv3D': None, - 'DepthwiseConv2dNative': None, +from tensorflow import keras - 'Mul': None, - 'MatMul': None, +supported_layers = { + keras.layers.Dense: ('Dense', 0), + keras.layers.Conv1D: ('Conv1D', 0), + keras.layers.Conv2D: ('Conv2D', 0), } + +default_layers = ['Dense', 'Conv2D'] + +def get_op_type(layer_type): + if layer_type in supported_layers: + return supported_layers[layer_type][0] + else: + return None + +def get_weight_index(op_type): + for k in supported_layers: + if supported_layers[k][0] == op_type: + return supported_layers[k][1] + return None From a3e4b903f39d8be264fd60827367b72a152d7234 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Wed, 13 Nov 2019 19:55:12 +0800 Subject: [PATCH 17/28] updates --- docs/en_US/Compressor/Pruner.md | 8 ++++---- examples/model_compress/fpgm_tf_mnist.py | 10 ++++++---- examples/model_compress/fpgm_torch_mnist.py | 2 +- .../compression/tensorflow/builtin_pruners.py | 16 ++++++++-------- .../nni/compression/torch/builtin_pruners.py | 17 +++++++++-------- src/sdk/pynni/tests/test_compressor.py | 4 ++-- 6 files changed, 30 insertions(+), 27 deletions(-) diff --git a/docs/en_US/Compressor/Pruner.md b/docs/en_US/Compressor/Pruner.md index 8c056c6fec..42162edc19 100644 --- a/docs/en_US/Compressor/Pruner.md +++ b/docs/en_US/Compressor/Pruner.md @@ -104,17 +104,17 @@ Tensorflow code ```python from nni.compression.tensorflow import FPGMPruner config_list = [{ - 'pruning_rate': 0.5, + 'sparsity': 0.5, 'op_types': ['Conv2D'] }] -pruner = FPGMPruner(tf.get_default_graph(), config_list) +pruner = FPGMPruner(model, config_list) pruner.compress() ``` PyTorch code ```python from nni.compression.torch import FPGMPruner config_list = [{ - 'pruning_rate': 0.5, + 'sparsity': 0.5, 'op_types': ['Conv2d'] }] pruner = FPGMPruner(model, config_list) @@ -134,6 +134,6 @@ pruner.update_epoch(epoch) You can view example for more information #### User configuration for FPGM Pruner -* **pruning_rate:** How much percentage of convolutional filters are to be pruned. +* **sparsity:** How much percentage of convolutional filters are to be pruned. *** diff --git a/examples/model_compress/fpgm_tf_mnist.py b/examples/model_compress/fpgm_tf_mnist.py index 0482b79fef..3cf15ed501 100644 --- a/examples/model_compress/fpgm_tf_mnist.py +++ b/examples/model_compress/fpgm_tf_mnist.py @@ -32,8 +32,8 @@ def get_model(): Dense(units=10, activation='softmax'), ]) model.compile(loss="sparse_categorical_crossentropy", - optimizer=keras.optimizers.SGD(lr=1e-3), - metrics=["accuracy"]) + optimizer=keras.optimizers.SGD(lr=1e-3), + metrics=["accuracy"]) return model def main(): @@ -41,13 +41,15 @@ def main(): model = get_model() configure_list = [{ - 'pruning_rate': 0.5, + 'sparsity': 0.5, 'op_types': ['Conv2D'] }] pruner = FPGMPruner(model, configure_list) pruner.compress() - model.fit(X_train, y_train, epochs=2, validation_data=(X_valid, y_valid)) + update_epoch_callback = keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: pruner.update_epoch(epoch)) + + model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid), callbacks=[update_epoch_callback]) if __name__ == '__main__': diff --git a/examples/model_compress/fpgm_torch_mnist.py b/examples/model_compress/fpgm_torch_mnist.py index 30bc581dc2..0d85e5d229 100644 --- a/examples/model_compress/fpgm_torch_mnist.py +++ b/examples/model_compress/fpgm_torch_mnist.py @@ -82,7 +82,7 @@ def main(): pruner = LevelPruner(configure_list) ''' configure_list = [{ - 'pruning_rate': 0.5, + 'sparsity': 0.5, 'op_types': ['Conv2d'] }] diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index 7a07396973..fd28ec6aae 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -114,10 +114,10 @@ def __init__(self, model, config_list): the model user wants to compress config_list: list support key for each list item: - - pruning_rate: percentage of convolutional filters to be pruned. + - sparsity: percentage of convolutional filters to be pruned. """ super().__init__(model, config_list) - self.mask_list = {} + self.mask_dict = {} self.assign_handler = [] self.epoch_pruned_layers = set() @@ -146,20 +146,20 @@ def calc_mask(self, layer, config): weight = layer.weight op_type = layer.type op_name = layer.name - assert 0 <= config.get('pruning_rate') < 1 + assert 0 <= config.get('sparsity') < 1 assert op_type in ['Conv1D', 'Conv2D'] assert op_type in config['op_types'] if layer.name in self.epoch_pruned_layers: - assert layer.name in self.mask_list - return self.mask_list.get(layer.name) + assert layer.name in self.mask_dict + return self.mask_dict.get(layer.name) try: weight = tf.stop_gradient(tf.transpose(weight, [2, 3, 0, 1])) masks = np.ones(weight.shape) num_kernels = weight.shape[0] * weight.shape[1] - num_prune = int(num_kernels * config.get('pruning_rate')) + num_prune = int(num_kernels * config.get('sparsity')) if num_kernels < 2 or num_prune < 1: return masks min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) @@ -168,7 +168,7 @@ def calc_mask(self, layer, config): finally: masks = np.transpose(masks, [2, 3, 0, 1]) masks = tf.Variable(masks) - self.mask_list.update({op_name: masks}) + self.mask_dict.update({op_name: masks}) self.epoch_pruned_layers.add(layer.name) return masks @@ -197,5 +197,5 @@ def _get_distance_sum(self, weight, in_idx, out_idx): x = tf.math.sqrt(x) return tf.math.reduce_sum(x) - def update_epoch(self, epoch, sess): + def update_epoch(self, epoch): self.epoch_pruned_layers = set() diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 315d1a1d85..2b0a0391a3 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -122,10 +122,11 @@ def __init__(self, model, config_list): the model user wants to compress config_list: list support key for each list item: - - pruning_rate: percentage of convolutional filters to be pruned. + - sparsity: percentage of convolutional filters to be pruned. """ super().__init__(model, config_list) - self.mask_list = {} + self.mask_dict = {} + self.epoch_pruned_layers = set() def calc_mask(self, layer, config): """ @@ -149,26 +150,26 @@ def calc_mask(self, layer, config): the configuration for generating the mask """ weight = layer.module.weight.data - assert 0 <= config.get('pruning_rate') < 1 + assert 0 <= config.get('sparsity') < 1 assert layer.type in ['Conv1d', 'Conv2d'] assert layer.type in config['op_types'] if layer.name in self.epoch_pruned_layers: - assert layer.name in self.mask_list - return self.mask_list.get(layer.name) + assert layer.name in self.mask_dict + return self.mask_dict.get(layer.name) - masks = torch.ones(weight.size()) + masks = torch.ones(weight.size()).type_as(weight) try: num_kernels = weight.size(0) * weight.size(1) - num_prune = int(num_kernels * config.get('pruning_rate')) + num_prune = int(num_kernels * config.get('sparsity')) if num_kernels < 2 or num_prune < 1: return masks min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune) for idx in min_gm_idx: masks[idx] = 0. finally: - self.mask_list.update({layer.name: masks}) + self.mask_dict.update({layer.name: masks}) self.epoch_pruned_layers.add(layer.name) return masks diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index a6e4e9d601..9a9eb07c54 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -114,12 +114,12 @@ def test_torch_pruner(self): def test_torch_fpgm_pruner(self): model = TorchMnist() - configure_list = [{'pruning_rate': 0.5, 'op_types': ['Conv2d']}] + configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}] torch_compressor.FPGMPruner(model, configure_list).compress() def test_tf_fpgm_pruner(self): model = TfMnist() - configure_list = [{'pruning_rate': 0.5, 'op_types': ['Conv2D']}] + configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2D']}] tf_compressor.FPGMPruner(tf.get_default_graph(), configure_list).compress() def test_torch_quantizer(self): From 20aedfccbac1f6136ae3eaf704641e71a7e3a9fb Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Wed, 13 Nov 2019 19:56:32 +0800 Subject: [PATCH 18/28] updates --- src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index fd28ec6aae..4a95d236cf 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -102,7 +102,7 @@ def update_epoch(self, epoch, sess): class FPGMPruner(Pruner): """A filter pruner via geometric median. - "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", + "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", https://arxiv.org/pdf/1811.00250.pdf """ From 676348b94018cb37139bbecfa836035e00aef6c7 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Wed, 13 Nov 2019 20:01:00 +0800 Subject: [PATCH 19/28] updates --- src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index 4a95d236cf..026af0d357 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -122,7 +122,7 @@ def __init__(self, model, config_list): self.epoch_pruned_layers = set() def calc_mask(self, layer, config): - """supports Conv1d, Conv2d, Conv3d + """supports Conv1D, Conv2D filter dimensions for Conv1D: LEN: filter length IN: number of input channel From 2b22a1ae5266adcf0bd6475cf2aae0ffb009892b Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Wed, 13 Nov 2019 20:06:27 +0800 Subject: [PATCH 20/28] updates --- src/sdk/pynni/nni/compression/tensorflow/compressor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sdk/pynni/nni/compression/tensorflow/compressor.py b/src/sdk/pynni/nni/compression/tensorflow/compressor.py index bae47b77c4..36b0119f71 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/compressor.py +++ b/src/sdk/pynni/nni/compression/tensorflow/compressor.py @@ -64,7 +64,7 @@ def select_config(self, layer): Parameters ---------- - op : LayerInfo + layer: LayerInfo one layer Returns From 9e68e2bf7da6c5afa13ad6a446a48f41be171063 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 14 Nov 2019 11:36:11 +0800 Subject: [PATCH 21/28] updates --- docs/en_US/Compressor/Pruner.md | 4 +- .../nni/compression/tensorflow/compressor.py | 45 ++++--- src/sdk/pynni/tests/test_compressor.py | 123 ++++++------------ 3 files changed, 65 insertions(+), 107 deletions(-) diff --git a/docs/en_US/Compressor/Pruner.md b/docs/en_US/Compressor/Pruner.md index 42162edc19..37aa9cf437 100644 --- a/docs/en_US/Compressor/Pruner.md +++ b/docs/en_US/Compressor/Pruner.md @@ -95,7 +95,7 @@ You can view example for more information ## FPGM Pruner FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf) -Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance. +>Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance. ### Usage First, you should import pruner and add mask to model. @@ -121,7 +121,7 @@ pruner = FPGMPruner(model, config_list) pruner.compress() ``` -Second, you should add code below to update epoch number when you finish one epoch in your training code. +Second, you should add code below to update epoch number at beginning of each epoch. Tensorflow code ```python diff --git a/src/sdk/pynni/nni/compression/tensorflow/compressor.py b/src/sdk/pynni/nni/compression/tensorflow/compressor.py index 36b0119f71..c0c9ee9c14 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/compressor.py +++ b/src/sdk/pynni/nni/compression/tensorflow/compressor.py @@ -27,7 +27,7 @@ def __init__(self, model, config_list): Parameters ---------- - model : pytorch model + model : keras model the model user wants to compress config_list : list the configurations that users specify for compression @@ -36,14 +36,31 @@ def __init__(self, model, config_list): self.config_list = config_list self.modules_to_compress = [] + def detect_modules_to_compress(self): + """ + detect all modules should be compressed, and save the result in `self.modules_to_compress`. + + The model will be instrumented and user should never edit it after calling this method. + """ + if self.modules_to_compress is None: + self.modules_to_compress = [] + for keras_layer in self.bound_model.layers: + layer = LayerInfo(keras_layer) + config = self.select_config(layer) + if config is not None: + self.modules_to_compress.append((layer, config)) + return self.modules_to_compress + def compress(self): - for keras_layer in self.bound_model.layers: - layer = LayerInfo(keras_layer) - config = self.select_config(layer) - if config is not None: - self._instrument_layer(layer, config) - self.modules_to_compress.append((layer, config)) + """ + Compress the model with algorithm implemented by subclass. + The model will be instrumented and user should never edit it after calling this method. + `self.modules_to_compress` records all the to-be-compressed layers + """ + modules_to_compress = self.detect_modules_to_compress() + for layer, config in modules_to_compress: + self._instrument_layer(layer, config) return self.bound_model def get_modules_to_compress(self): @@ -88,7 +105,7 @@ def select_config(self, layer): return None return ret - def update_epoch(self, epoch, sess): + def update_epoch(self, epoch): """ If user want to update model every epoch, user can override this method. This method should be called at the beginning of each epoch @@ -99,7 +116,7 @@ def update_epoch(self, epoch, sess): the current epoch number """ - def step(self, sess): + def step(self): """ If user want to update mask every step, user can override this method """ @@ -182,13 +199,3 @@ class Quantizer(Compressor): def quantize_weight(self, weight, config, op, op_type, op_name): raise NotImplementedError("Quantizer must overload quantize_weight()") - - def _instrument_layer(self, layer, config): - weight_index = _detect_weight_index(layer) - if weight_index is None: - _logger.warning('Failed to detect weight for layer %s', layer.name) - return - weight_op = layer.op.inputs[weight_index].op - weight = weight_op.inputs[0] - new_weight = self.quantize_weight(weight, config, op=layer.op, op_type=layer.type, op_name=layer.name) - tf.contrib.graph_editor.swap_outputs(weight_op, new_weight.op) diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index 5bec118321..5e10c075d0 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -5,78 +5,21 @@ import nni.compression.tensorflow as tf_compressor import nni.compression.torch as torch_compressor - -def weight_variable(shape): - return tf.Variable(tf.truncated_normal(shape, stddev=0.1)) - - -def bias_variable(shape): - return tf.Variable(tf.constant(0.1, shape=shape)) - - -def conv2d(x_input, w_matrix): - return tf.nn.conv2d(x_input, w_matrix, strides=[1, 1, 1, 1], padding='SAME') - - -def max_pool(x_input, pool_size): - size = [1, pool_size, pool_size, 1] - return tf.nn.max_pool(x_input, ksize=size, strides=size, padding='SAME') - - -class TfMnist: - def __init__(self): - images = tf.placeholder(tf.float32, [None, 784], name='input_x') - labels = tf.placeholder(tf.float32, [None, 10], name='input_y') - keep_prob = tf.placeholder(tf.float32, name='keep_prob') - - self.images = images - self.labels = labels - self.keep_prob = keep_prob - - self.train_step = None - self.accuracy = None - - self.w1 = None - self.b1 = None - self.fcw1 = None - self.cross = None - with tf.name_scope('reshape'): - x_image = tf.reshape(images, [-1, 28, 28, 1]) - with tf.name_scope('conv1'): - w_conv1 = weight_variable([5, 5, 1, 32]) - self.w1 = w_conv1 - b_conv1 = bias_variable([32]) - self.b1 = b_conv1 - h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) - with tf.name_scope('pool1'): - h_pool1 = max_pool(h_conv1, 2) - with tf.name_scope('conv2'): - w_conv2 = weight_variable([5, 5, 32, 64]) - b_conv2 = bias_variable([64]) - h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) - with tf.name_scope('pool2'): - h_pool2 = max_pool(h_conv2, 2) - with tf.name_scope('fc1'): - w_fc1 = weight_variable([7 * 7 * 64, 1024]) - self.fcw1 = w_fc1 - b_fc1 = bias_variable([1024]) - h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) - h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1) - with tf.name_scope('dropout'): - h_fc1_drop = tf.nn.dropout(h_fc1, 0.5) - with tf.name_scope('fc2'): - w_fc2 = weight_variable([1024, 10]) - b_fc2 = bias_variable([10]) - y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2 - with tf.name_scope('loss'): - cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_conv)) - self.cross = cross_entropy - with tf.name_scope('adam_optimizer'): - self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy) - with tf.name_scope('accuracy'): - correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1)) - self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - +def get_tf_mnist_model(): + model = keras.models.Sequential([ + tf.keras.layers.Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), + tf.keras.layers.MaxPooling2D(pool_size=2), + tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"), + tf.keras.layers.MaxPooling2D(pool_size=2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(units=128, activation='relu'), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(units=10, activation='softmax'), + ]) + model.compile(loss="sparse_categorical_crossentropy", + optimizer=keras.optimizers.SGD(lr=1e-3), + metrics=["accuracy"]) + return model class TorchMnist(torch.nn.Module): def __init__(self): @@ -96,17 +39,13 @@ def forward(self, x): x = self.fc2(x) return F.log_softmax(x, dim=1) +def tf2(func): + def test_tf2_func(self): + if tf.__version__ >= '2.0': + func() + return test_tf20_func class CompressorTestCase(TestCase): - def test_tf_pruner(self): - model = TfMnist() - configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] - tf_compressor.LevelPruner(tf.get_default_graph(), configure_list).compress() - - def test_tf_quantizer(self): - model = TfMnist() - tf_compressor.NaiveQuantizer(tf.get_default_graph(), [{'op_types': ['default']}]).compress() - def test_torch_pruner(self): model = TorchMnist() configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] @@ -117,11 +56,6 @@ def test_torch_fpgm_pruner(self): configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}] torch_compressor.FPGMPruner(model, configure_list).compress() - def test_tf_fpgm_pruner(self): - model = TfMnist() - configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2D']}] - tf_compressor.FPGMPruner(tf.get_default_graph(), configure_list).compress() - def test_torch_quantizer(self): model = TorchMnist() configure_list = [{ @@ -133,6 +67,23 @@ def test_torch_quantizer(self): }] torch_compressor.NaiveQuantizer(model, configure_list).compress() + @tf2 + def test_tf_pruner(self): + model = TfMnist() + configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] + tf_compressor.LevelPruner(get_tf_mnist_model(), configure_list).compress() + + @tf2 + def test_tf_quantizer(self): + model = TfMnist() + tf_compressor.NaiveQuantizer(get_tf_mnist_model(), [{'op_types': ['default']}]).compress() + + @tf2 + def test_tf_fpgm_pruner(self): + model = TfMnist() + configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2D']}] + tf_compressor.FPGMPruner(get_tf_mnist_model(), configure_list).compress() + if __name__ == '__main__': main() From eadc94175e5cf2e35765a77895de31d125244409 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 14 Nov 2019 11:42:01 +0800 Subject: [PATCH 22/28] updates --- src/sdk/pynni/tests/test_compressor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index 5e10c075d0..c885e60a9d 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -43,7 +43,7 @@ def tf2(func): def test_tf2_func(self): if tf.__version__ >= '2.0': func() - return test_tf20_func + return test_tf2_func class CompressorTestCase(TestCase): def test_torch_pruner(self): From 2978b7c6961df83c11cc55d0cb770b4e196c162f Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 14 Nov 2019 11:44:37 +0800 Subject: [PATCH 23/28] updates --- src/sdk/pynni/tests/test_compressor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index c885e60a9d..60b07c15aa 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -69,18 +69,15 @@ def test_torch_quantizer(self): @tf2 def test_tf_pruner(self): - model = TfMnist() configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] tf_compressor.LevelPruner(get_tf_mnist_model(), configure_list).compress() @tf2 def test_tf_quantizer(self): - model = TfMnist() tf_compressor.NaiveQuantizer(get_tf_mnist_model(), [{'op_types': ['default']}]).compress() @tf2 def test_tf_fpgm_pruner(self): - model = TfMnist() configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2D']}] tf_compressor.FPGMPruner(get_tf_mnist_model(), configure_list).compress() From 03d71dae384eee420e88a8b602a6cf6f881b6a90 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 14 Nov 2019 11:50:12 +0800 Subject: [PATCH 24/28] updates --- src/sdk/pynni/nni/compression/torch/compressor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sdk/pynni/nni/compression/torch/compressor.py b/src/sdk/pynni/nni/compression/torch/compressor.py index bb06524fba..2f1c8da2cc 100644 --- a/src/sdk/pynni/nni/compression/torch/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/compressor.py @@ -91,6 +91,7 @@ def select_config(self, layer): """ ret = None for config in self.config_list: + config = config.copy() config['op_types'] = self._expand_config_op_types(config) if layer.type not in config['op_types']: continue From 22b64751180e4eaca639546e0b007ab038d6ccf3 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 14 Nov 2019 15:40:41 +0800 Subject: [PATCH 25/28] updates --- .../nni/compression/tensorflow/default_layers.py | 11 +++++++++-- src/sdk/pynni/tests/test_compressor.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/sdk/pynni/nni/compression/tensorflow/default_layers.py b/src/sdk/pynni/nni/compression/tensorflow/default_layers.py index bdf9627fba..306d75a3e3 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/default_layers.py +++ b/src/sdk/pynni/nni/compression/tensorflow/default_layers.py @@ -1,12 +1,19 @@ from tensorflow import keras supported_layers = { - keras.layers.Dense: ('Dense', 0), keras.layers.Conv1D: ('Conv1D', 0), keras.layers.Conv2D: ('Conv2D', 0), + keras.layers.Conv2DTranspose: ('Conv2DTranspose', 0), + keras.layers.Conv3D: ('Conv3D', 0), + keras.layers.Conv3DTranspose: ('Conv3DTranspose', 0), + keras.layers.ConvLSTM2D: ('ConvLSTM2D', 0), + keras.layers.Dense: ('Dense', 0), + keras.layers.Embedding: ('Embedding', 0), + keras.layers.GRU: ('GRU', 0), + keras.layers.LSTM: ('LSTM', 0), } -default_layers = ['Dense', 'Conv2D'] +default_layers = [x[0] for x in supported_layers.values()] def get_op_type(layer_type): if layer_type in supported_layers: diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index 60b07c15aa..6332fdbc97 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -6,7 +6,7 @@ import nni.compression.torch as torch_compressor def get_tf_mnist_model(): - model = keras.models.Sequential([ + model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), tf.keras.layers.MaxPooling2D(pool_size=2), tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"), From 053e3d1f4127d54405adb1dd280fb51f64402156 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 14 Nov 2019 15:54:48 +0800 Subject: [PATCH 26/28] updates --- src/sdk/pynni/tests/test_compressor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index 6332fdbc97..f40bd9485b 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -2,9 +2,11 @@ import tensorflow as tf import torch import torch.nn.functional as F -import nni.compression.tensorflow as tf_compressor import nni.compression.torch as torch_compressor +if tf.__version__ >= '2.0': + import nni.compression.tensorflow as tf_compressor + def get_tf_mnist_model(): model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), @@ -17,7 +19,7 @@ def get_tf_mnist_model(): tf.keras.layers.Dense(units=10, activation='softmax'), ]) model.compile(loss="sparse_categorical_crossentropy", - optimizer=keras.optimizers.SGD(lr=1e-3), + optimizer=tf.keras.optimizers.SGD(lr=1e-3), metrics=["accuracy"]) return model From 08f523778c8422d65a4c147de3fd708ead810fa8 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 14 Nov 2019 18:14:35 +0800 Subject: [PATCH 27/28] updates --- docs/en_US/Compressor/Pruner.md | 1 + .../pynni/nni/compression/tensorflow/builtin_pruners.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/en_US/Compressor/Pruner.md b/docs/en_US/Compressor/Pruner.md index 37aa9cf437..4fcaa9d455 100644 --- a/docs/en_US/Compressor/Pruner.md +++ b/docs/en_US/Compressor/Pruner.md @@ -120,6 +120,7 @@ config_list = [{ pruner = FPGMPruner(model, config_list) pruner.compress() ``` +Note: FPGM Pruner is used to prune convolutional layers within deep neural network, therefore the `op_types` field supports only convolutional layers. Second, you should add code below to update epoch number at beginning of each epoch. diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index 026af0d357..b43195c945 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -101,7 +101,8 @@ def update_epoch(self, epoch, sess): self.if_init_list[k] = True class FPGMPruner(Pruner): - """A filter pruner via geometric median. + """ + A filter pruner via geometric median. "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", https://arxiv.org/pdf/1811.00250.pdf """ @@ -122,7 +123,8 @@ def __init__(self, model, config_list): self.epoch_pruned_layers = set() def calc_mask(self, layer, config): - """supports Conv1D, Conv2D + """ + Supports Conv1D, Conv2D filter dimensions for Conv1D: LEN: filter length IN: number of input channel @@ -140,7 +142,6 @@ def calc_mask(self, layer, config): calculate mask for `layer`'s weight config : dict the configuration for generating the mask - """ weight = layer.weight From ec8bb4e85df13127f9286fe59a32150728dd69b0 Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Thu, 14 Nov 2019 18:20:56 +0800 Subject: [PATCH 28/28] updates --- docs/en_US/Compressor/Pruner.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/Compressor/Pruner.md b/docs/en_US/Compressor/Pruner.md index 4fcaa9d455..5e06c02cd4 100644 --- a/docs/en_US/Compressor/Pruner.md +++ b/docs/en_US/Compressor/Pruner.md @@ -120,7 +120,7 @@ config_list = [{ pruner = FPGMPruner(model, config_list) pruner.compress() ``` -Note: FPGM Pruner is used to prune convolutional layers within deep neural network, therefore the `op_types` field supports only convolutional layers. +Note: FPGM Pruner is used to prune convolutional layers within deep neural networks, therefore the `op_types` field supports only convolutional layers. Second, you should add code below to update epoch number at beginning of each epoch.