-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_tv_synth.py
101 lines (77 loc) · 2.95 KB
/
eval_tv_synth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#%%
ROOT_DATA_FOLDER = 'data/synth'
ROOT_SAVE_FOLDER = 'save'
ROOT_TV_FOLDER = 'eval/tv'
#%%
# Imports
import os
import csv
import torch
import numpy as np
import itertools
from models.traditional.self_correcting import SelfCorrecting
from models.traditional.exp_hawkes import ExpHawkes
from models.traditional.sinexp_hawkes import SinExpHawkes
from models.basis_sum.exp_basis import ExpSum
from models.basis_sum.powerlaw_basis import PowerlawSum
from models.basis_sum.relu_basis import ReLUSum
from models.basis_sum.sin_basis import SinSum
from models.basis_sum.sigmoid_basis import SigmoidSum
from models.basis_sum.mixed_basis import MixedSum
from models.data import GeneralDataset, collate_no_marks
from plot.utils import total_variation
#%%
RNN_SIZE = 48
RNN_LAYER = 1
NUM_BASIS = 64
NUM_OUTPUT_LAYER = 2
SAMPLE_FREQ = 10
synth_data_list = [
('selfcorrecting', lambda x: SelfCorrecting(1, 1)),
('exphawkes', lambda x: ExpHawkes(0.5, 0.8, 1)),
('sinexphawkes', lambda x: SinExpHawkes(0.5, 0.4, 2, 1.0)),
]
make_model = [
('expsum', lambda x: ExpSum(RNN_SIZE, RNN_LAYER, NUM_BASIS)),
('powerlaw', lambda x: PowerlawSum(RNN_SIZE, RNN_LAYER, NUM_BASIS)),
('relusum', lambda x: ReLUSum(RNN_SIZE, RNN_LAYER, NUM_BASIS)),
('sinsum', lambda x: SinSum(RNN_SIZE, RNN_LAYER, NUM_BASIS)),
('sigmoidsum', lambda x: SigmoidSum(RNN_SIZE, RNN_LAYER, NUM_BASIS)),
('mixedsum', lambda x: MixedSum(RNN_SIZE, RNN_LAYER, NUM_BASIS // 2)),
]
#%%
data_model_iter = itertools.product(synth_data_list, make_model)
for dataset_pair, model_pair in data_model_iter:
model_name, model_gen = model_pair
dataset_name, dataset_gen = dataset_pair
fitted_model = model_gen(True)
true_model = dataset_gen(True)
fitted_name = dataset_name + '_' + model_name
true_name = dataset_name + '_' + dataset_name
print('Start Evaluating:', fitted_name)
fitted_folder = os.path.join(ROOT_SAVE_FOLDER, fitted_name)
fitted_model.load(fitted_folder)
data = []
save_file = os.path.join(ROOT_DATA_FOLDER, dataset_name + '.csv')
with open(save_file, 'r') as f:
reader = csv.reader(f)
for row in reader:
data.append(torch.Tensor(list(map(float, row))))
# Preprocess
whole_dataset = GeneralDataset(data)
whole_dataset.log_transform_rnn()
train_dataset, validation_dataset, test_dataset = whole_dataset.train_val_test_split(0.6, 0.2, 0.2, seed=1)
rnn_mean, rnn_std = train_dataset.rnn_statistics()
test_dataset.normalise_data(rnn_mean=rnn_mean, rnn_std=rnn_std)
tvs = []
for d in test_dataset:
events = collate_no_marks([d])
cur_tv = total_variation(fitted_model, true_model, events, SAMPLE_FREQ)
tvs.append(cur_tv.item())
tv_save_file = os.path.join(ROOT_TV_FOLDER, fitted_name + '.csv')
with open(tv_save_file, 'w') as f:
writer = csv.writer(f)
writer.writerow(tvs)
print('Average TV:', np.mean(tvs))
print('Finish Evaluating:', fitted_name)
# %%