-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathargument.py
56 lines (52 loc) · 4.63 KB
/
argument.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from argparse import ArgumentParser
def parse_args(default=False):
"""Command-line argument parser for training."""
parser = ArgumentParser(description='Pytorch implementation of CCAL')
parser.add_argument('--dataset', help='Dataset',choices=['cifar10', 'cifar100'], type=str)
parser.add_argument('--one_class_idx', help='None: multi-class, Not None: one-class',default=None, type=int)
parser.add_argument('--model', help='Model',choices=['resnet18', 'resnet18_imagenet'], type=str)
parser.add_argument('--mode', help='Training mode',default='simclr', type=str)
parser.add_argument('--simclr_dim', help='Dimension of simclr layer',default=128, type=int)
parser.add_argument('--target', help='The number of queried samples in each class in each AL cycle',default=1500, type=int)
parser.add_argument('--shift_trans_type', help='shifting transformation type', default='none',choices=['rotation', 'cutperm', 'none'], type=str)
parser.add_argument("--local_rank", type=int, default=0, help='Local rank for distributed learning')
parser.add_argument('--resume_path', help='Path to the resume checkpoint',default=None, type=str)
parser.add_argument('--load_path', help='Path to the loading checkpoint',default='/home/CIFAR10_20/result_model', type=str)
parser.add_argument('--load_feature_path', help='Path to the loading checkpoint', default='/home/CIFAR10_20/result_model/distinctive', type=str)
parser.add_argument('--load_senmatic_path', help='Path to the loading checkpoint',default='/home/CIFAR10_20/result_model/semantic', type=str)
parser.add_argument('--logdir', help='Path to the loading checkpoint',default='/home/CIFAR10_20/result_model/distinctive',type=str)
parser.add_argument("--no_strict", help='Do not strictly load state_dicts',action='store_true')
parser.add_argument('--suffix', help='Suffix for the log dir',default=None, type=str)
parser.add_argument('--error_step', help='Epoch steps to compute errors',default=5, type=int)
parser.add_argument('--save_step', help='Epoch steps to save models',default=10, type=int)
##### Training Configurations #####
parser.add_argument('--epochs', help='Epochs',default=1500, type=int)
parser.add_argument('--mismatch', help='mismatch',default=0.8, type=float)
parser.add_argument('--optimizer', help='Optimizer',choices=['sgd', 'adam', 'lars'],default='sgd', type=str)
parser.add_argument('--lr_scheduler', help='Learning rate scheduler',choices=['step_decay', 'cosine'],default='cosine', type=str)
parser.add_argument('--warmup', help='Warm-up epochs', default=10, type=int)
parser.add_argument('--lr_init', help='Initial learning rate',default=1e-1, type=float)
parser.add_argument('--weight_decay', help='Weight decay',default=1e-6, type=float)
parser.add_argument('--batch_size', help='Batch size',default=128, type=int)
parser.add_argument('--test_batch_size', help='Batch size for test loader',default=100, type=int)
##### Objective Configurations #####
parser.add_argument('--sim_lambda', help='Weight for SimCLR loss',default=1.0, type=float)
parser.add_argument('--temperature', help='Temperature for similarity',default=0.5, type=float)
##### Evaluation Configurations #####
parser.add_argument("--ood_dataset", help='Datasets for OOD detection',default=None, nargs="*", type=str)
#-----None
parser.add_argument("--ood_score", help='score function for OOD detection',default=['norm_mean'], nargs="+", type=str)
parser.add_argument("--ood_layer", help='layer for OOD scores',choices=['penultimate', 'simclr', 'shift'],default=['simclr', 'shift'], nargs="+", type=str)
parser.add_argument("--ood_samples", help='number of samples to compute OOD score',default=1, type=int)
parser.add_argument("--ood_batch_size", help='batch size to compute OOD score',default=100, type=int)
parser.add_argument("--resize_factor", help='resize scale is sampled from [resize_factor, 1.0]',default=0.08, type=float)
parser.add_argument("--resize_fix", help='resize scale is fixed to resize_factor (not (resize_factor, 1.0])',action='store_true')
parser.add_argument("--print_score", help='print quantiles of ood score',action='store_true')
parser.add_argument("--save_score", help='save ood score for plotting histogram',action='store_true')
parser.add_argument('--k', help='Initial learning rate',default=100.0, type=float)
parser.add_argument('--t', help='Initial learning rate',default=0.9, type=float)
parser.add_argument("--split", help='batch size to compute OOD score',default=5, type=int)
if default:
return parser.parse_args('') # empty string
else:
return parser.parse_args()