diff --git a/rlberry/agents/torch/a2c/a2c.py b/rlberry/agents/torch/a2c/a2c.py index 76f900abb..fd31c4a86 100644 --- a/rlberry/agents/torch/a2c/a2c.py +++ b/rlberry/agents/torch/a2c/a2c.py @@ -40,8 +40,6 @@ class A2CAgent(AgentWithSimplePolicy): Learning rate. optimizer_type: str Type of optimizer. 'ADAM' by defaut. - k_epochs : int - Number of epochs per update. policy_net_fn : function(env, **kwargs) Function that returns an instance of a policy network (pytorch). If None, a default net is used. @@ -79,7 +77,6 @@ def __init__( entr_coef=0.01, learning_rate=0.01, optimizer_type="ADAM", - k_epochs=5, policy_net_fn=None, value_net_fn=None, policy_net_kwargs=None, @@ -103,7 +100,6 @@ def __init__( self.gamma = gamma self.entr_coef = entr_coef self.learning_rate = learning_rate - self.k_epochs = k_epochs self.device = choose_device(device) self.policy_net_kwargs = policy_net_kwargs or {} @@ -243,33 +239,31 @@ def _update(self): old_states = torch.stack(self.memory.states).to(self.device).detach() old_actions = torch.stack(self.memory.actions).to(self.device).detach() - # optimize policy for K epochs - for _ in range(self.k_epochs): - # evaluate old actions and values - action_dist = self.cat_policy(old_states) - logprobs = action_dist.log_prob(old_actions) - state_values = torch.squeeze(self.value_net(old_states)) - dist_entropy = action_dist.entropy() - - # normalize the advantages - advantages = rewards - state_values.detach() - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) - # find pg loss - pg_loss = -logprobs * advantages - loss = ( - pg_loss - + 0.5 * self.MseLoss(state_values, rewards) - - self.entr_coef * dist_entropy - ) + # evaluate old actions and values + action_dist = self.cat_policy(old_states) + logprobs = action_dist.log_prob(old_actions) + state_values = torch.squeeze(self.value_net(old_states)) + dist_entropy = action_dist.entropy() + + # normalize the advantages + advantages = rewards - state_values.detach() + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + # find pg loss + pg_loss = -logprobs * advantages + loss = ( + pg_loss + + 0.5 * self.MseLoss(state_values, rewards) + - self.entr_coef * dist_entropy + ) - # take gradient step - self.policy_optimizer.zero_grad() - self.value_optimizer.zero_grad() + # take gradient step + self.policy_optimizer.zero_grad() + self.value_optimizer.zero_grad() - loss.mean().backward() + loss.mean().backward() - self.policy_optimizer.step() - self.value_optimizer.step() + self.policy_optimizer.step() + self.value_optimizer.step() # copy new weights into old policy self.cat_policy_old.load_state_dict(self.cat_policy.state_dict()) @@ -285,12 +279,9 @@ def sample_parameters(cls, trial): entr_coef = trial.suggest_loguniform("entr_coef", 1e-8, 0.1) - k_epochs = trial.suggest_categorical("k_epochs", [1, 5, 10, 20]) - return { "batch_size": batch_size, "gamma": gamma, "learning_rate": learning_rate, "entr_coef": entr_coef, - "k_epochs": k_epochs, } diff --git a/rlberry/agents/torch/tests/test_a2c.py b/rlberry/agents/torch/tests/test_a2c.py new file mode 100644 index 000000000..da5adfaf0 --- /dev/null +++ b/rlberry/agents/torch/tests/test_a2c.py @@ -0,0 +1,66 @@ +from rlberry.envs import Wrapper +from rlberry.agents.torch import A2CAgent +from rlberry.manager import AgentManager, evaluate_agents +from rlberry.envs.benchmarks.ball_exploration import PBall2D +from gym import make + + +def test_a2c(): + + env = "CartPole-v0" + mdp = make(env) + env_ctor = Wrapper + env_kwargs = dict(env=mdp) + + a2crlberry_stats = AgentManager( + A2CAgent, + (env_ctor, env_kwargs), + fit_budget=int(2), + eval_kwargs=dict(eval_horizon=2), + init_kwargs=dict(horizon=2), + n_fit=1, + agent_name="A2C_rlberry_" + env, + ) + + a2crlberry_stats.fit() + + output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False) + a2crlberry_stats.clear_output_dir() + + env = "Acrobot-v1" + mdp = make(env) + env_ctor = Wrapper + env_kwargs = dict(env=mdp) + + a2crlberry_stats = AgentManager( + A2CAgent, + (env_ctor, env_kwargs), + fit_budget=int(2), + eval_kwargs=dict(eval_horizon=2), + init_kwargs=dict(horizon=2), + n_fit=1, + agent_name="A2C_rlberry_" + env, + ) + + a2crlberry_stats.fit() + + output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False) + a2crlberry_stats.clear_output_dir() + + env_ctor = PBall2D + env_kwargs = dict() + + a2crlberry_stats = AgentManager( + A2CAgent, + (env_ctor, env_kwargs), + fit_budget=int(2), + eval_kwargs=dict(eval_horizon=2), + init_kwargs=dict(horizon=2), + n_fit=1, + agent_name="A2C_rlberry_" + "PBall2D", + ) + + a2crlberry_stats.fit() + + output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False) + a2crlberry_stats.clear_output_dir() diff --git a/rlberry/agents/torch/tests/test_actor_critic_algos.py b/rlberry/agents/torch/tests/test_actor_critic_algos.py index c76b83728..f1b83f001 100644 --- a/rlberry/agents/torch/tests/test_actor_critic_algos.py +++ b/rlberry/agents/torch/tests/test_actor_critic_algos.py @@ -19,7 +19,6 @@ def uncertainty_estimator_fn(observation_space, action_space): horizon=horizon, gamma=0.99, learning_rate=0.001, - k_epochs=4, use_bonus=True, uncertainty_estimator_kwargs=dict( uncertainty_estimator_fn=uncertainty_estimator_fn, bonus_scale_factor=1.0 @@ -39,7 +38,6 @@ def test_a2c_agent_partial_fit(): horizon=horizon, gamma=0.99, learning_rate=0.001, - k_epochs=4, use_bonus=False, )