Skip to content

Commit

Permalink
Fixed a name bug
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Mar 9, 2024
1 parent 2691a9d commit 319656b
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions torch_pruning/pruner/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class Importance(abc.ABC):
It should accept a group as inputs, and return a 1-D tensor with the same length as the number of channels.
All groups must be pruned simultaneously and thus their importance should be accumulated across channel groups.
Just ignore the ch_groups if you are not familar with grouping.
Example:
```python
Expand Down Expand Up @@ -362,7 +361,7 @@ def __call__(self, group, **kwargs):

# repeat importance for group convolutions
if prune_fn == function.prune_conv_in_channels and layer.groups != layer.in_channels and layer.groups != 1:
local_imp = local_imp.repeat(ch_groups)
local_imp = local_imp.repeat(layer.groups)
local_imp = local_imp[idxs]
similar_matrix = torch.cdist(local_imp.unsqueeze(0), local_imp.unsqueeze(0), p=2).squeeze(0)
similar_sum = torch.sum(torch.abs(similar_matrix), dim=0)
Expand Down Expand Up @@ -582,7 +581,7 @@ def adjust_fisher(self, group, idxs):
indices_to_keep = list(range(self.Fisher[layer].shape[1]))
for idx in idxs:
indices_to_keep = [i for i in indices_to_keep if not (idx*kernel_size <= i < (idx+1)*kernel_size)]
self.Fisher[layer] = torch.index_select(self.Fisher[layer], 1, torch.tensor(indices_to_keep).to(self.Fisher[layer].device))
self.Fisher[layer] = torch.index_select(self.Fisher[layer], 1, torch.LongTensor(indices_to_keep).to(self.Fisher[layer].device))


def _rm_hooks(self, model):
Expand Down Expand Up @@ -615,7 +614,7 @@ def _clear_buffer(self):
self.steps = 0

@torch.no_grad()
def __call__(self, group, ch_groups=1):
def __call__(self, group):
group_imp = []
group_idxs = []
for i, (dep, idxs) in enumerate(group):
Expand Down

0 comments on commit 319656b

Please sign in to comment.