Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update module, buffer and driver for off-policy algorithm #74

Merged
merged 1 commit into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 94 additions & 61 deletions openrl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.nn.functional as F

from openrl.algorithms.base_algorithm import BaseAlgorithm
from openrl.modules.networks.utils.distributed_utils import reduce_tensor
Expand All @@ -37,102 +37,72 @@ def __init__(
agent_num: int = 1,
device: Union[str, torch.device] = "cpu",
) -> None:
self._use_share_model = cfg.use_share_model
self.use_joint_action_loss = cfg.use_joint_action_loss
super(DQNAlgorithm, self).__init__(cfg, init_module, agent_num, device)

self.gamma = cfg.gamma

def dqn_update(self, sample, turn_on=True):
for optimizer in self.algo_module.optimizers.values():
optimizer.zero_grad()

(
obs_batch,
_,
next_obs_batch,
_,
rnn_states_batch,
rnn_states_critic_batch,
actions_batch,
value_preds_batch,
return_batch,
rewards_batch,
masks_batch,
active_masks_batch,
old_action_log_probs_batch,
adv_targ,
available_actions_batch,
) = sample

value_preds_batch = check(value_preds_batch).to(**self.tpdv)
return_batch = check(return_batch).to(**self.tpdv)
rewards_batch = check(rewards_batch).to(**self.tpdv)
active_masks_batch = check(active_masks_batch).to(**self.tpdv)

if self.use_amp:
with torch.cuda.amp.autocast():
(
loss_list,
value_loss,
policy_loss,
dist_entropy,
ratio,
) = self.prepare_loss(
loss_list = self.prepare_loss(
obs_batch,
next_obs_batch,
rnn_states_batch,
actions_batch,
masks_batch,
available_actions_batch,
value_preds_batch,
return_batch,
rewards_batch,
active_masks_batch,
turn_on,
)
for loss in loss_list:
self.algo_module.scaler.scale(loss).backward()
else:
loss_list, value_loss, policy_loss, dist_entropy, ratio = self.prepare_loss(
loss_list = self.prepare_loss(
obs_batch,
next_obs_batch,
rnn_states_batch,
actions_batch,
masks_batch,
available_actions_batch,
value_preds_batch,
return_batch,
rewards_batch,
active_masks_batch,
turn_on,
)
for loss in loss_list:
loss.backward()

if "transformer" in self.algo_module.models:
if self._use_max_grad_norm:
grad_norm = nn.utils.clip_grad_norm_(
self.algo_module.models["transformer"].parameters(),
self.max_grad_norm,
)
else:
grad_norm = get_gard_norm(
self.algo_module.models["transformer"].parameters()
)
critic_grad_norm = grad_norm
actor_grad_norm = grad_norm

raise NotImplementedError
else:
if self._use_share_model:
actor_para = self.algo_module.models["model"].get_actor_para()
else:
actor_para = self.algo_module.models["policy"].parameters()

if self._use_max_grad_norm:
actor_grad_norm = nn.utils.clip_grad_norm_(
actor_para, self.max_grad_norm
)
else:
actor_grad_norm = get_gard_norm(actor_para)

if self._use_share_model:
critic_para = self.algo_module.models["model"].get_critic_para()
else:
critic_para = self.algo_module.models["critic"].parameters()

if self._use_max_grad_norm:
critic_grad_norm = nn.utils.clip_grad_norm_(
critic_para, self.max_grad_norm
)
else:
critic_grad_norm = get_gard_norm(critic_para)
actor_para = self.algo_module.models["q_net"].parameters()
actor_grad_norm = get_gard_norm(actor_para)

if self.use_amp:
for optimizer in self.algo_module.optimizers.values():
Expand All @@ -149,14 +119,7 @@ def dqn_update(self, sample, turn_on=True):
if self.world_size > 1:
torch.cuda.synchronize()

return (
value_loss,
critic_grad_norm,
policy_loss,
dist_entropy,
actor_grad_norm,
ratio,
)
return loss

def cal_value_loss(
self,
Expand Down Expand Up @@ -208,16 +171,86 @@ def to_single_np(self, input):
def prepare_loss(
self,
obs_batch,
next_obs_batch,
rnn_states_batch,
actions_batch,
masks_batch,
available_actions_batch,
value_preds_batch,
return_batch,
rewards_batch,
active_masks_batch,
turn_on,
):
raise NotImplementedError
loss_list = []
critic_masks_batch = masks_batch

(
q_values,
max_next_q_values
) = self.algo_module.evaluate_actions(
obs_batch,
next_obs_batch,
rnn_states_batch,
rewards_batch,
actions_batch,
masks_batch,
available_actions_batch,
active_masks_batch,
critic_masks_batch=critic_masks_batch,
)

q_targets = rewards_batch + self.gamma * max_next_q_values
q_loss = torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数

loss_list.append(q_loss)
return loss_list

def train(self, buffer, turn_on=True):
raise NotImplementedError
train_info = {}

train_info["q_loss"] = 0

if self.world_size > 1:
train_info["reduced_q_loss"] = 0

# todo add rnn and transformer
# update once
for _ in range(1):
if "transformer" in self.algo_module.models:
raise NotImplementedError
elif self._use_recurrent_policy:
raise NotImplementedError
elif self._use_naive_recurrent:
raise NotImplementedError
else:
data_generator = buffer.feed_forward_generator(
_, self.num_mini_batch
)

for sample in data_generator:
(
q_loss,
critic_grad_norm,
policy_loss,
dist_entropy,
actor_grad_norm,
ratio,
) = self.dqn_update(sample, turn_on)

if self.world_size > 1:
train_info["reduced_q_loss"] += reduce_tensor(
q_loss.data, self.world_size
)

train_info["q_loss"] += q_loss.item()

num_updates = 1 * self.num_mini_batch

for k in train_info.keys():
train_info[k] /= num_updates

for optimizer in self.algo_module.optimizers.values():
if hasattr(optimizer, "sync_lookahead"):
optimizer.sync_lookahead()

return train_info
30 changes: 30 additions & 0 deletions openrl/buffers/offpolicy_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,36 @@ def __init__(
episode_length,
)

def insert(
self,
raw_obs,
next_raw_obs,
rnn_states,
rnn_states_critic,
actions,
action_log_probs,
value_preds,
rewards,
masks,
bad_masks=None,
active_masks=None,
available_actions=None,
):
self.data.insert(
raw_obs,
next_raw_obs,
rnn_states,
rnn_states_critic,
actions,
action_log_probs,
value_preds,
rewards,
masks,
bad_masks,
active_masks,
available_actions,
)

def get_buffer_size(self):
if self.data.first_insert_flag:
return self.data.step
Expand Down
Loading