-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathaction_space.py
106 lines (79 loc) · 3 KB
/
action_space.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import pyflann
from gym.spaces import Box
import numpy as np
import itertools
class Space:
def __init__(self, low, high, points):
self._low = np.array(low)
self._high = np.array(high)
self._range = self._high - self._low
self._dimensions = len(low)
self.__space = init_uniform_space([0] * self._dimensions,
[1] * self._dimensions,
points)
# print("self.__space: {}, self.__space.shape: {}, self.__space.dtype: {}".format(self.__space, self.__space.shape, self.__space.dtype))
self._flann = pyflann.FLANN()
self.rebuild_flann()
def rebuild_flann(self):
self._index = self._flann.build_index(self.__space, algorithm='kdtree')
# print("Index type: {}".format(type(self._index)))
def search_point(self, point, k):
p_in = self.import_point(point).reshape(1, -1).astype('float64')
# print("p_in: {}, p_in.shape: {}, p_in.dtype: {}".format(p_in, p_in.shape, p_in.dtype))
search_res, _ = self._flann.nn_index(p_in, k)
knns = self.__space[search_res]
p_out = []
for p in knns:
p_out.append(self.export_point(p))
if k == 1:
p_out = [p_out]
return np.array(p_out)
def import_point(self, point):
return (point - self._low) / self._range
def export_point(self, point):
return self._low + point * self._range
def get_space(self):
return self.__space
def shape(self):
return self.__space.shape
def get_number_of_actions(self):
return self.shape()[0]
def plot_space(self, additional_points=None):
dims = self._dimensions
if dims > 3:
print(
'Cannot plot a {}-dimensional space. Max 3 dimensions'.format(dims))
return
space = self.get_space()
if additional_points is not None:
for i in additional_points:
space = np.append(space, additional_points, axis=0)
if dims == 1:
for x in space:
plt.plot([x], [0], 'o')
plt.show()
elif dims == 2:
for x, y in space:
plt.plot([x], [y], 'o')
plt.show()
else:
plot_3d_points(space)
class Discrete_space(Space):
"""
Discrete action space with n actions (the integers in the range [0, n))
0, 1, 2, ..., n-2, n-1
"""
def __init__(self, n): # n: the number of the discrete actions
super().__init__([0], [n - 1], n)
def export_point(self, point):
return super().export_point(point).astype(int)
def init_uniform_space(low, high, points):
dims = len(low)
points_in_each_axis = round(points**(1 / dims))
axis = []
for i in range(dims):
axis.append(list(np.linspace(low[i], high[i], points_in_each_axis)))
space = []
for _ in itertools.product(*axis):
space.append(list(_))
return np.array(space)