-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
32 lines (24 loc) · 849 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from dataset import *
import evaluation
from bmvos import BMVOS
import warnings
warnings.filterwarnings('ignore')
def test(model):
datasets = {
'DAVIS16_val': TestDAVIS('../DB/DAVIS', '2016', 'val'),
'DAVIS17_val': TestDAVIS('../DB/DAVIS', '2017', 'val'),
'DAVIS17_test-dev': TestDAVIS('../DB/DAVIS', '2017', 'test-dev'),
# 'YTVOS18_val': TestYTVOS('../DB/YTVOS18', 'val')
}
for key, dataset in datasets.items():
evaluator = evaluation.Evaluator(dataset)
evaluator.evaluate(model, os.path.join('outputs', key))
if __name__ == '__main__':
# set device
torch.cuda.set_device(0)
# define model
model = BMVOS().eval()
# testing stage
model.load_state_dict(torch.load('weights/BMVOS_davis.pth', map_location='cpu'))
with torch.no_grad():
test(model)