diff --git a/docs/en_US/Compressor/CompressionReference.md b/docs/en_US/Compressor/CompressionReference.md new file mode 100644 index 0000000000..c190a46eb6 --- /dev/null +++ b/docs/en_US/Compressor/CompressionReference.md @@ -0,0 +1,23 @@ +# Python API Reference of Compression Utilities + +```eval_rst +.. contents:: +``` + +## Sensitivity Utilities + +```eval_rst +.. autoclass:: nni.compression.torch.utils.sensitivity_analysis.SensitivityAnalysis + :members: + +``` + +## Topology Utilities + +```eval_rst +.. autoclass:: nni.compression.torch.utils.shape_dependency.ChannelDependency + :members: + +.. autoclass:: nni.compression.torch.utils.mask_conflict.MaskConflict + :members: +``` diff --git a/docs/en_US/Compressor/CompressionUtils.md b/docs/en_US/Compressor/CompressionUtils.md new file mode 100644 index 0000000000..09418912b9 --- /dev/null +++ b/docs/en_US/Compressor/CompressionUtils.md @@ -0,0 +1,123 @@ +# Analysis Utils for Model Compression + +```eval_rst +.. contents:: +``` + +We provide several easy-to-use tools for users to analyze their model during model compression. + +## Sensitivity Analysis +First, we provide a sensitivity analysis tool (**SensitivityAnalysis**) for users to analyze the sensitivity of each convolutional layer in their model. Specifically, the SensitiviyAnalysis gradually prune each layer of the model, and test the accuracy of the model at the same time. Note that, SensitivityAnalysis only prunes a layer once a time, and the other layers are set to their original weights. According to the accuracies of different convolutional layers under different sparsities, we can easily find out which layers the model accuracy is more sensitive to. + +### Usage + +The following codes show the basic usage of the SensitivityAnalysis. +```python +from nni.compression.torch.utils.sensitivity_analysis import SensitivityAnalysis + +def val(model): + model.eval() + total = 0 + correct = 0 + with torch.no_grad(): + for batchid, (data, label) in enumerate(val_loader): + data, label = data.cuda(), label.cuda() + out = model(data) + _, predicted = out.max(1) + total += data.size(0) + correct += predicted.eq(label).sum().item() + return correct / total + +s_analyzer = SensitivityAnalysis(model=net, val_func=val) +sensitivity = s_analyzer.analysis(val_args=[net]) +os.makedir(outdir) +s_analyzer.export(os.path.join(outdir, filename)) +``` + +Two key parameters of SensitivityAnalysis are `model`, and `val_func`. `model` is the neural network that to be analyzed and the `val_func` is the validation function that returns the model accuracy/loss/ or other metrics on the validation dataset. Due to different scenarios may have different ways to calculate the loss/accuracy, so users should prepare a function that returns the model accuracy/loss on the dataset and pass it to SensitivityAnalysis. +SensitivityAnalysis can export the sensitivity results as a csv file usage is shown in the example above. + +Futhermore, users can specify the sparsities values used to prune for each layer by optional parameter `sparsities`. +```python +s_analyzer = SensitivityAnalysis(model=net, val_func=val, sparsities=[0.25, 0.5, 0.75]) +``` +the SensitivityAnalysis will prune 25% 50% 75% weights gradually for each layer, and record the model's accuracy at the same time (SensitivityAnalysis only prune a layer once a time, the other layers are set to their original weights). If the sparsities is not set, SensitivityAnalysis will use the numpy.arange(0.1, 1.0, 0.1) as the default sparsity values. + +Users can also speed up the progress of sensitivity analysis by the early_stop_mode and early_stop_value option. By default, the SensitivityAnalysis will test the accuracy under all sparsities for each layer. In contrast, when the early_stop_mode and early_stop_value are set, the sensitivity analysis for a layer will stop, when the accuracy/loss has already met the threshold set by early_stop_value. We support four early stop modes: minimize, maximize, dropped, raised. + +minimize: The analysis stops when the validation metric return by the val_func lower than `early_stop_value`. + +maximize: The analysis stops when the validation metric return by the val_func larger than `early_stop_value`. + +dropped: The analysis stops when the validation metric has dropped by `early_stop_value`. + +raised: The analysis stops when the validation metric has raised by `early_stop_value`. + +```python +s_analyzer = SensitivityAnalysis(model=net, val_func=val, sparsities=[0.25, 0.5, 0.75], early_stop_mode='dropped', early_stop_value=0.1) +``` +If users only want to analyze several specified convolutional layers, users can specify the target conv layers by the `specified_layers` in analysis function. `specified_layers` is a list that consists of the Pytorch module names of the conv layers. For example +```python +sensitivity = s_analyzer.analysis(val_args=[net], specified_layers=['Conv1']) +``` +In this example, only the `Conv1` layer is analyzed. In addtion, users can quickly and easily achieve the analysis parallelization by launching multiple processes and assigning different conv layers of the same model to each process. + + +### Output example +The following lines are the example csv file exported from SensitivityAnalysis. The first line is constructed by 'layername' and sparsity list. Here the sparsity value means how much weight SensitivityAnalysis prune for each layer. Each line below records the model accuracy when this layer is under different sparsities. Note that, due to the early_stop option, some layers may +not have model accuracies/losses under all sparsities, for example, its accuracy drop has already exceeded the threshold set by the user. +``` +layername,0.05,0.1,0.2,0.3,0.4,0.5,0.7,0.85,0.95 +features.0,0.54566,0.46308,0.06978,0.0374,0.03024,0.01512,0.00866,0.00492,0.00184 +features.3,0.54878,0.51184,0.37978,0.19814,0.07178,0.02114,0.00438,0.00442,0.00142 +features.6,0.55128,0.53566,0.4887,0.4167,0.31178,0.19152,0.08612,0.01258,0.00236 +features.8,0.55696,0.54194,0.48892,0.42986,0.33048,0.2266,0.09566,0.02348,0.0056 +features.10,0.55468,0.5394,0.49576,0.4291,0.3591,0.28138,0.14256,0.05446,0.01578 +``` + +## Topology Analysis +We also provide several tools for the topology analysis during the model compression. These tools are to help users compress their model better. Because of the complex topology of the network, when compressing the model, users often need to spend a lot of effort to check whether the compression configuration is reasonable. So we provide these tools for topology analysis to reduce the burden on users. + +### ChannelDependency +Complicated models may have residual connection/concat operations in their models. When the user prunes these models, they need to be careful about the channel-count dependencies between the convolution layers in the model. Taking the following residual block in the resnet18 as an example. The output features of the `layer2.0.conv2` and `layer2.0.downsample.0` are added together, so the number of the output channels of `layer2.0.conv2` and `layer2.0.downsample.0` should be the same, or there may be a tensor shape conflict. + +![](../../img/channel_dependency_example.jpg) + + +If the layers have channel dependency are assigned with different sparsities (here we only discuss the structured pruning by L1FilterPruner/L2FilterPruner), then there will be a shape conflict during these layers. Even the pruned model with mask works fine, the pruned model cannot be speedup to the final model directly that runs on the devices, because there will be a shape conflict when the model tries to add/concat the outputs of these layers. This tool is to find the layers that have channel count dependencies to help users better prune their model. + +#### Usage +```python +from nni.compression.torch.utils.shape_dependency import ChannelDependency +data = torch.ones(1, 3, 224, 224).cuda() +channel_depen = ChannelDependency(net, data) +channel_depen.export('dependency.csv') +``` + +#### Output Example +The following lines are the output example of torchvision.models.resnet18 exported by ChannelDependency. The layers at the same line have output channel dependencies with each other. For example, layer1.1.conv2, conv1, and layer1.0.conv2 have output channel dependencies with each other, which means the output channel(filters) numbers of these three layers should be same with each other, otherwise, the model may have shape conflict. +``` +Dependency Set,Convolutional Layers +Set 1,layer1.1.conv2,layer1.0.conv2,conv1 +Set 2,layer1.0.conv1 +Set 3,layer1.1.conv1 +Set 4,layer2.0.conv1 +Set 5,layer2.1.conv2,layer2.0.conv2,layer2.0.downsample.0 +Set 6,layer2.1.conv1 +Set 7,layer3.0.conv1 +Set 8,layer3.0.downsample.0,layer3.1.conv2,layer3.0.conv2 +Set 9,layer3.1.conv1 +Set 10,layer4.0.conv1 +Set 11,layer4.0.downsample.0,layer4.1.conv2,layer4.0.conv2 +Set 12,layer4.1.conv1 +``` + +### MaskConflict +When the masks of different layers in a model have conflict (for example, assigning different sparsities for the layers that have channel dependency), we can fix the mask conflict by MaskConflict. Specifically, the MaskConflict loads the masks exported by the pruners(L1FilterPruner, etc), and check if there is mask conflict, if so, MaskConflict sets the conflicting masks to the same value. + +``` +from nni.compression.torch.utils.mask_conflict import MaskConflict +mc = MaskConflict('./resnet18_mask', net, data) +mc.fix_mask_conflict() +mc.export('./resnet18_fixed_mask') +``` \ No newline at end of file diff --git a/docs/en_US/model_compression.rst b/docs/en_US/model_compression.rst index 457acfadce..f6821e3045 100644 --- a/docs/en_US/model_compression.rst +++ b/docs/en_US/model_compression.rst @@ -22,3 +22,4 @@ For details, please refer to the following tutorials: Model Speedup Automatic Model Compression Implementation + Compression Utilities diff --git a/docs/en_US/sdk_reference.rst b/docs/en_US/sdk_reference.rst index 49d47a2ffd..2602e257b9 100644 --- a/docs/en_US/sdk_reference.rst +++ b/docs/en_US/sdk_reference.rst @@ -7,4 +7,5 @@ Python API Reference :maxdepth: 1 Auto Tune - NAS \ No newline at end of file + NAS + Compression Utilities \ No newline at end of file diff --git a/docs/img/channel_dependency_example.jpg b/docs/img/channel_dependency_example.jpg new file mode 100644 index 0000000000..6fb517fe00 Binary files /dev/null and b/docs/img/channel_dependency_example.jpg differ diff --git a/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py b/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py new file mode 100644 index 0000000000..626283d43d --- /dev/null +++ b/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import logging +import torch +import numpy as np +from .shape_dependency import ChannelDependency +# logging.basicConfig(level = logging.DEBUG) +_logger = logging.getLogger('FixMaskConflict') + +class MaskConflict: + def __init__(self, mask_file, model=None, dummy_input=None, graph=None): + """ + MaskConflict fix the mask conflict between the layers that + has channel dependecy with each other. + + Parameters + ---------- + model : torch.nn.Module + model to fix the mask conflict + dummy_input : torch.Tensor + input example to trace the model + mask_file : str + the path of the original mask file + graph : torch._C.Graph + the traced graph of the target model, is this parameter is not None, + we donnot use the model and dummpy_input to get the trace graph. + """ + # check if the parameters are valid + parameter_valid = False + if graph is not None: + parameter_valid = True + elif (model is not None) and (dummy_input is not None): + parameter_valid = True + if not parameter_valid: + raise Exception('The input parameters is invalid!') + self.model = model + self.dummy_input = dummy_input + self.graph = graph + self.mask_file = mask_file + self.masks = torch.load(self.mask_file) + + def fix_mask_conflict(self): + """ + Fix the mask conflict before the mask inference for the layers that + has shape dependencies. This function should be called before the + mask inference of the 'speedup' module. + """ + channel_depen = ChannelDependency(self.model, self.dummy_input, self.graph) + depen_sets = channel_depen.dependency_sets + for dset in depen_sets: + if len(dset) == 1: + # This layer has no channel dependency with other layers + continue + channel_remain = set() + fine_grained = False + for name in dset: + if name not in self.masks: + # this layer is not pruned + continue + w_mask = self.masks[name]['weight'] + shape = w_mask.size() + count = np.prod(shape[1:]) + all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist() + all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist() + if len(all_ones) + len(all_zeros) < w_mask.size(0): + # In fine-grained pruning, there is no need to check + # the shape conflict + _logger.info('Layers %s using fine-grained pruning', ','.join(dset)) + fine_grained = True + break + channel_remain.update(all_ones) + _logger.debug('Layer: %s ', name) + _logger.debug('Original pruned filters: %s', str(all_zeros)) + # Update the masks for the layers in the dependency set + if fine_grained: + continue + ori_channels = 0 + for name in dset: + mask = self.masks[name] + w_shape = mask['weight'].size() + ori_channels = w_shape[0] + for i in channel_remain: + mask['weight'][i] = torch.ones(w_shape[1:]) + if hasattr(mask, 'bias'): + mask['bias'][i] = 1 + _logger.info(','.join(dset)) + _logger.info('Pruned Filters after fixing conflict:') + pruned_filters = set(list(range(ori_channels)))-channel_remain + _logger.info(str(sorted(pruned_filters))) + return self.masks + + def export(self, path): + """ + Export the masks after fixing the conflict to file. + """ + torch.save(self.masks, path) diff --git a/src/sdk/pynni/nni/compression/torch/utils/sensitivity_analysis.py b/src/sdk/pynni/nni/compression/torch/utils/sensitivity_analysis.py new file mode 100644 index 0000000000..fc259833b6 --- /dev/null +++ b/src/sdk/pynni/nni/compression/torch/utils/sensitivity_analysis.py @@ -0,0 +1,252 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import copy +import csv +import logging +from collections import OrderedDict + +import numpy as np +import torch.nn as nn + +from nni.compression.torch import LevelPruner +from nni.compression.torch import L1FilterPruner +from nni.compression.torch import L2FilterPruner + +SUPPORTED_OP_NAME = ['Conv2d', 'Conv1d'] +SUPPORTED_OP_TYPE = [getattr(nn, name) for name in SUPPORTED_OP_NAME] + +logger = logging.getLogger('Sensitivity_Analysis') +logger.setLevel(logging.INFO) + + +class SensitivityAnalysis: + def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None): + """ + Perform sensitivity analysis for this model. + Parameters + ---------- + model : torch.nn.Module + the model to perform sensitivity analysis + val_func : function + validation function for the model. Due to + different models may need different dataset/criterion + , therefore the user need to cover this part by themselves. + In the val_func, the model should be tested on the validation dateset, + and the validation accuracy/loss should be returned as the output of val_func. + There are no restrictions on the input parameters of the val_function. + User can use the val_args, val_kwargs parameters in analysis + to pass all the parameters that val_func needed. + sparsities : list + The sparsity list provided by users. This parameter is set when the user + only wants to test some specific sparsities. In the sparsity list, each element + is a sparsity value which means how much weight the pruner should prune. Take + [0.25, 0.5, 0.75] for an example, the SensitivityAnalysis will prune 25% 50% 75% + weights gradually for each layer. + prune_type : str + The pruner type used to prune the conv layers, default is 'l1', + and 'l2', 'fine-grained' is also supported. + early_stop_mode : str + If this flag is set, the sensitivity analysis + for a conv layer will early stop when the validation metric( + for example, accurracy/loss) has alreay meet the threshold. We + support four different early stop modes: minimize, maximize, dropped, + raised. The default value is None, which means the analysis won't stop + until all given sparsities are tested. This option should be used with + early_stop_value together. + + minimize: The analysis stops when the validation metric return by the val_func + lower than early_stop_value. + maximize: The analysis stops when the validation metric return by the val_func + larger than early_stop_value. + dropped: The analysis stops when the validation metric has dropped by early_stop_value. + raised: The analysis stops when the validation metric has raised by early_stop_value. + early_stop_value : float + This value is used as the threshold for different earlystop modes. + This value is effective only when the early_stop_mode is set. + + """ + self.model = model + self.val_func = val_func + self.target_layer = OrderedDict() + self.ori_state_dict = copy.deepcopy(self.model.state_dict()) + self.target_layer = {} + self.sensitivities = {} + if sparsities is not None: + self.sparsities = sorted(sparsities) + else: + self.sparsities = np.arange(0.1, 1.0, 0.1) + self.sparsities = [np.round(x, 2) for x in self.sparsities] + self.Pruner = L1FilterPruner + if prune_type == 'l2': + self.Pruner = L2FilterPruner + elif prune_type == 'fine-grained': + self.Pruner = LevelPruner + self.early_stop_mode = early_stop_mode + self.early_stop_value = early_stop_value + self.ori_metric = None # original validation metric for the model + # already_pruned is for the iterative sensitivity analysis + # For example, sensitivity_pruner iteratively prune the target + # model according to the sensitivity. After each round of + # pruning, the sensitivity_pruner will test the new sensitivity + # for each layer + self.already_pruned = {} + self.model_parse() + + @property + def layers_count(self): + return len(self.target_layer) + + def model_parse(self): + for name, submodel in self.model.named_modules(): + for op_type in SUPPORTED_OP_TYPE: + if isinstance(submodel, op_type): + self.target_layer[name] = submodel + self.already_pruned[name] = 0 + + def _need_to_stop(self, ori_metric, cur_metric): + """ + Judge if meet the stop conditon(early_stop, min_threshold, + max_threshold). + Parameters + ---------- + ori_metric : float + original validation metric + cur_metric : float + current validation metric + + Returns + ------- + stop : bool + if stop the sensitivity analysis + """ + if self.early_stop_mode is None: + # early stop mode is not enable + return False + assert self.early_stop_value is not None + if self.early_stop_mode == 'minimize': + if cur_metric < self.early_stop_value: + return True + elif self.early_stop_mode == 'maximize': + if cur_metric > self.early_stop_value: + return True + elif self.early_stop_mode == 'dropped': + if cur_metric < ori_metric - self.early_stop_value: + return True + elif self.early_stop_mode == 'raised': + if cur_metric > ori_metric + self.early_stop_value: + return True + return False + + def analysis(self, val_args=None, val_kwargs=None, specified_layers=None): + """ + This function analyze the sensitivity to pruning for + each conv layer in the target model. + If start and end are not set, we analyze all the conv + layers by default. Users can specify several layers to + analyze or parallelize the analysis process easily through + the start and end parameter. + + Parameters + ---------- + val_args : list + args for the val_function + val_kwargs : dict + kwargs for the val_funtion + specified_layers : list + list of layer names to analyze sensitivity. + If this variable is set, then only analyze + the conv layers that specified in the list. + User can also use this option to parallelize + the sensitivity analysis easily. + Returns + ------- + sensitivities : dict + dict object that stores the trajectory of the + accuracy/loss when the prune ratio changes + """ + if val_args is None: + val_args = [] + if val_kwargs is None: + val_kwargs = {} + # Get the original validation metric(accuracy/loss) before pruning + if self.ori_metric is None: + self.ori_metric = self.val_func(*val_args, **val_kwargs) + namelist = list(self.target_layer.keys()) + if specified_layers is not None: + # only analyze several specified conv layers + namelist = list(filter(lambda x: x in specified_layers, namelist)) + for name in namelist: + self.sensitivities[name] = {} + for sparsity in self.sparsities: + # Calculate the actual prune ratio based on the already pruned ratio + sparsity = ( + 1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name] + # TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary + # I think the L1/L2 Pruner should specify the op_types automaticlly + # according to the op_names + cfg = [{'sparsity': sparsity, 'op_names': [ + name], 'op_types': ['Conv2d']}] + pruner = self.Pruner(self.model, cfg) + pruner.compress() + val_metric = self.val_func(*val_args, **val_kwargs) + logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f', + name, sparsity, val_metric) + + self.sensitivities[name][sparsity] = val_metric + pruner._unwrap_model() + del pruner + # check if the current metric meet the stop condition + if self._need_to_stop(self.ori_metric, val_metric): + break + + # reset the weights pruned by the pruner, because the + # input sparsities is sorted, so we donnot need to reset + # weight of the layer when the sparsity changes, instead, + # we only need reset the weight when the pruning layer changes. + self.model.load_state_dict(self.ori_state_dict) + + return self.sensitivities + + def export(self, filepath): + """ + Export the results of the sensitivity analysis + to a csv file. The firstline of the csv file describe the content + structure. The first line is constructed by 'layername' and sparsity + list. Each line below records the validation metric returned by val_func + when this layer is under different sparsities. Note that, due to the early_stop + option, some layers may not have the metrics under all sparsities. + + layername, 0.25, 0.5, 0.75 + conv1, 0.6, 0.55 + conv2, 0.61, 0.57, 0.56 + + Parameters + ---------- + filepath : str + Path of the output file + """ + str_sparsities = [str(x) for x in self.sparsities] + header = ['layername'] + str_sparsities + with open(filepath, 'w') as csvf: + csv_w = csv.writer(csvf) + csv_w.writerow(header) + for layername in self.sensitivities: + row = [] + row.append(layername) + for sparsity in sorted(self.sensitivities[layername].keys()): + row.append(self.sensitivities[layername][sparsity]) + csv_w.writerow(row) + + def update_already_pruned(self, layername, ratio): + """ + Set the already pruned ratio for the target layer. + """ + self.already_pruned[layername] = ratio + + def load_state_dict(self, state_dict): + """ + Update the weight of the model + """ + self.ori_state_dict = copy.deepcopy(state_dict) + self.model.load_state_dict(self.ori_state_dict) diff --git a/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py b/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py new file mode 100644 index 0000000000..8922ec483e --- /dev/null +++ b/src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import csv +import logging + +from nni._graph_utils import TorchModuleGraph + + +CONV_TYPE = 'aten::_convolution' +ADD_TYPES = ['aten::add', 'aten::add_'] +CAT_TYPE = 'aten::cat' +logger = logging.getLogger('Shape_Dependency') + + +class ChannelDependency: + def __init__(self, model=None, dummy_input=None, traced_model=None): + """ + This model analyze the channel dependencis between the conv + layers in a model. + + Parameters + ---------- + model : torch.nn.Module + The model to be analyzed. + data : torch.Tensor + The example input data to trace the network architecture. + traced_model : torch._C.Graph + if we alreay has the traced graph of the target model, we donnot + need to trace the model again. + """ + # check if the input is legal + if traced_model is None: + # user should provide model & dummy_input to trace the model or a already traced model + assert model is not None and dummy_input is not None + self.graph = TorchModuleGraph(model, dummy_input, traced_model) + self.dependency = dict() + self.build_channel_dependency() + + def _get_parent_layers(self, node): + """ + Find the nearest father conv layers for the target node. + + Parameters + --------- + node : torch._C.Node + target node. + + Returns + ------- + parent_layers: list + nearest father conv/linear layers for the target worknode. + """ + parent_layers = [] + queue = [] + queue.append(node) + while queue: + curnode = queue.pop(0) + if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear': + # find the first met conv + parent_layers.append(curnode.name) + continue + parents = self.graph.find_predecessors(curnode.unique_name) + parents = [self.graph.name_to_node[name] for name in parents] + for parent in parents: + queue.append(parent) + return parent_layers + + def build_channel_dependency(self): + """ + Build the channel dependency for the conv layers + in the model. + """ + for node in self.graph.nodes_py.nodes_op: + parent_layers = [] + # find the node that contains aten::add + # or aten::cat operations + if node.op_type in ADD_TYPES: + parent_layers = self._get_parent_layers(node) + elif node.op_type == CAT_TYPE: + # To determine if this cat operation will introduce channel + # dependency, we need the specific input parameters of the cat + # opertion. To get the input parameters of the cat opertion, we + # need to traverse all the cpp_nodes included by this NodePyGroup, + # because, TorchModuleGraph merges the important nodes and the adjacent + # unimportant nodes (nodes started with prim::attr, for example) into a + # NodepyGroup. + cat_dim = None + for cnode in node.node_cpps: + if cnode.kind() == CAT_TYPE: + cat_dim = list(cnode.inputs())[1].toIValue() + break + if cat_dim != 1: + parent_layers = self._get_parent_layers(node) + dependency_set = set(parent_layers) + # merge the dependencies + for parent in parent_layers: + if parent in self.dependency: + dependency_set.update(self.dependency[parent]) + # save the dependencies + for _node in dependency_set: + self.dependency[_node] = dependency_set + + + def export(self, filepath): + """ + export the channel dependencies as a csv file. + The layers at the same line have output channel + dependencies with each other. For example, + layer1.1.conv2, conv1, and layer1.0.conv2 have + output channel dependencies with each other, which + means the output channel(filters) numbers of these + three layers should be same with each other, otherwise + the model may has shape conflict. + + Output example: + Dependency Set,Convolutional Layers + Set 1,layer1.1.conv2,layer1.0.conv2,conv1 + Set 2,layer1.0.conv1 + Set 3,layer1.1.conv1 + """ + header = ['Dependency Set', 'Convolutional Layers'] + setid = 0 + visited = set() + with open(filepath, 'w') as csvf: + csv_w = csv.writer(csvf, delimiter=',') + csv_w.writerow(header) + for node in self.graph.nodes_py.nodes_op: + if node.op_type != 'Conv2d' or node in visited: + continue + setid += 1 + row = ['Set %d' % setid] + if node.name not in self.dependency: + visited.add(node) + row.append(node.name) + else: + for other in self.dependency[node.name]: + visited.add(self.graph.name_to_node[other]) + row.append(other) + csv_w.writerow(row) + + @property + def dependency_sets(self): + """ + Get the list of the dependency set. + + Returns + ------- + dependency_sets : list + list of the dependency sets. For example, + [set(['conv1', 'conv2']), set(['conv3', 'conv4'])] + + """ + d_sets = [] + visited = set() + for node in self.graph.nodes_py.nodes_op: + if node.op_type != 'Conv2d' or node in visited: + continue + tmp_set = set() + if node.name not in self.dependency: + visited.add(node) + tmp_set.add(node.name) + else: + for other in self.dependency[node.name]: + visited.add(self.graph.name_to_node[other]) + tmp_set.add(other) + d_sets.append(tmp_set) + return d_sets diff --git a/src/sdk/pynni/tests/test_compression_utils.py b/src/sdk/pynni/tests/test_compression_utils.py new file mode 100644 index 0000000000..803666a50c --- /dev/null +++ b/src/sdk/pynni/tests/test_compression_utils.py @@ -0,0 +1,138 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import unittest +from unittest import TestCase, main +import torch +import torch.nn as nn +import torchvision.models as models +import numpy as np + +from nni.compression.torch import L1FilterPruner +from nni.compression.torch.utils.shape_dependency import ChannelDependency +from nni.compression.torch.utils.mask_conflict import MaskConflict + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +prefix = 'analysis_test' +model_names = ['alexnet', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg19', + 'resnet18', 'resnet34', 'squeezenet1_1', + 'shufflenet_v2_x1_0', 'mobilenet_v2', 'wide_resnet50_2'] + +channel_dependency_ground_truth = { + 'resnet18': [{'layer1.0.conv2', 'layer1.1.conv2', 'conv1'}, + {'layer2.1.conv2', 'layer2.0.conv2', 'layer2.0.downsample.0'}, + {'layer3.0.downsample.0', 'layer3.1.conv2', 'layer3.0.conv2'}, + {'layer4.0.downsample.0', 'layer4.1.conv2', 'layer4.0.conv2'}], + 'resnet34': [{'conv1', 'layer1.2.conv2', 'layer1.1.conv2', 'layer1.0.conv2'}, + {'layer2.3.conv2', 'layer2.0.conv2', 'layer2.0.downsample.0', + 'layer2.1.conv2', 'layer2.2.conv2'}, + {'layer3.3.conv2', 'layer3.0.conv2', 'layer3.4.conv2', 'layer3.0.downsample.0', + 'layer3.5.conv2', 'layer3.1.conv2', 'layer3.2.conv2'}, + {'layer4.0.downsample.0', 'layer4.1.conv2', 'layer4.2.conv2', 'layer4.0.conv2'}], + 'mobilenet_v2': [{'features.3.conv.2', 'features.2.conv.2'}, + {'features.6.conv.2', 'features.4.conv.2', 'features.5.conv.2'}, + {'features.8.conv.2', 'features.7.conv.2', + 'features.10.conv.2', 'features.9.conv.2'}, + {'features.11.conv.2', 'features.13.conv.2', + 'features.12.conv.2'}, + {'features.14.conv.2', 'features.16.conv.2', 'features.15.conv.2'}], + 'wide_resnet50_2': [{'layer1.2.conv3', 'layer1.1.conv3', 'layer1.0.conv3', 'layer1.0.downsample.0'}, + {'layer2.1.conv3', 'layer2.0.conv3', 'layer2.0.downsample.0', + 'layer2.2.conv3', 'layer2.3.conv3'}, + {'layer3.3.conv3', 'layer3.0.conv3', 'layer3.2.conv3', 'layer3.0.downsample.0', + 'layer3.1.conv3', 'layer3.4.conv3', 'layer3.5.conv3'}, + {'layer4.1.conv3', 'layer4.2.conv3', 'layer4.0.downsample.0', 'layer4.0.conv3'}], + 'alexnet': [], + 'vgg11': [], + 'vgg11_bn': [], + 'vgg13': [], + 'vgg19': [], + 'squeezenet1_1': [], + 'googlenet': [], + 'shufflenet_v2_x1_0': [] +} + +unittest.TestLoader.sortTestMethodsUsing = None + + +class AnalysisUtilsTest(TestCase): + @unittest.skipIf(torch.__version__ < "1.3.0", "not supported") + def test_channel_dependency(self): + outdir = os.path.join(prefix, 'dependency') + os.makedirs(outdir, exist_ok=True) + for name in model_names: + print('Analyze channel dependency for %s' % name) + model = getattr(models, name) + net = model().to(device) + dummy_input = torch.ones(1, 3, 224, 224).to(device) + channel_depen = ChannelDependency(net, dummy_input) + depen_sets = channel_depen.dependency_sets + d_set_count = 0 + for d_set in depen_sets: + if len(d_set) > 1: + d_set_count += 1 + assert d_set in channel_dependency_ground_truth[name] + assert d_set_count == len(channel_dependency_ground_truth[name]) + fpath = os.path.join(outdir, name) + channel_depen.export(fpath) + + def get_pruned_index(self, mask): + pruned_indexes = [] + shape = mask.size() + for i in range(shape[0]): + if torch.sum(mask[i]).item() == 0: + pruned_indexes.append(i) + + return pruned_indexes + + @unittest.skipIf(torch.__version__ < "1.3.0", "not supported") + def test_mask_conflict(self): + outdir = os.path.join(prefix, 'masks') + os.makedirs(outdir, exist_ok=True) + for name in model_names: + print('Test mask conflict for %s' % name) + model = getattr(models, name) + net = model().to(device) + dummy_input = torch.ones(1, 3, 224, 224).to(device) + # random generate the prune sparsity for each layer + cfglist = [] + for layername, layer in net.named_modules(): + if isinstance(layer, nn.Conv2d): + # pruner cannot allow the sparsity to be 0 or 1 + sparsity = np.random.uniform(0.01, 0.99) + cfg = {'op_types': ['Conv2d'], 'op_names': [ + layername], 'sparsity': sparsity} + cfglist.append(cfg) + pruner = L1FilterPruner(net, cfglist) + pruner.compress() + ck_file = os.path.join(outdir, '%s.pth' % name) + mask_file = os.path.join(outdir, '%s_mask' % name) + pruner.export_model(ck_file, mask_file) + pruner._unwrap_model() + # Fix the mask conflict + mf = MaskConflict(mask_file, net, dummy_input) + fixed_mask = mf.fix_mask_conflict() + mf.export(os.path.join(outdir, '%s_fixed_mask' % name)) + # use the channel dependency groud truth to check if + # fix the mask conflict successfully + for dset in channel_dependency_ground_truth[name]: + lset = list(dset) + for i, _ in enumerate(lset): + assert fixed_mask[lset[0]]['weight'].size( + 0) == fixed_mask[lset[i]]['weight'].size(0) + w_index1 = self.get_pruned_index( + fixed_mask[lset[0]]['weight']) + w_index2 = self.get_pruned_index( + fixed_mask[lset[i]]['weight']) + assert w_index1 == w_index2 + if hasattr(fixed_mask[lset[0]], 'bias'): + b_index1 = self.get_pruned_index( + fixed_mask[lset[0]]['bias']) + b_index2 = self.get_pruned_index( + fixed_mask[lset[i]]['bias']) + assert b_index1 == b_index2 + + +if __name__ == '__main__': + main()