-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added sac early version (gym 0.21) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rtd ? * Moved sac to gymnasium * Updated sac demo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Some docstring updates * Added demo_SAC docstring * Fix logging * Added tests for continus actions agents * Fixed some attribute names, passes rlberry tests * CI * Added test_sac to tests * Removed old experimental folder, updated changelog * Added test with learning * Fixed test params * Clean docstring * Add sac to hp opt test * Add alpha test for sac * Changed variable name (Pascal to snake case) * Update api.rst Added SACAgent to docu * Update api.rst * Update api.rst added next line * Update api.rst * Update api.rst reverted * Revert "Add alpha test for sac" This reverts commit 5ed8121. * Changelog * Changelog * Doc fix attempt * Remove rtd change * Update documentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Timothee Mathieu <[email protected]> Co-authored-by: Hector Kohler <[email protected]>
- Loading branch information
1 parent
0cbbb24
commit 72d88d5
Showing
17 changed files
with
707 additions
and
572 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,43 @@ | ||
""" | ||
============================= | ||
Record reward during training | ||
SAC Soft Actor-Critic | ||
============================= | ||
This script shows how to modify an agent to easily record reward or action | ||
during the fit of the agent. | ||
This script shows how to train a SAC agent on a Pendulum environment. | ||
""" | ||
|
||
import time | ||
|
||
# import numpy as np | ||
# from rlberry.wrappers import WriterWrapper | ||
from rlberry.envs.basewrapper import Wrapper | ||
|
||
# from rlberry.envs import gym_make | ||
from rlberry.manager import plot_writer_data, ExperimentManager | ||
from rlberry.envs.benchmarks.ball_exploration import PBall2D | ||
from rlberry.agents.experimental.torch import SACAgent | ||
import gymnasium as gym | ||
from rlberry.agents.torch.sac import SACAgent | ||
from rlberry.envs import Pendulum | ||
from rlberry.manager import AgentManager | ||
|
||
|
||
# we dont need wrapper actually just 'return env' works | ||
def env_ctor(env, wrap_spaces=True): | ||
return Wrapper(env, wrap_spaces) | ||
return env | ||
|
||
|
||
# Setup agent parameters | ||
env_name = "Pendulum" | ||
fit_budget = int(2e5) | ||
agent_name = f"{env_name}_{fit_budget}_{int(time.time())}" | ||
|
||
env = PBall2D() | ||
env = gym.wrappers.TimeLimit(env, max_episode_steps=100) | ||
# Setup environment parameters | ||
env = Pendulum() | ||
env = gym.wrappers.TimeLimit(env, max_episode_steps=200) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
env_kwargs = dict(env=env) | ||
agent = ExperimentManager( | ||
|
||
# Create agent instance | ||
agent = AgentManager( | ||
SACAgent, | ||
(env_ctor, env_kwargs), | ||
fit_budget=500, | ||
fit_budget=fit_budget, | ||
n_fit=1, | ||
enable_tensorboard=True, | ||
agent_name=agent_name, | ||
) | ||
|
||
# basic version | ||
# env_kwargs = dict(id = "CartPole-v1") | ||
# agent = ExperimentManager(SACAgent, (gym_make, env_kwargs), fit_budget=200, n_fit=1) | ||
|
||
# # timothe's | ||
# env = gym_make("CartPole-v1") | ||
# agent = ExperimentManager( | ||
# SACAgent, (env.__class__, dict()), fit_budget=200, n_fit=1, enable_tensorboard=True, | ||
# ) | ||
|
||
# Omar's | ||
# env = gym_make("CartPole-v1") | ||
# from copy import deepcopy | ||
# def env_constructor(): | ||
# return deepcopy(env) | ||
# agent = ExperimentManager( | ||
# SACAgent, (env_constructor, dict()), fit_budget=200, n_fit=1, enable_tensorboard=True, | ||
# ) | ||
|
||
|
||
# Start training | ||
agent.fit() | ||
|
||
# Plot of the cumulative reward. | ||
output = plot_writer_data(agent, tag="loss_q1", title="Loss q1") | ||
output = plot_writer_data(agent, tag="loss_q2", title="Loss q2") | ||
output = plot_writer_data(agent, tag="loss_v", title="Loss critic") | ||
output = plot_writer_data(agent, tag="loss_act", title="Loss actor") |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +0,0 @@ | ||
from .sac import SACAgent | ||
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.