Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Soft Actor-Critic (SAC) #326

Merged
merged 38 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f63704d
Added sac early version (gym 0.21)
brahimdriss Jun 29, 2023
95f8b53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2023
49e574a
fix rtd ?
TimotheeMathieu Jun 30, 2023
ea87a05
Merge pull request #1 from TimotheeMathieu/doc
brahimdriss Jun 30, 2023
31e597a
Moved sac to gymnasium
brahimdriss Jun 30, 2023
9997f15
Merge branch 'sac' of https://github.com/brahimdriss/rlberry into sac
brahimdriss Jun 30, 2023
ed3214c
Updated sac demo
brahimdriss Jun 30, 2023
4692f05
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2023
98cb7ec
Some docstring updates
brahimdriss Jun 30, 2023
d6e0637
Added demo_SAC docstring
brahimdriss Jul 3, 2023
e45de74
Fix logging
brahimdriss Jul 3, 2023
aad4791
Merge branch 'sac' of https://github.com/brahimdriss/rlberry into sac
brahimdriss Jul 3, 2023
9a14784
Added tests for continus actions agents
brahimdriss Jul 3, 2023
7657390
Fixed some attribute names, passes rlberry tests
brahimdriss Jul 3, 2023
3c3c1d6
CI
brahimdriss Jul 4, 2023
a726e2d
Added test_sac to tests
brahimdriss Jul 7, 2023
be14eb0
Removed old experimental folder, updated changelog
brahimdriss Jul 7, 2023
604a0e1
Added test with learning
brahimdriss Jul 7, 2023
7667f8e
Fixed test params
brahimdriss Jul 7, 2023
21b4d5c
Clean docstring
brahimdriss Jul 7, 2023
18d8eed
Add sac to hp opt test
brahimdriss Jul 7, 2023
5ed8121
Add alpha test for sac
brahimdriss Jul 7, 2023
9dfd91a
Changed variable name (Pascal to snake case)
brahimdriss Jul 11, 2023
456c6f5
Update api.rst
KohlerHECTOR Jul 24, 2023
a83b831
Update api.rst
KohlerHECTOR Jul 24, 2023
bb6b479
Merge branch 'main' into sac
KohlerHECTOR Jul 24, 2023
1516183
Update api.rst
KohlerHECTOR Jul 24, 2023
e9c9f60
Update api.rst
KohlerHECTOR Jul 24, 2023
47c9865
Update api.rst
KohlerHECTOR Jul 24, 2023
f812aec
Revert "Add alpha test for sac"
brahimdriss Jul 24, 2023
d12cdd8
Merge branch 'sac' of https://github.com/brahimdriss/rlberry into sac
brahimdriss Jul 24, 2023
030e7b6
Changelog
brahimdriss Jul 24, 2023
d6efcbe
Changelog
brahimdriss Jul 24, 2023
78cb8ee
Doc fix attempt
brahimdriss Jul 24, 2023
8aa654a
Remove rtd change
brahimdriss Jul 24, 2023
99e2ccb
Update documentation
brahimdriss Jul 24, 2023
a70e492
Merge branch 'main' into sac
brahimdriss Jul 24, 2023
16ec4b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,13 @@ Torch Agents
:toctree: generated/
:template: class.rst

agents.torch.SACAgent
agents.torch.A2CAgent
agents.torch.PPOAgent
agents.torch.DQNAgent
agents.torch.MunchausenDQNAgent
agents.torch.REINFORCEAgent

Experimental torch agents
-------------------------

.. autosummary::
:toctree: generated/
:template: class.rst

agents.experimental.torch.SACAgent

Environments
============
Expand Down
5 changes: 4 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ Changelog

Dev version
-----------
*PR #326*

* Moved SAC from experimental to torch agents. Tested and benchmarked.

*PR #335*

* Upgrade from Python3.9 -> python3.10


Version 0.5.0
-------------

Expand Down
65 changes: 22 additions & 43 deletions examples/demo_agents/demo_SAC.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,43 @@
"""
=============================
Record reward during training
SAC Soft Actor-Critic
=============================

This script shows how to modify an agent to easily record reward or action
during the fit of the agent.
This script shows how to train a SAC agent on a Pendulum environment.
"""

import time

# import numpy as np
# from rlberry.wrappers import WriterWrapper
from rlberry.envs.basewrapper import Wrapper

# from rlberry.envs import gym_make
from rlberry.manager import plot_writer_data, ExperimentManager
from rlberry.envs.benchmarks.ball_exploration import PBall2D
from rlberry.agents.experimental.torch import SACAgent
import gymnasium as gym
from rlberry.agents.torch.sac import SACAgent
from rlberry.envs import Pendulum
from rlberry.manager import AgentManager


# we dont need wrapper actually just 'return env' works
def env_ctor(env, wrap_spaces=True):
return Wrapper(env, wrap_spaces)
return env


# Setup agent parameters
env_name = "Pendulum"
fit_budget = int(2e5)
agent_name = f"{env_name}_{fit_budget}_{int(time.time())}"

env = PBall2D()
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
# Setup environment parameters
env = Pendulum()
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
env = gym.wrappers.RecordEpisodeStatistics(env)
env_kwargs = dict(env=env)
agent = ExperimentManager(

# Create agent instance
agent = AgentManager(
SACAgent,
(env_ctor, env_kwargs),
fit_budget=500,
fit_budget=fit_budget,
n_fit=1,
enable_tensorboard=True,
agent_name=agent_name,
)

# basic version
# env_kwargs = dict(id = "CartPole-v1")
# agent = ExperimentManager(SACAgent, (gym_make, env_kwargs), fit_budget=200, n_fit=1)

# # timothe's
# env = gym_make("CartPole-v1")
# agent = ExperimentManager(
# SACAgent, (env.__class__, dict()), fit_budget=200, n_fit=1, enable_tensorboard=True,
# )

# Omar's
# env = gym_make("CartPole-v1")
# from copy import deepcopy
# def env_constructor():
# return deepcopy(env)
# agent = ExperimentManager(
# SACAgent, (env_constructor, dict()), fit_budget=200, n_fit=1, enable_tensorboard=True,
# )


# Start training
agent.fit()

# Plot of the cumulative reward.
output = plot_writer_data(agent, tag="loss_q1", title="Loss q1")
output = plot_writer_data(agent, tag="loss_q2", title="Loss q2")
output = plot_writer_data(agent, tag="loss_v", title="Loss critic")
output = plot_writer_data(agent, tag="loss_act", title="Loss actor")
48 changes: 0 additions & 48 deletions rlberry/agents/experimental/tests/test_sac.py

This file was deleted.

1 change: 0 additions & 1 deletion rlberry/agents/experimental/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .sac import SACAgent
8 changes: 0 additions & 8 deletions rlberry/agents/experimental/torch/sac/improvements.txt

This file was deleted.

Loading