-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathresults_plotter.py
146 lines (119 loc) · 5.15 KB
/
results_plotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from stable_baselines.bench.monitor import load_results
# matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
plt.rcParams['svg.fonttype'] = 'none'
X_TIMESTEPS = 'timesteps'
X_EPISODES = 'episodes'
X_WALLTIME = 'walltime_hrs'
POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
EPISODES_WINDOW = 100
COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
'brown', 'orange', 'teal', 'coral', 'lightblue', 'lime', 'lavender', 'turquoise',
'darkgreen', 'tan', 'salmon', 'gold', 'lightpurple', 'darkred', 'darkblue']
def rolling_window(array, window):
"""
apply a rolling window to a np.ndarray
:param array: (np.ndarray) the input Array
:param window: (int) length of the rolling window
:return: (np.ndarray) rolling window on the input array
"""
shape = array.shape[:-1] + (array.shape[-1] - window + 1, window)
strides = array.strides + (array.strides[-1],)
return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
def window_func(var_1, var_2, window, func):
"""
apply a function to the rolling window of 2 arrays
:param var_1: (np.ndarray) variable 1
:param var_2: (np.ndarray) variable 2
:param window: (int) length of the rolling window
:param func: (numpy function) function to apply on the rolling window on variable 2 (such as np.mean)
:return: (np.ndarray, np.ndarray) the rolling output with applied function
"""
var_2_window = rolling_window(var_2, window)
function_on_var2 = func(var_2_window, axis=-1)
return var_1[window - 1:], function_on_var2
def ts2xy(timesteps, xaxis):
"""
Decompose a timesteps variable to x ans ys
:param timesteps: (Pandas DataFrame) the input data
:param xaxis: (str) the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:return: (np.ndarray, np.ndarray) the x and y output
"""
if xaxis == X_TIMESTEPS:
x_var = np.cumsum(timesteps.l.values)
y_var = timesteps.r.values
elif xaxis == X_EPISODES:
x_var = np.arange(len(timesteps))
y_var = timesteps.r.values
elif xaxis == X_WALLTIME:
x_var = timesteps.t.values / 3600.
y_var = timesteps.r.values
else:
raise NotImplementedError
return x_var, y_var
def plot_curves(xy_list, xaxis, title):
"""
plot the curves
:param xy_list: ([(np.ndarray, np.ndarray)]) the x and y coordinates to plot
:param xaxis: (str) the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:param title: (str) the title of the plot
"""
plt.figure(figsize=(8, 2))
maxx = max(xy[0][-1] for xy in xy_list)
minx = 0
for (i, (x, y)) in enumerate(xy_list):
color = COLORS[i]
plt.scatter(x, y, s=2)
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
if x.shape[0] >= EPISODES_WINDOW:
# Compute and plot rolling mean with window of size EPISODE_WINDOW
x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
plt.plot(x, y_mean, color=color)
plt.xlim(minx, maxx)
plt.title(title)
plt.xlabel(xaxis)
plt.ylabel("Episode Rewards")
plt.tight_layout()
def plot_results(dirs, num_timesteps, xaxis, task_name):
"""
plot the results
:param dirs: ([str]) the save location of the results to plot
:param num_timesteps: (int or None) only plot the points below this value
:param xaxis: (str) the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:param task_name: (str) the title of the task to plot
"""
tslist = []
for folder in dirs:
timesteps = load_results(folder)
if num_timesteps is not None:
timesteps = timesteps[timesteps.l.cumsum() <= num_timesteps]
tslist.append(timesteps)
xy_list = [ts2xy(timesteps_item, xaxis) for timesteps_item in tslist]
plot_curves(xy_list, xaxis, task_name)
def main():
"""
Example usage in jupyter-notebook
.. code-block:: python
from stable_baselines import results_plotter
%matplotlib inline
results_plotter.plot_results(["./log"], 10e6, results_plotter.X_TIMESTEPS, "Breakout")
Here ./log is a directory containing the monitor.csv files
"""
import argparse
import os
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dirs', help='List of log directories', nargs='*', default=['./log'])
parser.add_argument('--num_timesteps', type=int, default=int(10e6))
parser.add_argument('--xaxis', help='Varible on X-axis', default=X_TIMESTEPS)
parser.add_argument('--task_name', help='Title of plot', default='Breakout')
args = parser.parse_args()
args.dirs = [os.path.abspath(folder) for folder in args.dirs]
plot_results(args.dirs, args.num_timesteps, args.xaxis, args.task_name)
plt.show()
if __name__ == '__main__':
main()