-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* test for a2c , maybe need to have direct comparison with sb3 * faster runtime error tests for the rlberry a2c agent. * blacked the a2c agent and test file * fixed training horizon to 2 for the testing of the a2c agent. * Fixed test file by removing plotting * removed k_epochs params from A2C in the test_actor_critic file Co-authored-by: Hector Kohler <[email protected]>
- Loading branch information
1 parent
8168dfc
commit 80679a9
Showing
3 changed files
with
88 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from rlberry.envs import Wrapper | ||
from rlberry.agents.torch import A2CAgent | ||
from rlberry.manager import AgentManager, evaluate_agents | ||
from rlberry.envs.benchmarks.ball_exploration import PBall2D | ||
from gym import make | ||
|
||
|
||
def test_a2c(): | ||
|
||
env = "CartPole-v0" | ||
mdp = make(env) | ||
env_ctor = Wrapper | ||
env_kwargs = dict(env=mdp) | ||
|
||
a2crlberry_stats = AgentManager( | ||
A2CAgent, | ||
(env_ctor, env_kwargs), | ||
fit_budget=int(2), | ||
eval_kwargs=dict(eval_horizon=2), | ||
init_kwargs=dict(horizon=2), | ||
n_fit=1, | ||
agent_name="A2C_rlberry_" + env, | ||
) | ||
|
||
a2crlberry_stats.fit() | ||
|
||
output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False) | ||
a2crlberry_stats.clear_output_dir() | ||
|
||
env = "Acrobot-v1" | ||
mdp = make(env) | ||
env_ctor = Wrapper | ||
env_kwargs = dict(env=mdp) | ||
|
||
a2crlberry_stats = AgentManager( | ||
A2CAgent, | ||
(env_ctor, env_kwargs), | ||
fit_budget=int(2), | ||
eval_kwargs=dict(eval_horizon=2), | ||
init_kwargs=dict(horizon=2), | ||
n_fit=1, | ||
agent_name="A2C_rlberry_" + env, | ||
) | ||
|
||
a2crlberry_stats.fit() | ||
|
||
output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False) | ||
a2crlberry_stats.clear_output_dir() | ||
|
||
env_ctor = PBall2D | ||
env_kwargs = dict() | ||
|
||
a2crlberry_stats = AgentManager( | ||
A2CAgent, | ||
(env_ctor, env_kwargs), | ||
fit_budget=int(2), | ||
eval_kwargs=dict(eval_horizon=2), | ||
init_kwargs=dict(horizon=2), | ||
n_fit=1, | ||
agent_name="A2C_rlberry_" + "PBall2D", | ||
) | ||
|
||
a2crlberry_stats.fit() | ||
|
||
output = evaluate_agents([a2crlberry_stats], n_simulations=2, plot=False) | ||
a2crlberry_stats.clear_output_dir() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters