-
Notifications
You must be signed in to change notification settings - Fork 523
/
Copy pathrl_trainer_pytorch.py
312 lines (277 loc) · 11.5 KB
/
rl_trainer_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import logging
from typing import List, Optional
import torch
import torch.nn.functional as F
from reagent.core.parameters import EvaluationParameters, RLParameters
from reagent.core.torch_utils import masked_softmax
from reagent.optimizer.union import Optimizer__Union
from reagent.training.loss_reporter import LossReporter
from reagent.training.trainer import Trainer
logger = logging.getLogger(__name__)
# pyre-fixme[13]: Attribute `rl_parameters` is never initialized.
class RLTrainerMixin:
# todo potential inconsistencies
_use_seq_num_diff_as_time_diff = None
_maxq_learning = None
_multi_steps = None
rl_parameters: RLParameters
@property
def gamma(self) -> float:
return self.rl_parameters.gamma
@property
def tau(self) -> float:
return self.rl_parameters.target_update_rate
@property
def multi_steps(self) -> Optional[int]:
return (
self.rl_parameters.multi_steps
if self._multi_steps is None
else self._multi_steps
)
@multi_steps.setter
def multi_steps(self, multi_steps):
self._multi_steps = multi_steps
@property
def maxq_learning(self) -> bool:
return (
self.rl_parameters.maxq_learning
if self._maxq_learning is None
else self._maxq_learning
)
@maxq_learning.setter
def maxq_learning(self, maxq_learning):
self._maxq_learning = maxq_learning
@property
def use_seq_num_diff_as_time_diff(self) -> bool:
return (
self.rl_parameters.use_seq_num_diff_as_time_diff
if self._use_seq_num_diff_as_time_diff is None
else self._use_seq_num_diff_as_time_diff
)
@use_seq_num_diff_as_time_diff.setter
def use_seq_num_diff_as_time_diff(self, use_seq_num_diff_as_time_diff):
self._use_seq_num_diff_as_time_diff = use_seq_num_diff_as_time_diff
@property
def rl_temperature(self) -> float:
return self.rl_parameters.temperature
class RLTrainer(RLTrainerMixin, Trainer):
# Q-value for action that is not possible. Guaranteed to be worse than any
# legitimate action
ACTION_NOT_POSSIBLE_VAL = -1e9
# Hack to mark legitimate 0 value q-values before pytorch sparse -> dense
FINGERPRINT = 12345
def __init__(
self,
rl_parameters: RLParameters,
use_gpu: bool,
metrics_to_score=None,
actions: Optional[List[str]] = None,
evaluation_parameters: Optional[EvaluationParameters] = None,
loss_reporter=None,
) -> None:
super().__init__()
self.minibatch = 0
self.minibatch_size: Optional[int] = None
self.minibatches_per_step: Optional[int] = None
self.rl_parameters = rl_parameters
self.time_diff_unit_length = rl_parameters.time_diff_unit_length
self.tensorboard_logging_freq = rl_parameters.tensorboard_logging_freq
self.calc_cpe_in_training = (
evaluation_parameters and evaluation_parameters.calc_cpe_in_training
)
if rl_parameters.q_network_loss == "mse":
self.q_network_loss = F.mse_loss
elif rl_parameters.q_network_loss == "huber":
self.q_network_loss = F.smooth_l1_loss
else:
raise Exception(
"Q-Network loss type {} not valid loss.".format(
rl_parameters.q_network_loss
)
)
if metrics_to_score:
self.metrics_to_score = metrics_to_score + ["reward"]
else:
self.metrics_to_score = ["reward"]
cuda_available = torch.cuda.is_available()
logger.info("CUDA availability: {}".format(cuda_available))
if use_gpu and cuda_available:
logger.info("Using GPU: GPU requested and available.")
self.use_gpu = True
self.device = torch.device("cuda")
else:
logger.info("NOT Using GPU: GPU not requested or not available.")
self.use_gpu = False
self.device = torch.device("cpu")
self.loss_reporter = loss_reporter or LossReporter(actions)
self._actions = actions
@property
def num_actions(self) -> int:
assert self._actions is not None, "Not a discrete action DQN"
# pyre-fixme[6]: Expected `Sized` for 1st param but got `Optional[List[str]]`.
return len(self._actions)
def _initialize_cpe(
self,
reward_network,
q_network_cpe,
q_network_cpe_target,
optimizer: Optimizer__Union,
) -> None:
if self.calc_cpe_in_training:
assert reward_network is not None, "reward_network is required for CPE"
# pyre-fixme[16]: `RLTrainer` has no attribute `reward_network`.
self.reward_network = reward_network
# pyre-fixme[16]: `RLTrainer` has no attribute `reward_network_optimizer`.
self.reward_network_optimizer = optimizer.make_optimizer(
self.reward_network.parameters()
)
assert (
q_network_cpe is not None and q_network_cpe_target is not None
), "q_network_cpe and q_network_cpe_target are required for CPE"
# pyre-fixme[16]: `RLTrainer` has no attribute `q_network_cpe`.
self.q_network_cpe = q_network_cpe
# pyre-fixme[16]: `RLTrainer` has no attribute `q_network_cpe_target`.
self.q_network_cpe_target = q_network_cpe_target
# pyre-fixme[16]: `RLTrainer` has no attribute `q_network_cpe_optimizer`.
self.q_network_cpe_optimizer = optimizer.make_optimizer(
self.q_network_cpe.parameters()
)
num_output_nodes = len(self.metrics_to_score) * self.num_actions
# pyre-fixme[16]: `RLTrainer` has no attribute `reward_idx_offsets`.
self.reward_idx_offsets = torch.arange(
0,
num_output_nodes,
self.num_actions,
device=self.device,
dtype=torch.long,
)
else:
self.reward_network = None
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
# its type `no_grad` is not callable.
@torch.no_grad()
def _soft_update(self, network, target_network, tau) -> None:
"""Target network update logic as defined in DDPG paper
updated_params = tau * network_params + (1 - tau) * target_network_params
:param network network with parameters to include in soft update
:param target_network target network with params to soft update
:param tau hyperparameter to control target tracking speed
"""
for t_param, param in zip(target_network.parameters(), network.parameters()):
if t_param is param:
# Skip soft-updating when the target network shares the parameter with
# the network being train.
continue
new_param = tau * param.data + (1.0 - tau) * t_param.data
t_param.data.copy_(new_param)
# pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because
# its type `no_grad` is not callable.
@torch.no_grad()
def _maybe_soft_update(
self, network, target_network, tau, minibatches_per_step
) -> None:
if self.minibatch % minibatches_per_step != 0:
return
self._soft_update(network, target_network, tau)
def _maybe_run_optimizer(self, optimizer, minibatches_per_step) -> None:
if self.minibatch % minibatches_per_step != 0:
return
for group in optimizer.param_groups:
for p in group["params"]:
if p.grad is not None:
p.grad /= minibatches_per_step
optimizer.step()
optimizer.zero_grad()
@torch.no_grad()
def _calculate_cpes(
self,
training_batch,
states,
next_states,
all_action_scores,
all_next_action_scores,
logged_action_idxs,
discount_tensor,
not_done_mask,
):
if not self.calc_cpe_in_training:
return None, None, None
if training_batch.extras.metrics is None:
metrics_reward_concat_real_vals = training_batch.reward
else:
metrics_reward_concat_real_vals = torch.cat(
(training_batch.reward, training_batch.extras.metrics), dim=1
)
model_propensities_next_states = masked_softmax(
all_next_action_scores,
training_batch.possible_next_actions_mask
if self.maxq_learning
else training_batch.next_action,
self.rl_temperature,
)
with torch.enable_grad():
######### Train separate reward network for CPE evaluation #############
reward_estimates = self.reward_network(states)
reward_estimates_for_logged_actions = reward_estimates.gather(
1, self.reward_idx_offsets + logged_action_idxs
)
reward_loss = F.mse_loss(
reward_estimates_for_logged_actions, metrics_reward_concat_real_vals
)
reward_loss.backward()
self._maybe_run_optimizer(
self.reward_network_optimizer, self.minibatches_per_step
)
######### Train separate q-network for CPE evaluation #############
metric_q_values = self.q_network_cpe(states).gather(
1, self.reward_idx_offsets + logged_action_idxs
)
all_metrics_target_q_values = torch.chunk(
self.q_network_cpe_target(next_states).detach(),
len(self.metrics_to_score),
dim=1,
)
target_metric_q_values = []
for i, per_metric_target_q_values in enumerate(all_metrics_target_q_values):
per_metric_next_q_values = torch.sum(
per_metric_target_q_values * model_propensities_next_states,
1,
keepdim=True,
)
per_metric_next_q_values = per_metric_next_q_values * not_done_mask
per_metric_target_q_values = metrics_reward_concat_real_vals[
:, i : i + 1
] + (discount_tensor * per_metric_next_q_values)
target_metric_q_values.append(per_metric_target_q_values)
target_metric_q_values = torch.cat(target_metric_q_values, dim=1)
metric_q_value_loss = self.q_network_loss(
metric_q_values, target_metric_q_values
)
metric_q_value_loss.backward()
self._maybe_run_optimizer(
self.q_network_cpe_optimizer, self.minibatches_per_step
)
# Use the soft update rule to update target network
self._maybe_soft_update(
self.q_network_cpe,
self.q_network_cpe_target,
self.tau,
self.minibatches_per_step,
)
model_propensities = masked_softmax(
all_action_scores,
training_batch.possible_actions_mask
if self.maxq_learning
else training_batch.action,
self.rl_temperature,
)
model_rewards = reward_estimates[
:,
torch.arange(
self.reward_idx_offsets[0],
self.reward_idx_offsets[0] + self.num_actions,
),
]
return reward_loss, model_rewards, model_propensities