Skip to content

Commit

Permalink
A2c matches sb3 (#161)
Browse files Browse the repository at this point in the history
* test for a2c , maybe need to have direct comparison with sb3

* faster runtime error tests for the rlberry a2c agent.

* blacked the a2c agent and test file

* fixed training horizon to 2 for the testing of the a2c agent.

* Fixed test file by removing plotting

* removed k_epochs params from A2C in the test_actor_critic file

Co-authored-by: Hector Kohler <[email protected]>
  • Loading branch information
KohlerHECTOR and Hector Kohler authored Apr 13, 2022
1 parent 8168dfc commit 80679a9
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 33 deletions.
53 changes: 22 additions & 31 deletions rlberry/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class A2CAgent(AgentWithSimplePolicy):
Learning rate.
optimizer_type: str
Type of optimizer. 'ADAM' by defaut.
k_epochs : int
Number of epochs per update.
policy_net_fn : function(env, **kwargs)
Function that returns an instance of a policy network (pytorch).
If None, a default net is used.
Expand Down Expand Up @@ -79,7 +77,6 @@ def __init__(
entr_coef=0.01,
learning_rate=0.01,
optimizer_type="ADAM",
k_epochs=5,
policy_net_fn=None,
value_net_fn=None,
policy_net_kwargs=None,
Expand All @@ -103,7 +100,6 @@ def __init__(
self.gamma = gamma
self.entr_coef = entr_coef
self.learning_rate = learning_rate
self.k_epochs = k_epochs
self.device = choose_device(device)

self.policy_net_kwargs = policy_net_kwargs or {}
Expand Down Expand Up @@ -243,33 +239,31 @@ def _update(self):
old_states = torch.stack(self.memory.states).to(self.device).detach()
old_actions = torch.stack(self.memory.actions).to(self.device).detach()

# optimize policy for K epochs
for _ in range(self.k_epochs):
# evaluate old actions and values
action_dist = self.cat_policy(old_states)
logprobs = action_dist.log_prob(old_actions)
state_values = torch.squeeze(self.value_net(old_states))
dist_entropy = action_dist.entropy()

# normalize the advantages
advantages = rewards - state_values.detach()
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# find pg loss
pg_loss = -logprobs * advantages
loss = (
pg_loss
+ 0.5 * self.MseLoss(state_values, rewards)
- self.entr_coef * dist_entropy
)
# evaluate old actions and values
action_dist = self.cat_policy(old_states)
logprobs = action_dist.log_prob(old_actions)
state_values = torch.squeeze(self.value_net(old_states))
dist_entropy = action_dist.entropy()

# normalize the advantages
advantages = rewards - state_values.detach()
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# find pg loss
pg_loss = -logprobs * advantages
loss = (
pg_loss
+ 0.5 * self.MseLoss(state_values, rewards)
- self.entr_coef * dist_entropy
)

# take gradient step
self.policy_optimizer.zero_grad()
self.value_optimizer.zero_grad()
# take gradient step
self.policy_optimizer.zero_grad()
self.value_optimizer.zero_grad()

loss.mean().backward()
loss.mean().backward()

self.policy_optimizer.step()
self.value_optimizer.step()
self.policy_optimizer.step()
self.value_optimizer.step()

# copy new weights into old policy
self.cat_policy_old.load_state_dict(self.cat_policy.state_dict())
Expand All @@ -285,12 +279,9 @@ def sample_parameters(cls, trial):

entr_coef = trial.suggest_loguniform("entr_coef", 1e-8, 0.1)

k_epochs = trial.suggest_categorical("k_epochs", [1, 5, 10, 20])

return {
"batch_size": batch_size,
"gamma": gamma,
"learning_rate": learning_rate,
"entr_coef": entr_coef,
"k_epochs": k_epochs,
}
66 changes: 66 additions & 0 deletions rlberry/agents/torch/tests/test_a2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from rlberry.envs import Wrapper
from rlberry.agents.torch import A2CAgent
from rlberry.manager import AgentManager, evaluate_agents
from rlberry.envs.benchmarks.ball_exploration import PBall2D
from gym import make


def test_a2c():

env = "CartPole-v0"
mdp = make(env)
env_ctor = Wrapper
env_kwargs = dict(env=mdp)

a2crlberry_stats = AgentManager(
A2CAgent,
(env_ctor, env_kwargs),
fit_budget=int(2),
eval_kwargs=dict(eval_horizon=2),
init_kwargs=dict(horizon=2),
n_fit=1,
agent_name="A2C_rlberry_" + env,
)

a2crlberry_stats.fit()

output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False)
a2crlberry_stats.clear_output_dir()

env = "Acrobot-v1"
mdp = make(env)
env_ctor = Wrapper
env_kwargs = dict(env=mdp)

a2crlberry_stats = AgentManager(
A2CAgent,
(env_ctor, env_kwargs),
fit_budget=int(2),
eval_kwargs=dict(eval_horizon=2),
init_kwargs=dict(horizon=2),
n_fit=1,
agent_name="A2C_rlberry_" + env,
)

a2crlberry_stats.fit()

output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False)
a2crlberry_stats.clear_output_dir()

env_ctor = PBall2D
env_kwargs = dict()

a2crlberry_stats = AgentManager(
A2CAgent,
(env_ctor, env_kwargs),
fit_budget=int(2),
eval_kwargs=dict(eval_horizon=2),
init_kwargs=dict(horizon=2),
n_fit=1,
agent_name="A2C_rlberry_" + "PBall2D",
)

a2crlberry_stats.fit()

output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False)
a2crlberry_stats.clear_output_dir()
2 changes: 0 additions & 2 deletions rlberry/agents/torch/tests/test_actor_critic_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def uncertainty_estimator_fn(observation_space, action_space):
horizon=horizon,
gamma=0.99,
learning_rate=0.001,
k_epochs=4,
use_bonus=True,
uncertainty_estimator_kwargs=dict(
uncertainty_estimator_fn=uncertainty_estimator_fn, bonus_scale_factor=1.0
Expand All @@ -39,7 +38,6 @@ def test_a2c_agent_partial_fit():
horizon=horizon,
gamma=0.99,
learning_rate=0.001,
k_epochs=4,
use_bonus=False,
)

Expand Down

0 comments on commit 80679a9

Please sign in to comment.