Skip to content

Commit

Permalink
Add Soft Actor-Critic (SAC) (#326)
Browse files Browse the repository at this point in the history
* Added sac early version (gym 0.21)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix rtd ?

* Moved sac to gymnasium

* Updated sac demo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Some docstring updates

* Added demo_SAC docstring

* Fix logging

* Added tests for continus actions agents

* Fixed some attribute names, passes rlberry tests

* CI

* Added test_sac to tests

* Removed old experimental folder, updated changelog

* Added test with learning

* Fixed test params

* Clean docstring

* Add sac to hp opt test

* Add alpha test for sac

* Changed variable name (Pascal to snake case)

* Update api.rst

Added SACAgent to docu

* Update api.rst

* Update api.rst

added next line

* Update api.rst

* Update api.rst

reverted

* Revert "Add alpha test for sac"

This reverts commit 5ed8121.

* Changelog

* Changelog

* Doc fix attempt

* Remove rtd change

* Update documentation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Timothee Mathieu <[email protected]>
Co-authored-by: Hector Kohler <[email protected]>
  • Loading branch information
4 people authored Jul 24, 2023
1 parent 0cbbb24 commit 72d88d5
Show file tree
Hide file tree
Showing 17 changed files with 707 additions and 572 deletions.
9 changes: 1 addition & 8 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,13 @@ Torch Agents
:toctree: generated/
:template: class.rst

agents.torch.SACAgent
agents.torch.A2CAgent
agents.torch.PPOAgent
agents.torch.DQNAgent
agents.torch.MunchausenDQNAgent
agents.torch.REINFORCEAgent

Experimental torch agents
-------------------------

.. autosummary::
:toctree: generated/
:template: class.rst

agents.experimental.torch.SACAgent

Environments
============
Expand Down
5 changes: 4 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ Changelog

Dev version
-----------
*PR #326*

* Moved SAC from experimental to torch agents. Tested and benchmarked.

*PR #335*

* Upgrade from Python3.9 -> python3.10


Version 0.5.0
-------------

Expand Down
65 changes: 22 additions & 43 deletions examples/demo_agents/demo_SAC.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,43 @@
"""
=============================
Record reward during training
SAC Soft Actor-Critic
=============================
This script shows how to modify an agent to easily record reward or action
during the fit of the agent.
This script shows how to train a SAC agent on a Pendulum environment.
"""

import time

# import numpy as np
# from rlberry.wrappers import WriterWrapper
from rlberry.envs.basewrapper import Wrapper

# from rlberry.envs import gym_make
from rlberry.manager import plot_writer_data, ExperimentManager
from rlberry.envs.benchmarks.ball_exploration import PBall2D
from rlberry.agents.experimental.torch import SACAgent
import gymnasium as gym
from rlberry.agents.torch.sac import SACAgent
from rlberry.envs import Pendulum
from rlberry.manager import AgentManager


# we dont need wrapper actually just 'return env' works
def env_ctor(env, wrap_spaces=True):
return Wrapper(env, wrap_spaces)
return env


# Setup agent parameters
env_name = "Pendulum"
fit_budget = int(2e5)
agent_name = f"{env_name}_{fit_budget}_{int(time.time())}"

env = PBall2D()
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
# Setup environment parameters
env = Pendulum()
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
env = gym.wrappers.RecordEpisodeStatistics(env)
env_kwargs = dict(env=env)
agent = ExperimentManager(

# Create agent instance
agent = AgentManager(
SACAgent,
(env_ctor, env_kwargs),
fit_budget=500,
fit_budget=fit_budget,
n_fit=1,
enable_tensorboard=True,
agent_name=agent_name,
)

# basic version
# env_kwargs = dict(id = "CartPole-v1")
# agent = ExperimentManager(SACAgent, (gym_make, env_kwargs), fit_budget=200, n_fit=1)

# # timothe's
# env = gym_make("CartPole-v1")
# agent = ExperimentManager(
# SACAgent, (env.__class__, dict()), fit_budget=200, n_fit=1, enable_tensorboard=True,
# )

# Omar's
# env = gym_make("CartPole-v1")
# from copy import deepcopy
# def env_constructor():
# return deepcopy(env)
# agent = ExperimentManager(
# SACAgent, (env_constructor, dict()), fit_budget=200, n_fit=1, enable_tensorboard=True,
# )


# Start training
agent.fit()

# Plot of the cumulative reward.
output = plot_writer_data(agent, tag="loss_q1", title="Loss q1")
output = plot_writer_data(agent, tag="loss_q2", title="Loss q2")
output = plot_writer_data(agent, tag="loss_v", title="Loss critic")
output = plot_writer_data(agent, tag="loss_act", title="Loss actor")
48 changes: 0 additions & 48 deletions rlberry/agents/experimental/tests/test_sac.py

This file was deleted.

1 change: 0 additions & 1 deletion rlberry/agents/experimental/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .sac import SACAgent
8 changes: 0 additions & 8 deletions rlberry/agents/experimental/torch/sac/improvements.txt

This file was deleted.

Loading

0 comments on commit 72d88d5

Please sign in to comment.