[TOC]
A basic and simple training framework for pytorch, easy for extension.
- torch==1.7.0+cu110 (>= 1.6 for data distributed parallel)
- torchvision==0.8.0
- easydict==1.9
- tensorboard==2.7
- tensorboardX==2.4
- PyYAML==5.3.1
- Any module can be easily customized
- Not abstract, easy to learn, develop and debug
- With a lot of repetitive work reduction, it can be more easily control the training process in development and research
- Friendly to multi-network interactive training, such as GAN, transfer learning, knowledge distillation, etc.
- DDP training support.
python main.py --model resnet18 --save_dir cifar100_resnet18 --config_path ./configs/cifar100.yml
for ddp training, you can just run the below command in the terminal
sh run_ddp.sh
All CIFAR100 log files can be downloaded in package here: pan.baidu code:3lp2
Network | Accuracy | log |
---|---|---|
resnet18 | 76.46 | pan.baidu code: ewnd |
resnet34 | 77.23 | pan.baidu code: dq4r |
resnet50 | 76.82 | pan.baidu code: 1e62 |
resnet101 | 77.32 | pan.baidu code: myfv |
vgg11_bn | 70.52 | pan.baidu code: 2pun |
vgg13_bn | 73.71 | pan.baidu code: 4vmm |
mobilenetV2 | 68.99 | pan.baidu code: e93w |
shufflenet | 71.17 | pan.baidu code: lnvy |
shufflenetV2 | 71.16 | pan.baidu code: vmi6 |
All CIFAR10 log files can be downloaded in package here: pan.baidu code:3iqz
Network | Accuracy | log |
---|---|---|
resnet18 | 94.92 | pan.baidu code: a20j |
resnet34 | 94.80 | pan.baidu code: q8h1 |
resnet50 | 94.81 | pan.baidu code: f3wr |
resnet101 | 95.45 | pan.baidu code: d3i8 |
vgg11_bn | 92.21 | pan.baidu code: di45 |
vgg13_bn | 93.74 | pan.baidu code: su1z |
mobilenetV2 | 90.92 | pan.baidu code: todf |
shufflenet | 92.06 | pan.baidu code: 1xr2 |
shufflenetV2 | 91.61 | pan.baidu code: 8swu |
-
import modules
from src import DatasetBuilder, TransformBuilder, ModelBuilder, LossBuilder, LossWrapper, OptimizerBuilder, SchedulerBuilder, MetricBuilder from torch.utils.data import DataLoader from src import Controller from src.utils import AverageMeter
-
Load your dataloader and transform
transform_name = 'cifar100_transform' # your transform function name dataset_name = 'CIFAR100' # your dataset class name train_transform, val_transform = TransformBuilder.load(transform_name) trainset, trainset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=train_transform, train=True) valset, valset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=val_transform, train=False) train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
-
Load model, loss wrapper, metrics, optimizer, schemes and controller
epochs = 20 model = ModelBuilder.load("resnet18", num_classes=100) # if categorical weights: LossBuilder.load("CrossEntropyLoss", weight=torch.tensor([1.0, 2.0]).float()) loss_func1 = LossBuilder.load("CrossEntropyLoss") # Only one loss function, cross entropy, set its weight to 1.0 loss_wrapper = LossWrapper([loss_func1], [1.0]) model = model.cuda() loss_wrapper = loss_wrapper.cuda() metric_functions = [MetricBuilder.load('Accuracy')] optimizer, optimizer_param = OptimizerBuilder.load('Adam', model.parameters(), lr=0.1) scheduler, scheduler_param = SchedulerBuilder.load("cosine_annealing_lr", optimizer, max_epoch=epochs) controller = Controller(loss_wrapper=loss_wrapper, model=model, optimizer=optimizer)
-
Train !
for epoch in range(epochs): # train model.train() loss_recorder = AverageMeter(type='scalar', name='total loss') loss_list_recorder = AverageMeter(type='tuple', num_scalar=1, names=["CrossEntropyLoss"]) metric_list_recorder = AverageMeter(type='tuple', num_scalar=1, names=["Accuracy"]) for (img, label) in train_loader: img = img.cuda() label = label.cuda() loss, loss_tuple, output_no_grad = controller.train_step(img, label) loss_recorder.update(loss.item(), img.size(0)) loss_list_recorder.update(loss_tuple, img.size(0)) metrics = tuple([func(output_no_grad, label) for func in metric_functions]) metric_list_recorder.update(metrics, img.size(0)) print(f"total loss:{loss_recorder.get_value()} loss_tuple:{loss_list_recorder} metrics:{metric_list_recorder}") # eval model.eval() # ...
- In
src/datasets
directory, define your customized mydataset.py likecifar.py
,. - CIFAR class needs some parameters for initialization, such as
root
,train
,download
, which can be specified insrc/datasets/dataset_config.yml
. Something should be noticed thattransform
needs to be set intransorms.py
, details can be found at Customize Transform. - In
src/datasets/dataset_builder.py
, please import your dataset class. For example,MyDataset
class is defined inmydataset.py
, thusfrom .mydataset import *
indataset_builder.py
- In
configs/xxx.yml
, setdataset.name
toMyDataset
- In
src/datasets/transforms.py
, define your transform function, namedmy_transform_func
, which returnstrain_transform
andval_transform
- In
configs/xxx.yml
, please setdataset.transform_name
tomy_transform_func
- In
src/models
directory, define your customized model, such asmy_model.py
, and define the module classMyModel
. Please refer toresnet.py
- In
src/models/model_builder.py
, import your model.from .my_model import *
under thetry
process, andfrom my_model import *
under theexcept
process. It is just convenient for debugging. - In
configs/xxx.yml
, set themodel['name']
toMyModel
- In
src/losses
directory, you can define customized loss Module, such asCrossEntropyLoss
inclassification.py
. - Then import your loss Module in
loss_builder.py
- Maybe your model is supervised by multiple loss functions, which have different weights, so
LossWrapper
Module insrc/losses/loss_wrapper.py
may meet the requirement. - In
configs/xxx.yml
, please add your loss names and weights intotrain.criterion.names
andtrain/criterion/loss_weights
respectively.
-
In
src/optimizer
directory,optimizers.py
can be found, please define your customized optimizer here. For example,SGD
andAdam
have already defined,parameters
andlr
should be specified, and other params need to be specifed by*args, **kwargs
. Please refer to# src/optimizer/optimizers.py def SGD(parameters, lr, *args, **kwargs) -> optim.Optimizer: optimizer = optim.SGD(parameters, lr=lr, *args, **kwargs) return optimizer
-
For other parameters, such as
weight_decay
, can be set insrc/optimizer/optimizer_config.yml
. Please refer to the below yaml, and it is ok for5e-4
format, we transform it insrc/optimizer/optimizer_builder.py
.# src/optimizer/optimizer_config.yml SGD: momentum: 0.9 weight_decay: 5e-4 dampening: 0 nesterov: False
-
In
configs/xxx.yml
, set thetrain['lr']
, and set thetrain['optimizer']
toSGD
- In
src/schemes/lr_schemes.py
, define your learning rate scheme function, namedmy_scheduler
, which requires some params, such asoptimizer
,epochs
and so on. - Other params can be specified easily in
src/schemes/scheme_config.yml
- In
configs/xxx.yml
, set thetrain.schedule
tomy_scheduler
- In
src/metrics/
directory, define your metric, such asAccuracy
inaccuracy.py
, which computes the metric of predictions and target and returns an metric scalar - Import your metric in
metric_builder.py
, for example,from .accuracy import *
- Multiple metrics are supported, in
configs/xxx.yml
, add your metrics intotrain.metric.names
. While training model, the strategy of saving checkpoint refers to thetrain.metrics.key_metric_name
inconfigs.xxx.yml
, more details can be found at Customize Checkpoint Saving Strategy
- In
src/controller.py
, please feel free to build your training end validation step - Training step returns
loss
,loss_tuple
andoutput_no_grad
, whereloss_tuple
andoutput_no_grad
only involve in logging, whetherloss
has a gradient or not depends on you.
- After training epoch, validation epoch will be performed in general. Torch-atom's NetIO in
src/utils/netio.py
will save the best state dict according tokey_metric_name
andstrategy
inconfigs/xxx.yml
- Of course, checkpoint can be saved each
save_freq
epoch, which can be set inconfigs/xxx.yml
as well
- 2202.6.2 DDP support for training
- DDP training
- More experiment results
- More widely-used datasets and models
- Some visualization code for analysis
- bad case analysis
- data augmentation visualization
- ...
Torch-atom got ideas and developed based on the following projects:
If you find this project useful in your research, please consider cite:
@misc{2022torchatom,
title={Torch-atom: A basic and simple training framework for pytorch},
author={Baitan Shao},
howpublished = {\url{https://github.com/shaoeric/torch-atom}},
year={2022}
}
The MIT License | Open Source Initiative
Please feel free to submit issues, :)