This repo contains the PyTorch implementation of the paper: Scaling Supervised Local Learning with Augmented Auxiliary Networks. (ICLR 2024). [OpenReview]
- python 3.8.15
- pytorch 1.13.1
- torchvision 0.14.1
- Train ResNet32 with AugLocal on CIFAR10:
python train.py --dataset cifar10 --model resnet --layers 32 --cos_lr --local_module_num 16 --epochs 400 --batch_size 1024 --rule AugLocal --aux_net_depth 2 --pyramid --pyramid_coeff 0.5
- Train ResNet110 with AugLocal on CIFAR10:
python train.py --dataset cifar10 --model resnet --layers 110 --cos_lr --local_module_num 55 --epochs 400 --batch_size 1024 --rule AugLocal --aux_net_depth 3 --pyramid --pyramid_coeff 0.5
Please refer to run.sh
for the scripts of all supervised local learning rules.
Our implementation is adapted from InfoPro. We thank the authors for releasing their code.
If you have any questions, please contact [email protected].