-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
131 lines (113 loc) · 5.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import argparse
import paddle
import pandas as pd
import random
import numpy as np
from models.dcl import DCLNet as MainModel
from utils.train_model import train
from utils.transforms import load_data_transformers
from utils.dataset import MyDataSet, collate_fn4train, collate_fn4test
import warnings
import paddle.distributed as dist
paddle.device.set_device('gpu')
warnings.filterwarnings('ignore')
def set_random_seed(seed):
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
class LoadConfig:
def __init__(self, args):
self.describe = args.describe
# 预训练权重
self.pretrained_model = '/data/zhangzichao/models/resnet50.pdparams'
# 数据集信息
if args.dataset == 'CUB':
self.dataset = args.dataset
self.rawdata_root = '/data/zhangzichao/datasets/CUB/images'
self.anno_root = './datasets/CUB'
self.numcls = 200
elif args.dataset == 'STCAR':
self.dataset = args.dataset
self.rawdata_root = '/data/zhangzichao/datasets/StanfordCars/car_ims/'
self.anno_root = './datasets/STCAR'
self.numcls = 196
elif args.dataset == 'AIR':
self.dataset = args.dataset
self.rawdata_root = '/data/zhangzichao/datasets/fgvc-aircraft-2013b/data/images/'
self.anno_root = './datasets/AIR'
self.numcls = 100
elif args.dataset == 'CUB_TINY':
self.dataset = args.dataset
self.rawdata_root = './datasets/CUB_TINY'
self.anno_root = './datasets/CUB_TINY'
self.numcls = 4
else:
raise Exception('dataset not defined')
self.train_anno = pd.read_csv(os.path.join(self.anno_root, 'train.txt'), sep=",", header=None,
names=['ImageName', 'label'])
self.test_anno = pd.read_csv(os.path.join(self.anno_root, 'test.txt'), sep=",", header=None,
names=['ImageName', 'label'])
self.save_dir = f'./outputs/{args.dataset}/checkpoints'
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.log_folder = f'./outputs/{args.dataset}/logs'
if not os.path.exists(self.log_folder):
os.makedirs(self.log_folder)
def parse_args():
parser = argparse.ArgumentParser(description='dcl parameters')
parser.add_argument('--data', dest='dataset', default='CUB', type=str)
parser.add_argument('--backbone', dest='backbone', default='resnet50', type=str)
parser.add_argument('--epoch', dest='epoch', default=360, type=int)
parser.add_argument('--lr_step', dest='lr_step', default=60, type=int)
parser.add_argument('--tb', dest='train_batch', default=16, type=int)
parser.add_argument('--vb', dest='val_batch', default=16, type=int)
parser.add_argument('--tnw', dest='train_num_workers', default=16, type=int)
parser.add_argument('--vnw', dest='val_num_workers', default=16, type=int)
parser.add_argument('--lr', dest='base_lr', default=0.0008, type=float)
parser.add_argument('--start_epoch', dest='start_epoch', default=0, type=int)
parser.add_argument('--detail', dest='describe', default='dcl_cub', type=str)
parser.add_argument('--size', dest='resize_resolution', default=512, type=int)
parser.add_argument('--crop', dest='crop_resolution', default=448, type=int)
parser.add_argument('--swap_num', dest='swap_num', default=7, type=int, help='specify a range')
parser.add_argument('--save_model_name', dest='save_model_name', default=None, type=str, help='Model weight name')
args = parser.parse_args()
return args
def main():
args = parse_args()
set_random_seed(seed=2022)
print(args, flush=True)
Config = LoadConfig(args)
# 数据集加载
transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, (args.swap_num, args.swap_num))
train_set = MyDataSet(Config=Config, anno=Config.train_anno, common_aug=transformers["common_aug"],
swap=transformers["swap"], totensor=transformers["train_totensor"], train=True)
test_set = MyDataSet(Config=Config, anno=Config.test_anno, common_aug=transformers["None"],
swap=transformers["None"], totensor=transformers["test_totensor"], test=True)
dataloader = {}
dataloader['train'] = paddle.io.DataLoader(train_set, batch_size=args.train_batch, shuffle=True,
num_workers=args.train_num_workers, collate_fn=collate_fn4train,
drop_last=False)
setattr(dataloader['train'], 'total_item_len', len(train_set))
dataloader['test'] = paddle.io.DataLoader(test_set, batch_size=args.val_batch, shuffle=False,
num_workers=args.val_num_workers, collate_fn=collate_fn4test,
drop_last=False)
setattr(dataloader['test'], 'total_item_len', len(test_set))
setattr(dataloader['test'], 'num_cls', Config.numcls)
# 模型加载、定义优化器
model = MainModel(Config)
dist.init_parallel_env()
model = paddle.DataParallel(model)
scheduler = paddle.optimizer.lr.StepDecay(learning_rate=args.base_lr, step_size=args.lr_step)
optimizer = paddle.optimizer.Momentum(learning_rate=scheduler, parameters=model.parameters())
# 开始训练
train(Config,
model,
epoch_num=args.epoch,
start_epoch=args.start_epoch,
optimizer=optimizer,
scheduler=scheduler,
data_loader=dataloader,
date_suffix=args.save_model_name)
if __name__ == '__main__':
main()