-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathtest.py
67 lines (51 loc) · 2.27 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import argparse
import torch
import mixed_precision
from stats import AverageMeterSet
from datasets import Dataset, build_dataset, get_dataset, get_encoder_size
from model import Model
from checkpoint import Checkpointer
from utils import test_model
parser = argparse.ArgumentParser(description='Infomax Representations - Testing Script')
# parameters for general training stuff
parser.add_argument('checkpoint_path', type=str,
help='path from which to load checkpoint')
parser.add_argument('--dataset', type=str, default='STL10')
parser.add_argument('--batch_size', type=int, default=200,
help='input batch size (default: 200)')
parser.add_argument('--seed', type=int, default=1,
help='random seed (default: 1)')
parser.add_argument('--amp', action='store_true', default=False,
help='Enables automatic mixed precision')
parser.add_argument('--input_dir', type=str, default='/mnt/imagenet',
help="Input directory for the dataset. Not needed For C10,"
" C100 or STL10 as the data will be automatically downloaded.")
parser.add_argument('--run_name', type=str, default='default_run',
help='name to use for the tensorbaord summary for this run')
args = parser.parse_args()
def test(model, test_loader, device, stats):
test_model(model, test_loader, device, stats)
def main():
# enable mixed-precision computation if desired
if args.amp:
mixed_precision.enable_mixed_precision()
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# get the dataset
dataset = get_dataset(args.dataset)
_, test_loader, _ = build_dataset(dataset=dataset,
batch_size=args.batch_size,
input_dir=args.input_dir)
torch_device = torch.device('cuda')
checkpointer = Checkpointer()
model = checkpointer.restore_model_from_checkpoint(args.checkpoint_path)
model = model.to(torch_device)
model, _ = mixed_precision.initialize(model, None)
test_stats = AverageMeterSet()
test(model, test_loader, torch_device, test_stats)
stat_str = test_stats.pretty_string(ignore=model.tasks)
print(stat_str)
if __name__ == "__main__":
print(args)
main()