-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy patheval.py
70 lines (58 loc) · 2.39 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from model import EventDetector
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from dataloader import GolfDB, ToTensor, Normalize
import torch.nn.functional as F
import numpy as np
from util import correct_preds
def eval(model, split, seq_length, n_cpu, disp):
dataset = GolfDB(data_file='data/val_split_{}.pkl'.format(split),
vid_dir='data/videos_160/',
seq_length=seq_length,
transform=transforms.Compose([ToTensor(),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
train=False)
data_loader = DataLoader(dataset,
batch_size=1,
shuffle=False,
num_workers=n_cpu,
drop_last=False)
correct = []
for i, sample in enumerate(data_loader):
images, labels = sample['images'], sample['labels']
# full samples do not fit into GPU memory so evaluate sample in 'seq_length' batches
batch = 0
while batch * seq_length < images.shape[1]:
if (batch + 1) * seq_length > images.shape[1]:
image_batch = images[:, batch * seq_length:, :, :, :]
else:
image_batch = images[:, batch * seq_length:(batch + 1) * seq_length, :, :, :]
logits = model(image_batch.cuda())
if batch == 0:
probs = F.softmax(logits.data, dim=1).cpu().numpy()
else:
probs = np.append(probs, F.softmax(logits.data, dim=1).cpu().numpy(), 0)
batch += 1
_, _, _, _, c = correct_preds(probs, labels.squeeze())
if disp:
print(i, c)
correct.append(c)
PCE = np.mean(correct)
return PCE
if __name__ == '__main__':
split = 1
seq_length = 64
n_cpu = 6
model = EventDetector(pretrain=True,
width_mult=1.,
lstm_layers=1,
lstm_hidden=256,
bidirectional=True,
dropout=False)
save_dict = torch.load('models/swingnet_1800.pth.tar')
model.load_state_dict(save_dict['model_state_dict'])
model.cuda()
model.eval()
PCE = eval(model, split, seq_length, n_cpu, True)
print('Average PCE: {}'.format(PCE))