-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy patheval.py
72 lines (59 loc) · 2.23 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import sys
import logging
import warnings
from options.train_options1 import TrainOptions
from models import create_model
from cda.config import cfg as detcfg
from cda.utils.logger import setup_logger
from cda.utils import dist_util, mkdir
warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
basedir = os.path.dirname(os.path.realpath(__file__))
if __name__ == '__main__':
command_line = 'python ' + ' '.join(sys.argv)
opt = TrainOptions().parse()
opt.run = opt.train_config
# some default parammeters for evaluating
opt.eps = 10
opt.eval_num_classes = 1000
opt.eval_dataset = "imagenet5k_val"
opt.gen_dropout = 0.0
opt.loss_fn, opt.softmax2D = None, False
opt.act_layer_mean = False
opt.data_dim = 'high'
detcfg.merge_from_file('./checkpoints/{}/{}.yaml'.format(opt.train_config, opt.train_config))
detcfg.TEST.BATCH_SIZE = 5
detcfg.OUTPUT_DIR = os.path.join(basedir, "checkpoints", str(opt.train_config))
detcfg.EVAL_MODEL.META_ARCHITECTURE = opt.eval_model
detcfg.EVAL_MODEL.NUM_CLASSES = opt.eval_num_classes
detcfg.DATASETS.TEST = (opt.eval_dataset,)
opt.isTrain = False
opt.classifier_weights = ''
detcfg.MODEL.BACKBONE.PRETRAINED = True
detcfg.MODEL.NUM_CLASSES = 1000
# we evaluate on the resolution on which GAN is trained against
if 'inception' in opt.train_config:
detcfg.INPUT.RESIZE_SIZE = 300
detcfg.INPUT.IMAGE_SIZE = 299
opt.train_getG_299 = True
else:
detcfg.INPUT.RESIZE_SIZE = 256
detcfg.INPUT.IMAGE_SIZE = 224
opt.train_getG_299 = False
detcfg.freeze()
if detcfg.OUTPUT_DIR:
mkdir(detcfg.OUTPUT_DIR)
opt.detcfg = detcfg
opt.weightfile = ''
opt.isTrain = False
opt.eps = opt.eps / 255.0
opt.perturbmode = detcfg.NETG.PERTURB_MODE
logger = setup_logger("CDA", dist_util.get_rank(), opt.detcfg.OUTPUT_DIR, logger_name='log_eval.txt')
logger.info("Command line: {}".format(command_line))
logger = logging.getLogger("CDA.inference")
logger.info("Epsilon: {:.1f}".format(opt.eps * 255))
model = create_model(opt)
model.setup(opt)
#model.evaluate(0, save_feats=False)
model.evaluate_adv(0, save_feats=False)