Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release v1.0 #69

Merged
merged 31 commits into from
Mar 11, 2021
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ac890a2
Update sac hyperparams
araffin Mar 1, 2021
d033c0a
Bump version
araffin Mar 1, 2021
fa335df
Move real robot hyperparams
araffin Mar 1, 2021
27edfcb
Update HER params
araffin Mar 2, 2021
1e0c8bd
Fix for HER action noise
araffin Mar 2, 2021
84bd43c
Update benchmark file
araffin Mar 2, 2021
1d3b8a7
Update formatting
araffin Mar 2, 2021
599ba3f
Catch errors when benchmarking
araffin Mar 3, 2021
cb509ac
Use subprocess only if needed
araffin Mar 3, 2021
df76707
Change default number of threads for bench
araffin Mar 3, 2021
994cc30
Add pre-trained agents
araffin Mar 4, 2021
6172d4a
Catch keyboard interrupt for enjoy
araffin Mar 5, 2021
4a912b4
Update benchmark
araffin Mar 5, 2021
2395593
Update README and changelog
araffin Mar 5, 2021
8a41755
Merge branch 'feat/release-v1.0rc0' of github.com:DLR-RM/rl-baselines…
araffin Mar 5, 2021
1b9611c
Tuned DDPG hyperparam
araffin Mar 5, 2021
ef04597
Update TD3 hyperparams
araffin Mar 5, 2021
597a304
Minor edit
araffin Mar 5, 2021
4c2acb1
Add Reacher
araffin Mar 5, 2021
164c2d1
Update table
araffin Mar 5, 2021
7441e47
Ugrade SB3
araffin Mar 6, 2021
4cadd46
Add support for loading saved models with python 3.8
araffin Mar 6, 2021
52693c6
Upgrade SB3
araffin Mar 6, 2021
9debc19
Add BipedalWalkerHardcore
araffin Mar 8, 2021
190adf4
Merge branch 'feat/release-v1.0rc0' of github.com:DLR-RM/rl-baselines…
araffin Mar 8, 2021
d5f75ff
Changed pybullet version in CI
araffin Mar 8, 2021
a02f4fb
Add more Atari games
araffin Mar 9, 2021
ae953f4
Update README
araffin Mar 9, 2021
40f3b5b
Add benchmark files
araffin Mar 9, 2021
1c903c6
Add QR-DQN Enduro
araffin Mar 11, 2021
11f6266
Update README + bug fix for HER enjoy
araffin Mar 11, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add support for loading saved models with python 3.8
araffin committed Mar 6, 2021
commit 4cadd46d5dc76cb2ff4d1954d281cf36cfd8f126
3 changes: 1 addition & 2 deletions .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
@@ -16,8 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7] # 3.8 not supported yet due to cloudpickle errors

python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
with:
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@

### New Features
- Added 90+ trained agents + benchmark file
- Add support for loading saved model under python 3.8+ (no retraining possible)

### Bug fixes
- Bug fixes for `HER` handling action noise
17 changes: 15 additions & 2 deletions enjoy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import importlib
import os
import sys

import numpy as np
import torch as th
@@ -140,7 +141,19 @@ def main(): # noqa: C901
# Dummy buffer size as we don't need memory to enjoy the trained agent
kwargs.update(dict(buffer_size=1))

model = ALGOS[algo].load(model_path, env=env, **kwargs)
# Check if we are running python 3.8+
# we need to patch saved model under python 3.6/3.7 to load them
newer_python_version = sys.version_info.major == 3 and sys.version_info.minor >= 8

custom_objects = {}
if newer_python_version:
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
}

model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, **kwargs)

obs = env.reset()

@@ -197,7 +210,7 @@ def main(): # noqa: C901

except KeyboardInterrupt:
pass

if args.verbose > 0 and len(successes) > 0:
print(f"Success rate: {100 * np.mean(successes):.2f}%")