-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_main.py
123 lines (102 loc) · 5.24 KB
/
train_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -*- coding: utf-8 -*-
import sys
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse as sp
import argparse
import os
import random
from tqdm import *
import logging
from utils.utils import get_dataset, AverageMeter, set_log
from utils.evaluate import accuracy, eval
from utils.config import DefaultConfig
from models.pad_model import PA_Detector, Face_Related_Work, Cross_Modal_Adapter
from models.networks import PAD_Classifier
config = DefaultConfig()
parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument("--train_data", type=str, default='om', help='Training data (om/ci)')
parser.add_argument("--test_data", type=str, default='ci', help='Testing data (ci/om)')
parser.add_argument("--downstream", type=str, default='FR', help='FR/FE/FA')
parser.add_argument("--graph_type",type=str, default='direct', help='direct/dense')
args = parser.parse_args()
log_dir = config.root + 'face_log/'+ args.downstream+'/'
logger = set_log(log_dir, args.train_data, args.test_data)
logger.info("Log path:" + log_dir)
logger.info("Training Protocol")
logger.info("Epoch Total number:{}".format(config.Epoch_num))
logger.info("Batch Size is {:^.2f}".format(config.batch_size))
logger.info("Shuffle Data for Training is {}".format(config.shuffle_train))
logger.info("Training set is {}".format(config.dataset[args.train_data]))
logger.info("Test set is {}".format(config.dataset[args.test_data]))
logger.info("Face related work is {}".format(config.face_related_work[args.downstream]))
logger.info("Graph type is {}".format(config.graph[args.graph_type]))
logger.info("savedir:{}".format(config.savedir))
def load_net_datasets():
net_pad = PA_Detector()
net_downstream = Face_Related_Work(config.face_related_work[args.downstream])
net_adapter = Cross_Modal_Adapter(config.graph[args.graph_type], config.batch_size)
net = PAD_Classifier(net_pad,net_downstream,net_adapter,args.downstream)
train_data_loader, test_data_loader = get_dataset('./labels',config.dataset[args.train_data], config.sample_frame, config.dataset[pargs.test_data], config.sample_frame, config.batch_size)
net.cuda()
return net, train_data_loader, test_data_loader
def train():
net, train_loader, test_loader = load_net_datasets()
best_model_TOP1 = 0.0
best_model_HTER = 1.0
best_model_AUC = 0.0
best_model_TDR = 0.0
# loss,top1 accuracy, hter, auc, tdr
valid_args = [np.inf, 0, 0, 0, 0]
logger.info('**************************** start training target model! ******************************\n')
logger.info(
'---------|-------------- VALID ---------------|---- Training ----|-------- Current Best -------|\n')
logger.info(
' epoch | loss HTER AUC TDR | loss top-1 | HTER AUC TDR |\n')
logger.info(
'-----------------------------------------------------------------------------------------------|\n')
loss_classifier = AverageMeter()
classifer_top1 = AverageMeter()
criterion = nn.CrossEntropyLoss().cuda()
if config.opt == 'Adam':
optimizer = torch.optim.Adam(net.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, amsgrad=True)
if config.opt == 'SGD':
optimizer = torch.optim.SGD(net.parameters(), lr=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
for e in range(config.Epoch_num):
t = tqdm(train_loader)
t.set_description("Epoch [{}/{}]".format(e +1 ,config.Epoch_num))
for b, (imgs, labels, _) in enumerate(t):
imgs = imgs.cuda()
labels = labels.cuda().view(-1)
net.train()
out = net(imgs)
loss = criterion(out, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss_classifier.update(loss.item())
acc = accuracy(out.narrow(0, 0, imgs.size(0)), labels, topk=(1,))
classifer_top1.update(acc[0])
if ((e+1) % config.eval_epoch == 0):
valid_args = eval(test_loader, net)
is_best = valid_args[1] <= best_model_HTER
if (is_best):
best_model_HTER = valid_args[1]
best_model_AUC = valid_args[2]
best_model_TDR = valid_args[3]
logger.info(
' %3d | %5.3f %6.3f %6.3f %6.3f | %6.3f %6.3f | %6.3f %6.3f %6.3f |'
% (
e+1,
valid_args[0], valid_args[1] * 100, valid_args[2] * 100, valid_args[3] *100,
loss_classifier.avg, classifer_top1.avg,
float(best_model_HTER * 100), float(best_model_AUC * 100), float(best_model_TDR * 100)))
if is_best:
save_dir = config.savedir+args.downstream+'_'+args.graph_type+'_Graph'+'/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = save_dir + 'Train_' + args.train_data + '_test_' + args.test_data + '_' + str(e+1)+'_HTER_'+str(round(best_model_HTER * 100, 3)) + '.pth'
torch.save(net.state_dict(), save_path)
if __name__ == "__main__":
train()