diff --git a/docs/en_US/Compressor/L1FilterPruner.md b/docs/en_US/Compressor/L1FilterPruner.md new file mode 100644 index 0000000000..2906fde271 --- /dev/null +++ b/docs/en_US/Compressor/L1FilterPruner.md @@ -0,0 +1,54 @@ +L1FilterPruner on NNI Compressor +=== + +## 1. Introduction + +L1FilterPruner is a general structured pruning algorithm for pruning filters in the convolutional layers. + +In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), authors Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf. + +![](../../img/l1filter_pruner.png) + +> L1Filter Pruner prunes filters in the **convolution layers** +> +> The procedure of pruning m filters from the ith convolutional layer is as follows: +> +> 1. For each filter ![](http://latex.codecogs.com/gif.latex?F_{i,j}), calculate the sum of its absolute kernel weights![](http://latex.codecogs.com/gif.latex?s_j=\sum_{l=1}^{n_i}\sum|K_l|) +> 2. Sort the filters by ![](http://latex.codecogs.com/gif.latex?s_j). +> 3. Prune ![](http://latex.codecogs.com/gif.latex?m) filters with the smallest sum values and their corresponding feature maps. The +> kernels in the next convolutional layer corresponding to the pruned feature maps are also +> removed. +> 4. A new kernel matrix is created for both the ![](http://latex.codecogs.com/gif.latex?i)th and ![](http://latex.codecogs.com/gif.latex?i+1)th layers, and the remaining kernel +> weights are copied to the new model. + +## 2. Usage + +PyTorch code + +``` +from nni.compression.torch import L1FilterPruner +config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'], 'op_names': ['conv1', 'conv2'] }] +pruner = L1FilterPruner(model, config_list) +pruner.compress() +``` + +#### User configuration for L1Filter Pruner + +- **sparsity:** This is to specify the sparsity operations to be compressed to +- **op_types:** Only Conv2d is supported in L1Filter Pruner + +## 3. Experiment + +We implemented one of the experiments in ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), we pruned **VGG-16** for CIFAR-10 to **VGG-16-pruned-A** in the paper, in which $64\%$ parameters are pruned. Our experiments results are as follows: + +| Model | Error(paper/ours) | Parameters | Pruned | +| --------------- | ----------------- | --------------- | -------- | +| VGG-16 | 6.75/6.49 | 1.5x10^7 | | +| VGG-16-pruned-A | 6.60/6.47 | 5.4x10^6 | 64.0% | + +The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/) + + + + + diff --git a/docs/en_US/Compressor/Overview.md b/docs/en_US/Compressor/Overview.md index f992117ffa..279cbe6f89 100644 --- a/docs/en_US/Compressor/Overview.md +++ b/docs/en_US/Compressor/Overview.md @@ -12,6 +12,8 @@ We have provided two naive compression algorithms and three popular ones for use |---|---| | [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights | | [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)| +| [L1Filter Pruner](./Pruner.md#l1filter-pruner) | Pruning least important filters in convolution layers(PRUNING FILTERS FOR EFFICIENT CONVNETS)[Reference Paper](https://arxiv.org/abs/1608.08710) | +| [Slim Pruner](./Pruner.md#slim-pruner) | Pruning channels in convolution layers by pruning scaling factors in BN layers(Learning Efficient Convolutional Networks through Network Slimming)[Reference Paper](https://arxiv.org/abs/1708.06519) | | [Lottery Ticket Pruner](./Pruner.md#agp-pruner) | The pruning process used by "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks". It prunes a model iteratively. [Reference Paper](https://arxiv.org/abs/1803.03635)| | [FPGM Pruner](./Pruner.md#fpgm-pruner) | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration [Reference Paper](https://arxiv.org/pdf/1811.00250.pdf)| | [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits | diff --git a/docs/en_US/Compressor/Pruner.md b/docs/en_US/Compressor/Pruner.md index ce50d579ed..298ade1d1f 100644 --- a/docs/en_US/Compressor/Pruner.md +++ b/docs/en_US/Compressor/Pruner.md @@ -3,7 +3,7 @@ Pruner on NNI Compressor ## Level Pruner -This is one basic pruner: you can set a target sparsity level (expressed as a fraction, 0.6 means we will prune 60%). +This is one basic one-shot pruner: you can set a target sparsity level (expressed as a fraction, 0.6 means we will prune 60%). We first sort the weights in the specified layer by their absolute values. And then mask to zero the smallest magnitude weights until the desired sparsity level is reached. @@ -31,7 +31,7 @@ pruner.compress() *** ## AGP Pruner -In [To prune, or not to prune: exploring the efficacy of pruning for model compression](https://arxiv.org/abs/1710.01878), authors Michael Zhu and Suyog Gupta provide an algorithm to prune the weight gradually. +This is an iterative pruner, In [To prune, or not to prune: exploring the efficacy of pruning for model compression](https://arxiv.org/abs/1710.01878), authors Michael Zhu and Suyog Gupta provide an algorithm to prune the weight gradually. >We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value si (usually 0) to a final sparsity value sf over a span of n pruning steps, starting at training step t0 and with pruning frequency ∆t: ![](../../img/agp_pruner.png) @@ -65,7 +65,7 @@ config_list = [{ 'start_epoch': 0, 'end_epoch': 10, 'frequency': 1, - 'op_types': 'default' + 'op_types': ['default'] }] pruner = AGP_Pruner(model, config_list) pruner.compress() @@ -134,7 +134,7 @@ The above configuration means that there are 5 times of iterative pruning. As th *** ## FPGM Pruner -FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf) +This is an one-shot pruner, FPGM Pruner is an implementation of paper [Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/pdf/1811.00250.pdf) >Previous works utilized “smaller-norm-less-important” criterion to prune filters with smaller norm values in a convolutional neural network. In this paper, we analyze this norm-based criterion and point out that its effectiveness depends on two requirements that are not always met: (1) the norm deviation of the filters should be large; (2) the minimum norm of the filters should be small. To solve this problem, we propose a novel filter pruning method, namely Filter Pruning via Geometric Median (FPGM), to compress the model regardless of those two requirements. Unlike previous methods, FPGM compresses CNN models by pruning filters with redundancy, rather than those with “relatively less” importance. @@ -179,3 +179,57 @@ You can view example for more information * **sparsity:** How much percentage of convolutional filters are to be pruned. *** + +## L1Filter Pruner + +This is an one-shot pruner, In ['PRUNING FILTERS FOR EFFICIENT CONVNETS'](https://arxiv.org/abs/1608.08710), authors Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf. + +![](../../img/l1filter_pruner.png) + +> L1Filter Pruner prunes filters in the **convolution layers** +> +> The procedure of pruning m filters from the ith convolutional layer is as follows: +> +> 1. For each filter ![](http://latex.codecogs.com/gif.latex?F_{i,j}), calculate the sum of its absolute kernel weights![](http://latex.codecogs.com/gif.latex?s_j=\sum_{l=1}^{n_i}\sum|K_l|) +> 2. Sort the filters by ![](http://latex.codecogs.com/gif.latex?s_j). +> 3. Prune ![](http://latex.codecogs.com/gif.latex?m) filters with the smallest sum values and their corresponding feature maps. The +> kernels in the next convolutional layer corresponding to the pruned feature maps are also +> removed. +> 4. A new kernel matrix is created for both the ![](http://latex.codecogs.com/gif.latex?i)th and ![](http://latex.codecogs.com/gif.latex?i+1)th layers, and the remaining kernel +> weights are copied to the new model. + +``` +from nni.compression.torch import L1FilterPruner +config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }] +pruner = L1FilterPruner(model, config_list) +pruner.compress() +``` + +#### User configuration for L1Filter Pruner + +- **sparsity:** This is to specify the sparsity operations to be compressed to +- **op_types:** Only Conv2d is supported in L1Filter Pruner + +## Slim Pruner + +This is an one-shot pruner, In ['Learning Efficient Convolutional Networks through Network Slimming'](https://arxiv.org/pdf/1708.06519.pdf), authors Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang. + +![](../../img/slim_pruner.png) + +> Slim Pruner **prunes channels in the convolution layers by masking corresponding scaling factors in the later BN layers**, L1 regularization on the scaling factors should be applied in batch normalization (BN) layers while training, scaling factors of BN layers are **globally ranked** while pruning, so the sparse model can be automatically found given sparsity. + +### Usage + +PyTorch code + +``` +from nni.compression.torch import SlimPruner +config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }] +pruner = SlimPruner(model, config_list) +pruner.compress() +``` + +#### User configuration for Slim Pruner + +- **sparsity:** This is to specify the sparsity operations to be compressed to +- **op_types:** Only BatchNorm2d is supported in Slim Pruner diff --git a/docs/en_US/Compressor/SlimPruner.md b/docs/en_US/Compressor/SlimPruner.md new file mode 100644 index 0000000000..e15112711a --- /dev/null +++ b/docs/en_US/Compressor/SlimPruner.md @@ -0,0 +1,39 @@ +SlimPruner on NNI Compressor +=== + +## 1. Slim Pruner + +SlimPruner is a structured pruning algorithm for pruning channels in the convolutional layers by pruning corresponding scaling factors in the later BN layers. + +In ['Learning Efficient Convolutional Networks through Network Slimming'](https://arxiv.org/pdf/1708.06519.pdf), authors Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang. + +![](../../img/slim_pruner.png) + +> Slim Pruner **prunes channels in the convolution layers by masking corresponding scaling factors in the later BN layers**, L1 regularization on the scaling factors should be applied in batch normalization (BN) layers while training, scaling factors of BN layers are **globally ranked** while pruning, so the sparse model can be automatically found given sparsity. + +## 2. Usage + +PyTorch code + +``` +from nni.compression.torch import SlimPruner +config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }] +pruner = SlimPruner(model, config_list) +pruner.compress() +``` + +#### User configuration for Filter Pruner + +- **sparsity:** This is to specify the sparsity operations to be compressed to +- **op_types:** Only BatchNorm2d is supported in Slim Pruner + +## 3. Experiment + +We implemented one of the experiments in ['Learning Efficient Convolutional Networks through Network Slimming'](https://arxiv.org/pdf/1708.06519.pdf), we pruned $70\%$ channels in the **VGGNet** for CIFAR-10 in the paper, in which $88.5\%$ parameters are pruned. Our experiments results are as follows: + +| Model | Error(paper/ours) | Parameters | Pruned | +| ------------- | ----------------- | ---------- | --------- | +| VGGNet | 6.34/6.40 | 20.04M | | +| Pruned-VGGNet | 6.20/6.39 | 2.03M | 88.5% | + +The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/) diff --git a/docs/img/l1filter_pruner.PNG b/docs/img/l1filter_pruner.PNG new file mode 100644 index 0000000000..a4d6c498ed Binary files /dev/null and b/docs/img/l1filter_pruner.PNG differ diff --git a/docs/img/slim_pruner.PNG b/docs/img/slim_pruner.PNG new file mode 100644 index 0000000000..e7fe52f67a Binary files /dev/null and b/docs/img/slim_pruner.PNG differ diff --git a/examples/model_compress/L1_filter_pruner_torch_vgg16.py b/examples/model_compress/L1_filter_pruner_torch_vgg16.py new file mode 100644 index 0000000000..c54fc12119 --- /dev/null +++ b/examples/model_compress/L1_filter_pruner_torch_vgg16.py @@ -0,0 +1,173 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +from nni.compression.torch import L1FilterPruner + + +class vgg(nn.Module): + def __init__(self, init_weights=True): + super(vgg, self).__init__() + cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512] + self.cfg = cfg + self.feature = self.make_layers(cfg, True) + num_classes = 10 + self.classifier = nn.Sequential( + nn.Linear(cfg[-1], 512), + nn.BatchNorm1d(512), + nn.ReLU(inplace=True), + nn.Linear(512, num_classes) + ) + if init_weights: + self._initialize_weights() + + def make_layers(self, cfg, batch_norm=True): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + def forward(self, x): + x = self.feature(x) + x = nn.AvgPool2d(2)(x) + x = x.view(x.size(0), -1) + y = self.classifier(x) + return y + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(0.5) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +def train(model, device, train_loader, optimizer): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + acc = 100 * correct / len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, acc)) + return acc + + +def main(): + torch.manual_seed(0) + device = torch.device('cuda') + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=True, download=True, + transform=transforms.Compose([ + transforms.Pad(4), + transforms.RandomCrop(32), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=200, shuffle=False) + + model = vgg() + model.to(device) + + # Train the base VGG-16 model + print('=' * 10 + 'Train the unpruned base model' + '=' * 10) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0) + for epoch in range(160): + train(model, device, train_loader, optimizer) + test(model, device, test_loader) + lr_scheduler.step(epoch) + torch.save(model.state_dict(), 'vgg16_cifar10.pth') + + # Test base model accuracy + print('=' * 10 + 'Test on the original model' + '=' * 10) + model.load_state_dict(torch.load('vgg16_cifar10.pth')) + test(model, device, test_loader) + # top1 = 93.51% + + # Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS', + # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A' + configure_list = [{ + 'sparsity': 0.5, + 'op_types': ['default'], + 'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37'] + }] + + # Prune model and test accuracy without fine tuning. + print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10) + pruner = L1FilterPruner(model, configure_list) + model = pruner.compress() + test(model, device, test_loader) + # top1 = 88.19% + + # Fine tune the pruned model for 40 epochs and test accuracy + print('=' * 10 + 'Fine tuning' + '=' * 10) + optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) + best_top1 = 0 + for epoch in range(40): + pruner.update_epoch(epoch) + print('# Epoch {} #'.format(epoch)) + train(model, device, train_loader, optimizer_finetune) + top1 = test(model, device, test_loader) + if top1 > best_top1: + best_top1 = top1 + # Export the best model, 'model_path' stores state_dict of the pruned model, + # mask_path stores mask_dict of the pruned model + pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth') + + # Test the exported model + print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10) + new_model = vgg() + new_model.to(device) + new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth')) + test(new_model, device, test_loader) + # top1 = 93.53% + + +if __name__ == '__main__': + main() diff --git a/examples/model_compress/slim_pruner_torch_vgg19.py b/examples/model_compress/slim_pruner_torch_vgg19.py new file mode 100644 index 0000000000..a227e7a731 --- /dev/null +++ b/examples/model_compress/slim_pruner_torch_vgg19.py @@ -0,0 +1,176 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +from nni.compression.torch import SlimPruner + + +class vgg(nn.Module): + def __init__(self, init_weights=True): + super(vgg, self).__init__() + cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512] + self.feature = self.make_layers(cfg, True) + num_classes = 10 + self.classifier = nn.Linear(cfg[-1], num_classes) + if init_weights: + self._initialize_weights() + + def make_layers(self, cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + def forward(self, x): + x = self.feature(x) + x = nn.AvgPool2d(2)(x) + x = x.view(x.size(0), -1) + y = self.classifier(x) + return y + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(0.5) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +def updateBN(model): + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.weight.grad.data.add_(0.0001 * torch.sign(m.weight.data)) # L1 + + +def train(model, device, train_loader, optimizer, sparse_bn=False): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + # L1 regularization on BN layer + if sparse_bn: + updateBN(model) + optimizer.step() + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + acc = 100 * correct / len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, acc)) + return acc + + +def main(): + torch.manual_seed(0) + device = torch.device('cuda') + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=True, download=True, + transform=transforms.Compose([ + transforms.Pad(4), + transforms.RandomCrop(32), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])), + batch_size=200, shuffle=False) + + model = vgg() + model.to(device) + + # Train the base VGG-19 model + print('=' * 10 + 'Train the unpruned base model' + '=' * 10) + epochs = 160 + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) + for epoch in range(epochs): + if epoch in [epochs * 0.5, epochs * 0.75]: + for param_group in optimizer.param_groups: + param_group['lr'] *= 0.1 + train(model, device, train_loader, optimizer, True) + test(model, device, test_loader) + torch.save(model.state_dict(), 'vgg19_cifar10.pth') + + # Test base model accuracy + print('=' * 10 + 'Test the original model' + '=' * 10) + model.load_state_dict(torch.load('vgg19_cifar10.pth')) + test(model, device, test_loader) + # top1 = 93.60% + + # Pruning Configuration, in paper 'Learning efficient convolutional networks through network slimming', + configure_list = [{ + 'sparsity': 0.7, + 'op_types': ['BatchNorm2d'], + }] + + # Prune model and test accuracy without fine tuning. + print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10) + pruner = SlimPruner(model, configure_list) + model = pruner.compress() + test(model, device, test_loader) + # top1 = 93.55% + + # Fine tune the pruned model for 40 epochs and test accuracy + print('=' * 10 + 'Fine tuning' + '=' * 10) + optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) + best_top1 = 0 + for epoch in range(40): + pruner.update_epoch(epoch) + print('# Epoch {} #'.format(epoch)) + train(model, device, train_loader, optimizer_finetune) + top1 = test(model, device, test_loader) + if top1 > best_top1: + best_top1 = top1 + # Export the best model, 'model_path' stores state_dict of the pruned model, + # mask_path stores mask_dict of the pruned model + pruner.export_model(model_path='pruned_vgg19_cifar10.pth', mask_path='mask_vgg19_cifar10.pth') + + # Test the exported model + print('=' * 10 + 'Test the export pruned model after fine tune' + '=' * 10) + new_model = vgg() + new_model.to(device) + new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth')) + test(new_model, device, test_loader) + # top1 = 93.61% + + +if __name__ == '__main__': + main() diff --git a/src/nni_manager/common/log.ts b/src/nni_manager/common/log.ts index e2ca62f9c6..275fb76ffe 100644 --- a/src/nni_manager/common/log.ts +++ b/src/nni_manager/common/log.ts @@ -155,11 +155,7 @@ class Logger { } } -function getLogger(fileName?: string): Logger { - component.Container.bind(Logger).provider({ - get: (): Logger => new Logger(fileName) - }); - +function getLogger(): Logger { return component.get(Logger); } diff --git a/src/nni_manager/main.ts b/src/nni_manager/main.ts index fec5a8819e..758694be32 100644 --- a/src/nni_manager/main.ts +++ b/src/nni_manager/main.ts @@ -49,7 +49,7 @@ function initStartupInfo( setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly); } -async function initContainer(platformMode: string): Promise { +async function initContainer(platformMode: string, logFileName?: string): Promise { if (platformMode === 'local') { Container.bind(TrainingService) .to(LocalTrainingService) @@ -82,6 +82,9 @@ async function initContainer(platformMode: string): Promise { Container.bind(DataStore) .to(NNIDataStore) .scope(Scope.Singleton); + Container.bind(Logger).provider({ + get: (): Logger => new Logger(logFileName) + }); const ds: DataStore = component.get(DataStore); await ds.init(); @@ -145,13 +148,14 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly); mkDirP(getLogDir()) .then(async () => { - const log: Logger = getLogger(); try { await initContainer(mode); const restServer: NNIRestServer = component.get(NNIRestServer); await restServer.start(); + const log: Logger = getLogger(); log.info(`Rest server listening on: ${restServer.endPoint}`); } catch (err) { + const log: Logger = getLogger(); log.error(`${err.stack}`); throw err; } diff --git a/src/nni_manager/training_service/common/gpuData.ts b/src/nni_manager/training_service/common/gpuData.ts index fd09f8212f..68968cb8f7 100644 --- a/src/nni_manager/training_service/common/gpuData.ts +++ b/src/nni_manager/training_service/common/gpuData.ts @@ -59,14 +59,6 @@ export class GPUSummary { } } -export const GPU_INFO_COLLECTOR_FORMAT_LINUX: string = -` -#!/bin/bash -export METRIC_OUTPUT_DIR={0} -echo $$ >{1} -python3 -m nni_gpu_tool.gpu_metrics_collector -`; - export const GPU_INFO_COLLECTOR_FORMAT_WINDOWS: string = ` $env:METRIC_OUTPUT_DIR="{0}" diff --git a/src/nni_manager/training_service/common/util.ts b/src/nni_manager/training_service/common/util.ts index 294728ee6d..0deb58e1ad 100644 --- a/src/nni_manager/training_service/common/util.ts +++ b/src/nni_manager/training_service/common/util.ts @@ -27,7 +27,7 @@ import * as path from 'path'; import { String } from 'typescript-string-operations'; import { countFilesRecursively, getNewLine, validateFileNameRecursively } from '../../common/utils'; import { file } from '../../node_modules/@types/tmp'; -import { GPU_INFO_COLLECTOR_FORMAT_LINUX, GPU_INFO_COLLECTOR_FORMAT_WINDOWS } from './gpuData'; +import { GPU_INFO_COLLECTOR_FORMAT_WINDOWS } from './gpuData'; /** * Validate codeDir, calculate file count recursively under codeDir, and throw error if any rule is broken @@ -219,22 +219,16 @@ export function getScriptName(fileNamePrefix: string): string { } } -/** - * generate script file - * @param gpuMetricCollectorScriptFolder - */ -export function getgpuMetricsCollectorScriptContent(gpuMetricCollectorScriptFolder: string): string { +export function getGpuMetricsCollectorBashScriptContent(scriptFolder: string): string { + return `echo $$ > ${scriptFolder}/pid ; METRIC_OUTPUT_DIR=${scriptFolder} python3 -m nni_gpu_tool.gpu_metrics_collector`; +} + +export function runGpuMetricsCollector(scriptFolder: string): void { if (process.platform === 'win32') { - return String.Format( - GPU_INFO_COLLECTOR_FORMAT_WINDOWS, - gpuMetricCollectorScriptFolder, - path.join(gpuMetricCollectorScriptFolder, 'pid') - ); + const scriptPath = path.join(scriptFolder, 'gpu_metrics_collector.ps1'); + const content = String.Format(GPU_INFO_COLLECTOR_FORMAT_WINDOWS, scriptFolder, path.join(scriptFolder, 'pid')); + fs.writeFile(scriptPath, content, { encoding: 'utf8' }, () => { runScript(scriptPath); }); } else { - return String.Format( - GPU_INFO_COLLECTOR_FORMAT_LINUX, - gpuMetricCollectorScriptFolder, - path.join(gpuMetricCollectorScriptFolder, 'pid') - ); + cp.exec(getGpuMetricsCollectorBashScriptContent(scriptFolder), { shell: '/bin/bash' }); } } diff --git a/src/nni_manager/training_service/local/gpuScheduler.ts b/src/nni_manager/training_service/local/gpuScheduler.ts index bf05220da0..87fbacd1b9 100644 --- a/src/nni_manager/training_service/local/gpuScheduler.ts +++ b/src/nni_manager/training_service/local/gpuScheduler.ts @@ -28,7 +28,7 @@ import { String } from 'typescript-string-operations'; import { getLogger, Logger } from '../../common/log'; import { delay } from '../../common/utils'; import { GPUInfo, GPUSummary } from '../common/gpuData'; -import { execKill, execMkdir, execRemove, execTail, getgpuMetricsCollectorScriptContent, getScriptName, runScript } from '../common/util'; +import { execKill, execMkdir, execRemove, execTail, runGpuMetricsCollector } from '../common/util'; /** * GPUScheduler for local training service @@ -43,7 +43,7 @@ class GPUScheduler { constructor() { this.stopping = false; this.log = getLogger(); - this.gpuMetricCollectorScriptFolder = `${os.tmpdir()}/nni/script`; + this.gpuMetricCollectorScriptFolder = `${os.tmpdir()}/${os.userInfo().username}/nni/script`; } public async run(): Promise { @@ -101,12 +101,7 @@ class GPUScheduler { */ private async runGpuMetricsCollectorScript(): Promise { await execMkdir(this.gpuMetricCollectorScriptFolder, true); - //generate gpu_metrics_collector script - const gpuMetricsCollectorScriptPath: string = - path.join(this.gpuMetricCollectorScriptFolder, getScriptName('gpu_metrics_collector')); - const gpuMetricsCollectorScriptContent: string = getgpuMetricsCollectorScriptContent(this.gpuMetricCollectorScriptFolder); - await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' }); - runScript(gpuMetricsCollectorScriptPath); + runGpuMetricsCollector(this.gpuMetricCollectorScriptFolder); } // tslint:disable:non-literal-fs-path diff --git a/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts b/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts index 4733df6809..11fc85f829 100644 --- a/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts +++ b/src/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts @@ -42,10 +42,10 @@ import { getVersion, uniqueString, unixPathJoin } from '../../common/utils'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; -import { GPU_INFO_COLLECTOR_FORMAT_LINUX, GPUSummary } from '../common/gpuData'; +import { GPUSummary } from '../common/gpuData'; import { TrialConfig } from '../common/trialConfig'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; -import { execCopydir, execMkdir, execRemove, validateCodeDir } from '../common/util'; +import { execCopydir, execMkdir, execRemove, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util'; import { GPUScheduler } from './gpuScheduler'; import { HOST_JOB_SHELL_FORMAT, RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta, @@ -334,8 +334,6 @@ class RemoteMachineTrainingService implements TrainingService { break; case TrialConfigMetadataKey.MACHINE_LIST: await this.setupConnections(value); - //remove local temp files - await execRemove(this.getLocalGpuMetricCollectorDir()); break; case TrialConfigMetadataKey.TRIAL_CONFIG: const remoteMachineTrailConfig: TrialConfig = JSON.parse(value); @@ -428,34 +426,6 @@ class RemoteMachineTrainingService implements TrainingService { return Promise.resolve(); } - /** - * Generate gpu metric collector directory to store temp gpu metric collector script files - */ - private getLocalGpuMetricCollectorDir(): string { - const userName: string = path.basename(os.homedir()); //get current user name of os - - return path.join(os.tmpdir(), userName, 'nni', 'scripts'); - } - - /** - * Generate gpu metric collector shell script in local machine, - * used to run in remote machine, and will be deleted after uploaded from local. - */ - private async generateGpuMetricsCollectorScript(userName: string): Promise { - const gpuMetricCollectorScriptFolder : string = this.getLocalGpuMetricCollectorDir(); - await execMkdir(path.join(gpuMetricCollectorScriptFolder, userName)); - //generate gpu_metrics_collector.sh - const gpuMetricsCollectorScriptPath: string = path.join(gpuMetricCollectorScriptFolder, userName, 'gpu_metrics_collector.sh'); - // This directory is used to store gpu_metrics and pid created by script - const remoteGPUScriptsDir: string = this.getRemoteScriptsPath(userName); - const gpuMetricsCollectorScriptContent: string = String.Format( - GPU_INFO_COLLECTOR_FORMAT_LINUX, - remoteGPUScriptsDir, - unixPathJoin(remoteGPUScriptsDir, 'pid') - ); - await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' }); - } - private async setupConnections(machineList: string): Promise { this.log.debug(`Connecting to remote machines: ${machineList}`); const deferred: Deferred = new Deferred(); @@ -479,24 +449,18 @@ class RemoteMachineTrainingService implements TrainingService { private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, conn: Client): Promise { // Create root working directory after ssh connection is ready - // generate gpu script in local machine first, will copy to remote machine later - await this.generateGpuMetricsCollectorScript(rmMeta.username); const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni'); await SSHClientUtility.remoteExeCommand(`mkdir -p ${this.remoteExpRootDir}`, conn); - // Copy NNI scripts to remote expeirment working directory - const localGpuScriptCollectorDir: string = this.getLocalGpuMetricCollectorDir(); // the directory to store temp scripts in remote machine const remoteGpuScriptCollectorDir: string = this.getRemoteScriptsPath(rmMeta.username); - await SSHClientUtility.remoteExeCommand(`mkdir -p ${remoteGpuScriptCollectorDir}`, conn); + await SSHClientUtility.remoteExeCommand(`(umask 0 ; mkdir -p ${remoteGpuScriptCollectorDir})`, conn); await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn); - //copy gpu_metrics_collector.sh to remote - await SSHClientUtility.copyFileToRemote(path.join(localGpuScriptCollectorDir, rmMeta.username, 'gpu_metrics_collector.sh'), - unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh'), conn); //Begin to execute gpu_metrics_collection scripts // tslint:disable-next-line: no-floating-promises - SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh')}`, conn); + const script = getGpuMetricsCollectorBashScriptContent(remoteGpuScriptCollectorDir); + SSHClientUtility.remoteExeCommand(`bash -c '${script}'`, conn); const disposable: Rx.IDisposable = this.timer.subscribe( async (tick: number) => { diff --git a/src/sdk/pynni/nni/common.py b/src/sdk/pynni/nni/common.py index f57c458b1d..1388b4c023 100644 --- a/src/sdk/pynni/nni/common.py +++ b/src/sdk/pynni/nni/common.py @@ -68,6 +68,27 @@ def init_logger(logger_file_path, log_level_name='info'): sys.stdout = _LoggerFileWrapper(logger_file) +def init_standalone_logger(): + """ + Initialize root logger for standalone mode. + This will set NNI's log level to INFO and print its log to stdout. + """ + fmt = '[%(asctime)s] %(levelname)s (%(name)s) %(message)s' + formatter = logging.Formatter(fmt, _time_format) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + nni_logger = logging.getLogger('nni') + nni_logger.addHandler(handler) + nni_logger.setLevel(logging.INFO) + nni_logger.propagate = False + + # Following line does not affect NNI loggers, but without this user's logger won't be able to + # print log even it's level is set to INFO, so we do it for user's convenience. + # If this causes any issue in future, remove it and use `logging.info` instead of + # `logging.getLogger('xxx')` in all examples. + logging.basicConfig() + + _multi_thread = False _multi_phase = False diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index b43195c945..b56cbf3ad9 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -34,7 +34,6 @@ def calc_mask(self, layer, config): class AGP_Pruner(Pruner): """An automated gradual pruning algorithm that prunes the smallest magnitude weights to achieve a preset level of network sparsity. - Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the efficacy of pruning for model compression", 2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices, @@ -178,17 +177,13 @@ def _get_min_gm_kernel_idx(self, weight, n): assert len(weight.shape) >= 3 assert weight.shape[0] * weight.shape[1] > 2 - dist_list, idx_list = [], [] + dist_list = [] for in_i in range(weight.shape[0]): for out_i in range(weight.shape[1]): dist_sum = self._get_distance_sum(weight, in_i, out_i) - dist_list.append(dist_sum) - idx_list.append([in_i, out_i]) - dist_tensor = tf.convert_to_tensor(dist_list) - idx_tensor = tf.constant(idx_list) - - _, idx = tf.math.top_k(dist_tensor, k=n) - return tf.gather(idx_tensor, idx) + dist_list.append((dist_sum, (in_i, out_i))) + min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n] + return [x[1] for x in min_gm_kernels] def _get_distance_sum(self, weight, in_idx, out_idx): w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1])) diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 2b0a0391a3..6a080e488c 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -2,24 +2,44 @@ import torch from .compressor import Pruner -__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner'] +__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner', 'L1FilterPruner', 'SlimPruner'] logger = logging.getLogger('torch pruner') class LevelPruner(Pruner): - """Prune to an exact pruning level specification + """ + Prune to an exact pruning level specification """ def __init__(self, model, config_list): """ - config_list: supported keys: - - sparsity + Parameters + ---------- + model : torch.nn.module + Model to be pruned + config_list : list + List on pruning configs """ + super().__init__(model, config_list) self.if_init_list = {} def calc_mask(self, layer, config): + """ + Calculate the mask of given layer + Parameters + ---------- + layer : LayerInfo + the layer to instrument the compression operation + config : dict + layer's pruning config + Returns + ------- + torch.Tensor + mask of the layer's weight + """ + weight = layer.module.weight.data op_name = layer.name if self.if_init_list.get(op_name, True): @@ -37,9 +57,9 @@ def calc_mask(self, layer, config): class AGP_Pruner(Pruner): - """An automated gradual pruning algorithm that prunes the smallest magnitude + """ + An automated gradual pruning algorithm that prunes the smallest magnitude weights to achieve a preset level of network sparsity. - Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the efficacy of pruning for model compression", 2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices, @@ -48,24 +68,39 @@ class AGP_Pruner(Pruner): def __init__(self, model, config_list): """ - config_list: supported keys: - - initial_sparsity - - final_sparsity: you should make sure initial_sparsity <= final_sparsity - - start_epoch: start epoch number begin update mask - - end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch - - frequency: if you want update every 2 epoch, you can set it 2 + Parameters + ---------- + model : torch.nn.module + Model to be pruned + config_list : list + List on pruning configs """ + super().__init__(model, config_list) self.now_epoch = 0 self.if_init_list = {} def calc_mask(self, layer, config): + """ + Calculate the mask of given layer + Parameters + ---------- + layer : LayerInfo + the layer to instrument the compression operation + config : dict + layer's pruning config + Returns + ------- + torch.Tensor + mask of the layer's weight + """ + weight = layer.module.weight.data op_name = layer.name start_epoch = config.get('start_epoch', 0) freq = config.get('frequency', 1) - if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) and ( - self.now_epoch - start_epoch) % freq == 0: + if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \ + and (self.now_epoch - start_epoch) % freq == 0: mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight)) target_sparsity = self.compute_target_sparsity(config) k = int(weight.numel() * target_sparsity) @@ -82,6 +117,18 @@ def calc_mask(self, layer, config): return new_mask def compute_target_sparsity(self, config): + """ + Calculate the sparsity for pruning + Parameters + ---------- + config : dict + Layer's pruning config + Returns + ------- + float + Target sparsity to be pruned + """ + end_epoch = config.get('end_epoch', 1) start_epoch = config.get('start_epoch', 0) freq = config.get('frequency', 1) @@ -102,11 +149,20 @@ def compute_target_sparsity(self, config): return target_sparsity def update_epoch(self, epoch): + """ + Update epoch + Parameters + ---------- + epoch : int + current training epoch + """ + if epoch > 0: self.now_epoch = epoch - for k in self.if_init_list: + for k in self.if_init_list.keys(): self.if_init_list[k] = True + class FPGMPruner(Pruner): """ A filter pruner via geometric median. @@ -135,13 +191,11 @@ def calc_mask(self, layer, config): OUT: number of output channel IN: number of input channel LEN: filter length - filter dimensions for Conv2d: OUT: number of output channel IN: number of input channel H: filter height W: filter width - Parameters ---------- layer : LayerInfo @@ -196,7 +250,6 @@ def _get_distance_sum(self, weight, in_idx, out_idx): for k in w: dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2) return dist_sum - Parameters ---------- weight: Tensor @@ -206,25 +259,151 @@ def _get_distance_sum(self, weight, in_idx, out_idx): between this specified filter and all other filters. in_idx: int input channel index of specified filter - Returns ------- float32 The total distance """ logger.debug('weight size: %s', weight.size()) - if len(weight.size()) == 4: # Conv2d + if len(weight.size()) == 4: # Conv2d w = weight.view(-1, weight.size(-2), weight.size(-1)) anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) - elif len(weight.size()) == 3: # Conv1d + elif len(weight.size()) == 3: # Conv1d w = weight.view(-1, weight.size(-1)) anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1)) else: raise RuntimeError('unsupported layer type') x = w - anchor_w - x = (x*x).sum((-2, -1)) + x = (x * x).sum((-2, -1)) x = torch.sqrt(x) return x.sum() def update_epoch(self, epoch): self.epoch_pruned_layers = set() + + +class L1FilterPruner(Pruner): + """ + A structured pruning algorithm that prunes the filters of smallest magnitude + weights sum in the convolution layers to achieve a preset level of network sparsity. + Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf, + "PRUNING FILTERS FOR EFFICIENT CONVNETS", 2017 ICLR + https://arxiv.org/abs/1608.08710 + """ + + def __init__(self, model, config_list): + """ + Parameters + ---------- + model : torch.nn.module + Model to be pruned + config_list : list + support key for each list item: + - sparsity: percentage of convolutional filters to be pruned. + """ + + super().__init__(model, config_list) + self.mask_calculated_ops = set() + + def calc_mask(self, layer, config): + """ + Calculate the mask of given layer. + Filters with the smallest sum of its absolute kernel weights are masked. + Parameters + ---------- + layer : LayerInfo + the layer to instrument the compression operation + config : dict + layer's pruning config + Returns + ------- + torch.Tensor + mask of the layer's weight + """ + + weight = layer.module.weight.data + op_name = layer.name + op_type = layer.type + assert op_type == 'Conv2d', 'L1FilterPruner only supports 2d convolution layer pruning' + if op_name in self.mask_calculated_ops: + assert op_name in self.mask_dict + return self.mask_dict.get(op_name) + mask = torch.ones(weight.size()).type_as(weight) + try: + filters = weight.shape[0] + w_abs = weight.abs() + k = int(filters * config['sparsity']) + if k == 0: + return torch.ones(weight.shape).type_as(weight) + w_abs_structured = w_abs.view(filters, -1).sum(dim=1) + threshold = torch.topk(w_abs_structured.view(-1), k, largest=False).values.max() + mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight) + finally: + self.mask_dict.update({layer.name: mask}) + self.mask_calculated_ops.add(layer.name) + + return mask + + +class SlimPruner(Pruner): + """ + A structured pruning algorithm that prunes channels by pruning the weights of BN layers. + Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang + "Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV + https://arxiv.org/pdf/1708.06519.pdf + """ + + def __init__(self, model, config_list): + """ + Parameters + ---------- + config_list : list + support key for each list item: + - sparsity: percentage of convolutional filters to be pruned. + """ + + super().__init__(model, config_list) + self.mask_calculated_ops = set() + weight_list = [] + if len(config_list) > 1: + logger.warning('Slim pruner only supports 1 configuration') + config = config_list[0] + for (layer, config) in self.detect_modules_to_compress(): + assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' + weight_list.append(layer.module.weight.data.clone()) + all_bn_weights = torch.cat(weight_list) + k = int(all_bn_weights.shape[0] * config['sparsity']) + self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False).values.max() + + def calc_mask(self, layer, config): + """ + Calculate the mask of given layer. + Scale factors with the smallest absolute value in the BN layer are masked. + Parameters + ---------- + layer : LayerInfo + the layer to instrument the compression operation + config : dict + layer's pruning config + Returns + ------- + torch.Tensor + mask of the layer's weight + """ + + weight = layer.module.weight.data + op_name = layer.name + op_type = layer.type + assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' + if op_name in self.mask_calculated_ops: + assert op_name in self.mask_dict + return self.mask_dict.get(op_name) + mask = torch.ones(weight.size()).type_as(weight) + try: + w_abs = weight.abs() + mask = torch.gt(w_abs, self.global_threshold).type_as(weight) + finally: + self.mask_dict.update({layer.name: mask}) + self.mask_calculated_ops.add(layer.name) + + return mask diff --git a/src/sdk/pynni/nni/platform/standalone.py b/src/sdk/pynni/nni/platform/standalone.py index 7f752786b7..554fc976fa 100644 --- a/src/sdk/pynni/nni/platform/standalone.py +++ b/src/sdk/pynni/nni/platform/standalone.py @@ -22,14 +22,26 @@ import logging import json_tricks +from ..common import init_standalone_logger -# print INFO log to stdout -logging.basicConfig() -logging.getLogger('nni').setLevel(logging.INFO) +__all__ = [ + 'get_next_parameter', + 'get_experiment_id', + 'get_trial_id', + 'get_sequence_id', + 'send_metric', +] + +init_standalone_logger() +_logger = logging.getLogger('nni') def get_next_parameter(): - pass + _logger.warning('Requesting parameter without NNI framework, returning empty dict') + return { + 'parameter_id': None, + 'parameters': {} + } def get_experiment_id(): pass @@ -43,6 +55,8 @@ def get_sequence_id(): def send_metric(string): metric = json_tricks.loads(string) if metric['type'] == 'FINAL': - print('Final result:', metric['value']) + _logger.info('Final result: %s', metric['value']) elif metric['type'] == 'PERIODICAL': - print('Intermediate result:', metric['value']) + _logger.info('Intermediate result: %s (Index %s)', metric['value'], metric['sequence']) + else: + _logger.error('Unexpected metric: %s', string) diff --git a/src/sdk/pynni/nni/trial.py b/src/sdk/pynni/nni/trial.py index e0c7cde163..586b7a913a 100644 --- a/src/sdk/pynni/nni/trial.py +++ b/src/sdk/pynni/nni/trial.py @@ -126,9 +126,10 @@ def report_intermediate_result(metric): serializable object. """ global _intermediate_seq - assert _params is not None, 'nni.get_next_parameter() needs to be called before report_intermediate_result' + assert _params or trial_env_vars.NNI_PLATFORM is None, \ + 'nni.get_next_parameter() needs to be called before report_intermediate_result' metric = json_tricks.dumps({ - 'parameter_id': _params['parameter_id'], + 'parameter_id': _params['parameter_id'] if _params else None, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'type': 'PERIODICAL', 'sequence': _intermediate_seq, @@ -147,9 +148,10 @@ def report_final_result(metric): metric: serializable object. """ - assert _params is not None, 'nni.get_next_parameter() needs to be called before report_final_result' + assert _params or trial_env_vars.NNI_PLATFORM is None, \ + 'nni.get_next_parameter() needs to be called before report_final_result' metric = json_tricks.dumps({ - 'parameter_id': _params['parameter_id'], + 'parameter_id': _params['parameter_id'] if _params else None, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'type': 'FINAL', 'sequence': 0, diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index f40bd9485b..0f714a76a2 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -1,4 +1,5 @@ from unittest import TestCase, main +import numpy as np import tensorflow as tf import torch import torch.nn.functional as F @@ -7,11 +8,11 @@ if tf.__version__ >= '2.0': import nni.compression.tensorflow as tf_compressor -def get_tf_mnist_model(): +def get_tf_model(): model = tf.keras.models.Sequential([ - tf.keras.layers.Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), + tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"), tf.keras.layers.MaxPooling2D(pool_size=2), - tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"), + tf.keras.layers.Conv2D(filters=10, kernel_size=3, activation='relu', padding="SAME"), tf.keras.layers.MaxPooling2D(pool_size=2), tf.keras.layers.Flatten(), tf.keras.layers.Dense(units=128, activation='relu'), @@ -23,43 +24,51 @@ def get_tf_mnist_model(): metrics=["accuracy"]) return model -class TorchMnist(torch.nn.Module): +class TorchModel(torch.nn.Module): def __init__(self): super().__init__() - self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) - self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) - self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) - self.fc2 = torch.nn.Linear(500, 10) + self.conv1 = torch.nn.Conv2d(1, 5, 5, 1) + self.conv2 = torch.nn.Conv2d(5, 10, 5, 1) + self.fc1 = torch.nn.Linear(4 * 4 * 10, 100) + self.fc2 = torch.nn.Linear(100, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4 * 4 * 50) + x = x.view(-1, 4 * 4 * 10) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) def tf2(func): - def test_tf2_func(self): + def test_tf2_func(*args): if tf.__version__ >= '2.0': - func() + func(*args) return test_tf2_func +k1 = [[1]*3]*3 +k2 = [[2]*3]*3 +k3 = [[3]*3]*3 +k4 = [[4]*3]*3 +k5 = [[5]*3]*3 + +w = [[k1, k2, k3, k4, k5]] * 10 + class CompressorTestCase(TestCase): - def test_torch_pruner(self): - model = TorchMnist() + def test_torch_level_pruner(self): + model = TorchModel() configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] torch_compressor.LevelPruner(model, configure_list).compress() - def test_torch_fpgm_pruner(self): - model = TorchMnist() - configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}] - torch_compressor.FPGMPruner(model, configure_list).compress() + @tf2 + def test_tf_level_pruner(self): + configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] + tf_compressor.LevelPruner(get_tf_model(), configure_list).compress() - def test_torch_quantizer(self): - model = TorchMnist() + def test_torch_naive_quantizer(self): + model = TorchModel() configure_list = [{ 'quant_types': ['weight'], 'quant_bits': { @@ -70,18 +79,59 @@ def test_torch_quantizer(self): torch_compressor.NaiveQuantizer(model, configure_list).compress() @tf2 - def test_tf_pruner(self): - configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] - tf_compressor.LevelPruner(get_tf_mnist_model(), configure_list).compress() + def test_tf_naive_quantizer(self): + tf_compressor.NaiveQuantizer(get_tf_model(), [{'op_types': ['default']}]).compress() - @tf2 - def test_tf_quantizer(self): - tf_compressor.NaiveQuantizer(get_tf_mnist_model(), [{'op_types': ['default']}]).compress() + def test_torch_fpgm_pruner(self): + """ + With filters(kernels) defined as above (k1 - k5), it is obvious that k3 is the Geometric Median + which minimize the total geometric distance by defination of Geometric Median in this paper: + Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration, + https://arxiv.org/pdf/1811.00250.pdf + + So if sparsity is 0.2, the expected masks should mask out all k3, this can be verified through: + `all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))` + + If sparsity is 0.6, the expected masks should mask out all k2, k3, k4, this can be verified through: + `all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))` + """ + + model = TorchModel() + config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d']}, {'sparsity': 0.6, 'op_types': ['Conv2d']}] + pruner = torch_compressor.FPGMPruner(model, config_list) + + model.conv2.weight.data = torch.tensor(w).float() + layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2) + masks = pruner.calc_mask(layer, config_list[0]) + assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.])) + + pruner.update_epoch(1) + model.conv2.weight.data = torch.tensor(w).float() + masks = pruner.calc_mask(layer, config_list[1]) + assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.])) @tf2 def test_tf_fpgm_pruner(self): - configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2D']}] - tf_compressor.FPGMPruner(get_tf_mnist_model(), configure_list).compress() + model = get_tf_model() + config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}, {'sparsity': 0.6, 'op_types': ['Conv2D']}] + + pruner = tf_compressor.FPGMPruner(model, config_list) + weights = model.layers[2].weights + weights[0] = np.array(w).astype(np.float32).transpose([2, 3, 0, 1]).transpose([0, 1, 3, 2]) + model.layers[2].set_weights([weights[0], weights[1].numpy()]) + + layer = tf_compressor.compressor.LayerInfo(model.layers[2]) + masks = pruner.calc_mask(layer, config_list[0]).numpy() + masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3]) + + assert all(masks.sum((0, 2, 3)) == np.array([90., 90., 0., 90., 90.])) + + pruner.update_epoch(1) + model.layers[2].set_weights([weights[0], weights[1].numpy()]) + masks = pruner.calc_mask(layer, config_list[1]).numpy() + masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3]) + + assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.])) if __name__ == '__main__': diff --git a/tools/nni_gpu_tool/gpu_metrics_collector.py b/tools/nni_gpu_tool/gpu_metrics_collector.py index 436e1edaaf..591352acc2 100644 --- a/tools/nni_gpu_tool/gpu_metrics_collector.py +++ b/tools/nni_gpu_tool/gpu_metrics_collector.py @@ -35,7 +35,7 @@ def check_ready_to_run(): pidList.remove(os.getpid()) return not pidList else: - pgrep_output = subprocess.check_output('pgrep -fx \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True) + pgrep_output = subprocess.check_output('pgrep -fxu "$(whoami)" \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True) pidList = [] for pid in pgrep_output.splitlines(): pidList.append(int(pid))