-
Notifications
You must be signed in to change notification settings - Fork 37
/
eval.py
107 lines (97 loc) · 3.08 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import torch as th
from torch.utils.data import DataLoader
from args import get_args
from model import Net
from metrics import compute_metrics, print_computed_metrics
from gensim.models.keyedvectors import KeyedVectors
import pickle
import glob
from lsmdc_dataloader import LSMDC_DataLoader
from msrvtt_dataloader import MSRVTT_DataLoader
from youcook_dataloader import Youcook_DataLoader
args = get_args()
if args.verbose:
print(args)
assert args.pretrain_path != '', 'Need to specify pretrain_path argument'
print('Loading word vectors: {}'.format(args.word2vec_path))
we = KeyedVectors.load_word2vec_format(args.word2vec_path, binary=True)
print('done')
if args.eval_youcook:
dataset_val = Youcook_DataLoader(
data=args.youcook_val_path,
we=we,
max_words=args.max_words,
we_dim=args.we_dim,
)
dataloader_val = DataLoader(
dataset_val,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
shuffle=False,
)
if args.eval_lsmdc:
dataset_lsmdc = LSMDC_DataLoader(
csv_path=args.lsmdc_test_csv_path,
features_path=args.lsmdc_test_features_path,
we=we,
max_words=args.max_words,
we_dim=args.we_dim,
)
dataloader_lsmdc = DataLoader(
dataset_lsmdc,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
shuffle=False,
)
if args.eval_msrvtt:
msrvtt_testset = MSRVTT_DataLoader(
csv_path=args.msrvtt_test_csv_path,
features_path=args.msrvtt_test_features_path,
we=we,
max_words=args.max_words,
we_dim=args.we_dim,
)
dataloader_msrvtt = DataLoader(
msrvtt_testset,
batch_size=3000,
num_workers=args.num_thread_reader,
shuffle=False,
drop_last=False,
)
net = Net(
video_dim=args.feature_dim,
embd_dim=args.embd_dim,
we_dim=args.we_dim,
max_words=args.max_words,
)
net.eval()
net.cuda()
if args.verbose:
print('Starting evaluation loop ...')
def Eval_retrieval(model, eval_dataloader, dataset_name):
model.eval()
print('Evaluating Text-Video retrieval on {} data'.format(dataset_name))
with th.no_grad():
for i_batch, data in enumerate(eval_dataloader):
text = data['text'].cuda()
video = data['video'].cuda()
vid = data['video_id']
m = model(video, text)
m = m.cpu().detach().numpy()
metrics = compute_metrics(m)
print_computed_metrics(metrics)
all_checkpoints = glob.glob(args.pretrain_path)
for c in all_checkpoints:
print('Eval checkpoint: {}'.format(c))
print('Loading checkpoint: {}'.format(c))
net.load_checkpoint(c)
if args.eval_youcook:
Eval_retrieval(net, dataloader_val, 'YouCook2')
if args.eval_msrvtt:
Eval_retrieval(net, dataloader_msrvtt, 'MSR-VTT')
if args.eval_lsmdc:
Eval_retrieval(net, dataloader_lsmdc, 'LSMDC')