Skip to content

Commit

Permalink
Typo
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Aug 9, 2024
1 parent cd9a43e commit 4c2f113
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 48 deletions.
103 changes: 60 additions & 43 deletions examples/timm_models/prune_timm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,53 +37,70 @@ def main():
test_output = model(example_inputs)
ignored_layers = []
num_heads = {}
pruning_ratio_dict = {}
import random

for m in model.modules():
#if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
ignored_layers.append(m)
print("Ignore classifier layer: ", m)

# Attention layers
if hasattr(m, 'num_heads'):
if hasattr(m, 'qkv'):
num_heads[m.qkv] = m.num_heads
print("Attention layer: ", m.qkv, m.num_heads)
elif hasattr(m, 'qkv_proj'):
num_heads[m.qkv_proj] = m.num_heads
population = [
[0.265625,0.234375,0.265625,0.265625,0.93359375,0.328125,0.2265625,0.58984375,0.54296875,0.701171875,0.919921875,0.04296875,0.796875,0.240966796875,0.07763671875],
[0.96875,0.578125,0.3515625,0.6328125,0.7578125,0.7109375,0.8984375,0.533203125,0.0703125,0.697265625,0.451171875,0.626953125,0.935546875,0.294921875,0.5244140625],
[0.25,0.421875,0.171875,0.4921875,0.71875,0.51953125,0.71875,0.876953125,0.896484375,0.626953125,0.646484375,0.490234375,0.65234375,0.599609375,0.0341796875],
[0.015625,0.015625,0.078125,0.4375,0.59375,0.6953125,0.73828125,0.611328125,0.787109375,0.76171875,0.25,0.427734375,0.154296875,0.592529296875,0.298583984375]
]

print("========Before pruning========")
print(model)
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
pruner = tp.pruner.MetaPruner(
model,
example_inputs,
global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
iterative_steps=1, # the number of iterations to achieve target pruning ratio
pruning_ratio=args.pruning_ratio, # target pruning ratio
num_heads=num_heads,
ignored_layers=ignored_layers,
)
for g in pruner.step(interactive=True):
g.prune()
for ratios in population:
k = 0
for m in model.modules():
#if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
ignored_layers.append(m)
print("Ignore classifier layer: ", m)

# Attention layers
if hasattr(m, 'num_heads'):
if hasattr(m, 'qkv'):
num_heads[m.qkv] = m.num_heads
print("Attention layer: ", m.qkv, m.num_heads)
elif hasattr(m, 'qkv_proj'):
num_heads[m.qkv_proj] = m.num_heads

elif isinstance(m, nn.Conv2d):
pruning_ratio_dict[m] = ratios[k]
k+=1

for m in model.modules():
# Attention layers
if hasattr(m, 'num_heads'):
if hasattr(m, 'qkv'):
m.num_heads = num_heads[m.qkv]
m.head_dim = m.qkv.out_features // (3 * m.num_heads)
elif hasattr(m, 'qkv_proj'):
m.num_heads = num_heads[m.qqkv_projkv]
m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)

print("========After pruning========")
print(model)
test_output = model(example_inputs)
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))
print("========Before pruning========")
print(model)
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
pruner = tp.pruner.MetaPruner(
model,
example_inputs,
global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
iterative_steps=1, # the number of iterations to achieve target pruning ratio
pruning_ratio=args.pruning_ratio, # target pruning ratio
pruning_ratio_dict=pruning_ratio_dict,
num_heads=num_heads,
ignored_layers=ignored_layers,
)
for g in pruner.step(interactive=True):
g.prune()

for m in model.modules():
# Attention layers
if hasattr(m, 'num_heads'):
if hasattr(m, 'qkv'):
m.num_heads = num_heads[m.qkv]
m.head_dim = m.qkv.out_features // (3 * m.num_heads)
elif hasattr(m, 'qkv_proj'):
m.num_heads = num_heads[m.qqkv_projkv]
m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)

print("========After pruning========")
print(model)
test_output = model(example_inputs)
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))

if __name__=='__main__':
main()
2 changes: 1 addition & 1 deletion examples/torchvision_models/torchvision_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def my_prune(model, example_inputs, output_transform, model_name):
# pruner.step()

if isinstance(pruner, (tp.pruner.BNScalePruner, tp.pruner.GroupNormPruner, tp.pruner.GrowingRegPruner)):
pruner.update_regularizor() # if the model has been pruned, we need to update the regularizor
pruner.update_regularizer() # if the model has been pruned, we need to update the regularizer
pruner.regularize(model)

if isinstance(
Expand Down
3 changes: 2 additions & 1 deletion examples/transformers/prune_timm_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def main():
elif isinstance(imp, tp.importance.GroupTaylorImportance):
loss = torch.nn.functional.cross_entropy(output, lbls)
loss.backward()



for i, g in enumerate(pruner.step(interactive=True)):
g.prune()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_pruner():
grad_dict[p] = p.grad.clone()
else:
grad_dict[p] = None
pruner.update_regularizor()
pruner.update_regularizer()
pruner.regularize(model)
for name, p in model.named_parameters():
if p.grad is not None and grad_dict[p] is not None:
Expand Down
2 changes: 1 addition & 1 deletion torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
if self.group_lasso:
self._l2_imp = MagnitudeImportance(p=2, group_reduction='mean', normalizer=None, target_types=[nn.modules.batchnorm._BatchNorm])

def update_regularizor(self):
def update_regularizer(self):
self._groups = list(self.DG.get_all_groups(root_module_types=self.root_module_types, ignored_layers=self.ignored_layers))

def regularize(self, model, reg=None, bias=False):
Expand Down
2 changes: 1 addition & 1 deletion torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def update_regularizer(self) -> None:
pass

def regularize(self, model, loss) -> typing.Any:
""" Model regularizor for sparse training
""" Model regularizer for sparse training
"""
pass

Expand Down

0 comments on commit 4c2f113

Please sign in to comment.