Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TRPO #40

Merged
merged 33 commits into from
Dec 29, 2021
Merged
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f779a9f
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 7, 2021
98bc5b2
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 9, 2021
97ece67
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 17, 2021
799b140
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 17, 2021
dc73462
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 19, 2021
9b8a222
feat: TRPO - addressing PR comments
cyprienc Sep 11, 2021
869dce9
refactor: TRPO - policier
cyprienc Sep 11, 2021
347dcc0
feat: using updated ActorCriticPolicy from SB3
cyprienc Sep 11, 2021
35d7256
Bump version for `get_distribution` support
araffin Sep 13, 2021
9cfcb54
Add basic test
araffin Sep 13, 2021
974174a
Reformat
araffin Sep 13, 2021
b6bd449
[ci skip] Fix changelog
araffin Sep 13, 2021
c88951c
fix: setting train mode for trpo
cyprienc Sep 13, 2021
1f7e99d
fix: batch_size type hint in trpo.py
cyprienc Sep 13, 2021
6540371
style: renaming variables + docstring in trpo.py
cyprienc Sep 15, 2021
3a26c05
Merge branch 'master' into master
araffin Sep 23, 2021
f003e88
Merge branch 'master' into master
araffin Sep 27, 2021
a33409e
Merge branch 'master' into master
araffin Sep 29, 2021
8ecf40e
Rename + cleanup
araffin Sep 29, 2021
45f4ea6
Move grad computation to separate method
araffin Sep 29, 2021
cc4b5ab
Remove grad norm clipping
araffin Sep 29, 2021
fc7a6c7
Remove n epochs and add sub-sampling
araffin Sep 29, 2021
66723ff
Update defaults
araffin Sep 29, 2021
63a263f
Merge branch 'master' into master
araffin Dec 1, 2021
bf583de
Merge branch 'master' into cyprienc/master
araffin Dec 10, 2021
e983348
Add Doc
araffin Dec 27, 2021
439d79b
Add more test and fixes for CNN
araffin Dec 27, 2021
d9483dc
Update doc + add benchmark
araffin Dec 28, 2021
fff84e4
Add tests + update doc
araffin Dec 28, 2021
95dddf4
Fix doc
araffin Dec 28, 2021
661fe15
Improve names for conjugate gradient
araffin Dec 29, 2021
a24e7c0
Update comments
araffin Dec 29, 2021
342fe53
Update changelog
araffin Dec 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add basic test
araffin committed Sep 13, 2021
commit 9cfcb540c9525c57789c17b2174c56c6384a1ba5
8 changes: 7 additions & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from sb3_contrib import QRDQN, TQC
from sb3_contrib import QRDQN, TQC, TRPO


@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
@@ -56,3 +56,9 @@ def test_qrdqn():
create_eval_env=True,
)
model.learn(total_timesteps=500, eval_freq=250)


@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
def test_trpo(env_id):
model = TRPO("MlpPolicy", env_id, n_steps=64, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
model.learn(total_timesteps=500)