forked from mlpack/mlpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasync_learning.hpp
253 lines (227 loc) · 7.46 KB
/
async_learning.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
/**
* @file async_learning.hpp
* @author Shangtong Zhang
*
* This file is the definition of AsyncLearning class,
* which is wrapper for various asynchronous learning algorithms.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_RL_ASYNC_LEARNING_HPP
#define MLPACK_METHODS_RL_ASYNC_LEARNING_HPP
#include <mlpack/prereqs.hpp>
#include "worker/one_step_q_learning_worker.hpp"
#include "worker/one_step_continuous_q_learning_worker.hpp"
#include "worker/one_step_sarsa_worker.hpp"
#include "worker/n_step_q_learning_worker.hpp"
#include "training_config.hpp"
namespace mlpack {
namespace rl {
/**
* Wrapper of various asynchronous learning algorithms,
* e.g. async one-step Q-learning, async one-step Sarsa,
* async n-step Q-learning and async advantage actor-critic.
*
* For more details, see the following:
* @code
* @inproceedings{mnih2016asynchronous,
* title = {Asynchronous methods for deep reinforcement learning},
* author = {Mnih, Volodymyr and Badia, Adria Puigdomenech and Mirza,
* Mehdi and Graves, Alex and Lillicrap, Timothy and Harley,
* Tim and Silver, David and Kavukcuoglu, Koray},
* booktitle = {International Conference on Machine Learning},
* pages = {1928--1937},
* year = {2016}
* }
* @endcode
*
* @tparam WorkerType The type of the worker.
* @tparam EnvironmentType The type of reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename WorkerType,
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
class AsyncLearning
{
public:
/**
* Construct an instance of the given async learning algorithm.
*
* @param config Hyper-parameters for training.
* @param network The network model.
* @param policy The behavior policy.
* @param updater The optimizer.
* @param environment The reinforcement learning task.
*/
AsyncLearning(TrainingConfig config,
NetworkType network,
PolicyType policy,
UpdaterType updater = UpdaterType(),
EnvironmentType environment = EnvironmentType());
/**
* Starting async training.
*
* @tparam Measure The type of the measurement. It should be a
* callable object like
* @code
* bool foo(double reward);
* @endcode
* where reward is the total reward of a deterministic test episode,
* and the return value should indicate whether the training
* process is completed.
* @param measure The measurement instance.
*/
template <typename Measure>
void Train(Measure& measure);
//! Get training config.
TrainingConfig& Config() { return config; }
//! Modify training config.
const TrainingConfig& Config() const { return config; }
//! Get learning network.
NetworkType& Network() { return learningNetwork; }
//! Modify learning network.
const NetworkType& Network() const { return learningNetwork; }
//! Get behavior policy.
PolicyType& Policy() { return policy; }
//! Modify behavior policy.
const PolicyType& Policy() const { return policy; }
//! Get optimizer.
UpdaterType& Updater() { return updater; }
//! Modify optimizer.
const UpdaterType& Updater() const { return updater; }
//! Get the environment.
EnvironmentType& Environment() { return environment; }
//! Modify the environment.
const EnvironmentType& Environment() const { return environment; }
private:
//! Locally-stored hyper-parameters.
TrainingConfig config;
//! Locally-stored global learning network.
NetworkType learningNetwork;
//! Locally-stored policy.
PolicyType policy;
//! Locally-stored optimizer.
UpdaterType updater;
//! Locally-stored task.
EnvironmentType environment;
};
/**
* Forward declaration of OneStepQLearningWorker.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
class OneStepQLearningWorker;
/**
* Forward declaration of OneStepSarsaWorker.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
class OneStepSarsaWorker;
/**
* Forward declaration of NStepQLearningWorker.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
class NStepQLearningWorker;
/**
* Convenient typedef for async one step q-learning.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
using OneStepQLearning = AsyncLearning<OneStepQLearningWorker<EnvironmentType,
NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
UpdaterType, PolicyType>;
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
using OneStepContinuousQLearning = AsyncLearning<OneStepContinuousQLearningWorker<EnvironmentType,
NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
UpdaterType, PolicyType>;
/**
* Convenient typedef for async one step Sarsa.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
using OneStepSarsa = AsyncLearning<OneStepSarsaWorker<EnvironmentType,
NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
UpdaterType, PolicyType>;
/**
* Convenient typedef for async n step q-learning.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
using NStepQLearning = AsyncLearning<NStepQLearningWorker<EnvironmentType,
NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
UpdaterType, PolicyType>;
} // namespace rl
} // namespace mlpack
// Include implementation
#include "async_learning_impl.hpp"
#endif