diff --git a/torch_pruning/pruner/importance.py b/torch_pruning/pruner/importance.py index 75046ae..9094641 100644 --- a/torch_pruning/pruner/importance.py +++ b/torch_pruning/pruner/importance.py @@ -37,7 +37,6 @@ class Importance(abc.ABC): It should accept a group as inputs, and return a 1-D tensor with the same length as the number of channels. All groups must be pruned simultaneously and thus their importance should be accumulated across channel groups. - Just ignore the ch_groups if you are not familar with grouping. Example: ```python @@ -362,7 +361,7 @@ def __call__(self, group, **kwargs): # repeat importance for group convolutions if prune_fn == function.prune_conv_in_channels and layer.groups != layer.in_channels and layer.groups != 1: - local_imp = local_imp.repeat(ch_groups) + local_imp = local_imp.repeat(layer.groups) local_imp = local_imp[idxs] similar_matrix = torch.cdist(local_imp.unsqueeze(0), local_imp.unsqueeze(0), p=2).squeeze(0) similar_sum = torch.sum(torch.abs(similar_matrix), dim=0) @@ -582,7 +581,7 @@ def adjust_fisher(self, group, idxs): indices_to_keep = list(range(self.Fisher[layer].shape[1])) for idx in idxs: indices_to_keep = [i for i in indices_to_keep if not (idx*kernel_size <= i < (idx+1)*kernel_size)] - self.Fisher[layer] = torch.index_select(self.Fisher[layer], 1, torch.tensor(indices_to_keep).to(self.Fisher[layer].device)) + self.Fisher[layer] = torch.index_select(self.Fisher[layer], 1, torch.LongTensor(indices_to_keep).to(self.Fisher[layer].device)) def _rm_hooks(self, model): @@ -615,7 +614,7 @@ def _clear_buffer(self): self.steps = 0 @torch.no_grad() - def __call__(self, group, ch_groups=1): + def __call__(self, group): group_imp = [] group_idxs = [] for i, (dep, idxs) in enumerate(group):