-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathload_data.py
104 lines (95 loc) · 4.51 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import os
from tqdm import tqdm
def load_EOD_data(data_path, market_name, tickers, steps=1):
eod_data = []
masks = []
ground_truth = []
base_price = []
for index, ticker in enumerate(tqdm(tickers)):
single_EOD = np.genfromtxt(
os.path.join(data_path, market_name + '_' + ticker + '_1.csv'),
dtype=np.float32, delimiter=',', skip_header=False
)
if market_name == 'NASDAQ':
# remove the last day since lots of missing data
single_EOD = single_EOD[:-1, :]
if index == 0:
print('single EOD data shape:', single_EOD.shape)
eod_data = np.zeros([len(tickers), single_EOD.shape[0],
single_EOD.shape[1] - 1], dtype=np.float32)
masks = np.ones([len(tickers), single_EOD.shape[0]],
dtype=np.float32)
ground_truth = np.zeros([len(tickers), single_EOD.shape[0]],
dtype=np.float32)
base_price = np.zeros([len(tickers), single_EOD.shape[0]],
dtype=np.float32)
for row in range(single_EOD.shape[0]):
if abs(single_EOD[row][-1] + 1234) < 1e-8:
masks[index][row] = 0.0
elif row > steps - 1 and abs(single_EOD[row - steps][-1] + 1234) \
> 1e-8:
ground_truth[index][row] = \
(single_EOD[row][-1] - single_EOD[row - steps][-1]) / \
single_EOD[row - steps][-1]
for col in range(single_EOD.shape[1]):
if abs(single_EOD[row][col] + 1234) < 1e-8:
single_EOD[row][col] = 1.1
eod_data[index, :, :] = single_EOD[:, 1:]
base_price[index, :] = single_EOD[:, -1]
return eod_data, masks, ground_truth, base_price
def load_graph_relation_data(relation_file, lap=False):
relation_encoding = np.load(relation_file)
print('relation encoding shape:', relation_encoding.shape)
rel_shape = [relation_encoding.shape[0], relation_encoding.shape[1]]
mask_flags = np.equal(np.zeros(rel_shape, dtype=int),
np.sum(relation_encoding, axis=2))
ajacent = np.where(mask_flags, np.zeros(rel_shape, dtype=float),
np.ones(rel_shape, dtype=float))
degree = np.sum(ajacent, axis=0)
for i in range(len(degree)):
degree[i] = 1.0 / degree[i]
np.sqrt(degree, degree)
deg_neg_half_power = np.diag(degree)
if lap:
return np.identity(ajacent.shape[0], dtype=float) - np.dot(
np.dot(deg_neg_half_power, ajacent), deg_neg_half_power)
else:
return np.dot(np.dot(deg_neg_half_power, ajacent), deg_neg_half_power)
def load_relation_data(relation_file):
relation_encoding = np.load(relation_file)
print('relation encoding shape:', relation_encoding.shape)
rel_shape = [relation_encoding.shape[0], relation_encoding.shape[1]]
mask_flags = np.equal(np.zeros(rel_shape, dtype=int),
np.sum(relation_encoding, axis=2))
mask = np.where(mask_flags, np.ones(rel_shape) * -1e9, np.zeros(rel_shape))
return relation_encoding, mask
def build_SFM_data(data_path, market_name, tickers):
eod_data = []
for index, ticker in enumerate(tickers):
single_EOD = np.genfromtxt(
os.path.join(data_path, market_name + '_' + ticker + '_1.csv'),
dtype=np.float32, delimiter=',', skip_header=False
)
if index == 0:
print('single EOD data shape:', single_EOD.shape)
eod_data = np.zeros([len(tickers), single_EOD.shape[0]],
dtype=np.float32)
for row in range(single_EOD.shape[0]):
if abs(single_EOD[row][-1] + 1234) < 1e-8:
# handle missing data
if row < 3:
# eod_data[index, row] = 0.0
for i in range(row + 1, single_EOD.shape[0]):
if abs(single_EOD[i][-1] + 1234) > 1e-8:
eod_data[index][row] = single_EOD[i][-1]
# print(index, row, i, eod_data[index][row])
break
else:
eod_data[index][row] = np.sum(
eod_data[index, row - 3:row]) / 3
# print(index, row, eod_data[index][row])
else:
eod_data[index][row] = single_EOD[row][-1]
# print('test point')
np.save(market_name + '_sfm_data', eod_data)