Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ATARI environments - part1 #277

Merged
merged 88 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
3e874be
add requirements to use atari game
JulienT01 Feb 15, 2023
74686f6
update atari_make, and the wrapper scalarize, to manage env from atar…
JulienT01 Feb 15, 2023
b6739d5
update training and models to manage cnn (mandatory for atari games)
JulienT01 Feb 15, 2023
32a5703
add tests on atari games (test the cnn part in dqn and ppo too)
JulienT01 Feb 15, 2023
54fc405
add example with video for the documentation
JulienT01 Feb 15, 2023
469a86c
black
JulienT01 Feb 15, 2023
92286b9
black
JulienT01 Feb 15, 2023
ae0ed61
update setup.py
JulienT01 Feb 15, 2023
b701b29
add pytest-xprocess to run test_server.py
JulienT01 Feb 16, 2023
ae587f7
add configfiles to .gitignore
JulienT01 Feb 16, 2023
dac452f
change to fixed version image azure
TimotheeMathieu Feb 28, 2023
c9c0248
Merge branch 'rlberry-py:main' into Atari_part1
JulienT01 Feb 28, 2023
9cc2efe
remove accelerate
TimotheeMathieu Feb 28, 2023
8206d53
Update README.md
JulienT01 Mar 15, 2023
daae996
Merge remote-tracking branch 'origin/main' into Atari_part1
JulienT01 Mar 15, 2023
a21f5cd
temporary correction until main branch update
JulienT01 Mar 15, 2023
3d242ae
Merge remote-tracking branch 'origin/main' into Atari_part1
JulienT01 Mar 15, 2023
91ea0d1
xfail on tests that failed on Mac and windows
JulienT01 Mar 17, 2023
40ddbab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2023
314f532
Update README.md
JulienT01 Mar 17, 2023
7b3a692
optuna more graceful cleaning
TimotheeMathieu Mar 17, 2023
dbe3e33
shutils rmtree to os.rmdir
TimotheeMathieu Mar 20, 2023
4bf876b
fix optuna ?
TimotheeMathieu Mar 20, 2023
4e297da
trigger ci
TimotheeMathieu Mar 20, 2023
1d15711
add xfail for windows CI
JulienT01 Mar 30, 2023
238ca0b
Merge branch 'xfail_tests_mac_windows' of github.com:JulienT01/rlberr…
JulienT01 Mar 30, 2023
585739d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2023
cdd5e2e
Merge remote-tracking branch 'origin/xfail_tests_mac_windows' into At…
JulienT01 Mar 30, 2023
5014539
add xfail for windows CI
JulienT01 Mar 30, 2023
a4c3d53
Merge remote-tracking branch 'origin/xfail_tests_mac_windows' into At…
JulienT01 Mar 30, 2023
8a793cb
Merge branch 'rlberry-py:main' into Atari_part1
JulienT01 Mar 30, 2023
1f7c57a
empty commit trigger ci
JulienT01 Mar 30, 2023
79654ab
remove main from tests file
JulienT01 Mar 30, 2023
15ab700
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2023
7bd3a17
add xfail for windows CI
JulienT01 Mar 30, 2023
96bf506
Merge branch 'xfail_tests_mac_windows' of github.com:JulienT01/rlberr…
JulienT01 Mar 30, 2023
10cff9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2023
e5dd32a
add xfail for windows CI
JulienT01 Mar 30, 2023
894e265
Merge branch 'xfail_tests_mac_windows' of github.com:JulienT01/rlberr…
JulienT01 Mar 30, 2023
bd366f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2023
944efc3
Empty-Commit
JulienT01 Mar 31, 2023
f2eebe7
Merge branch 'xfail_tests_mac_windows' of github.com:JulienT01/rlberr…
JulienT01 Mar 31, 2023
e84b1c0
Merge remote-tracking branch 'origin/xfail_tests_mac_windows' into At…
JulienT01 Mar 31, 2023
afcc8c7
remove main from tests
JulienT01 Mar 31, 2023
9c025c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2023
7990228
Empty-Commit
JulienT01 Mar 31, 2023
8119065
empty commit trigger ci
JulienT01 Mar 31, 2023
812f3d4
xfail on windows
JulienT01 Mar 31, 2023
6c7d5d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2023
7e28010
allow to continue the buffer after a first 'fit()'
JulienT01 Mar 31, 2023
ae27dcc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2023
518abfa
Merge remote-tracking branch 'origin/xfail_tests_mac_windows' into PP…
JulienT01 Mar 31, 2023
587681e
Merge branch 'PPO_buffer' of github.com:JulienT01/rlberry into PPO_bu…
JulienT01 Mar 31, 2023
8f71b96
xfail for windows...
JulienT01 Mar 31, 2023
cb391a3
Empty-Commit
JulienT01 Mar 31, 2023
a36fadb
Empty-Commit
JulienT01 Mar 31, 2023
d0b9120
remove 'sleep' (added for debug)
JulienT01 Apr 3, 2023
1d6fd61
remove sleep (old debug)
JulienT01 Apr 3, 2023
7af5d3a
remove sleep (old debug)
JulienT01 Apr 3, 2023
98357b7
remove 'sleep' (added for debug)
JulienT01 Apr 3, 2023
2917e69
Merge remote-tracking branch 'origin/xfail_tests_mac_windows' into PP…
JulienT01 Apr 3, 2023
35aa58b
use temporary folder instead
JulienT01 Apr 3, 2023
02e69b1
generalize PPO tests to 'check_agent'
JulienT01 Apr 4, 2023
9073353
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2023
799ceda
flake
JulienT01 Apr 4, 2023
9af6369
Merge branch 'PPO_buffer' of github.com:JulienT01/rlberry into PPO_bu…
JulienT01 Apr 4, 2023
a49cbab
patch : stableBaselines don't have get_params()
JulienT01 Apr 4, 2023
4377431
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2023
ef0d917
Empty-Commit
JulienT01 Apr 4, 2023
e8f0b29
update doc
JulienT01 Apr 5, 2023
3789caf
Empty-Commit
JulienT01 Apr 5, 2023
d6a63b7
Merge branch 'fix_ci' of https://github.com/TimotheeMathieu/rlberry i…
JulienT01 Apr 5, 2023
d310dd2
don't remove PyOpenGL_accelerate
JulienT01 Apr 5, 2023
90faf20
Merge remote-tracking branch 'origin/xfail_tests_mac_windows' into PP…
JulienT01 Apr 5, 2023
3993823
Merge remote-tracking branch 'origin/PPO_buffer' into Atari_part1
JulienT01 Apr 5, 2023
9f5f217
Merge branch 'main' into Atari_part1
JulienT01 Apr 5, 2023
03afc5c
add tests for atari empty input dim
JulienT01 Apr 5, 2023
5ae295b
Merge branch 'Atari_part1' of github.com:JulienT01/rlberry into Atari…
JulienT01 Apr 5, 2023
d321a99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2023
fad2801
remove test (already exist in "check_agent.py")
JulienT01 Apr 5, 2023
80f1220
Merge branch 'Atari_part1' of github.com:JulienT01/rlberry into Atari…
JulienT01 Apr 5, 2023
0c26794
updades following Matheus review
JulienT01 Apr 11, 2023
46ad222
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2023
0cbd974
add docstring for atari_make
JulienT01 Apr 12, 2023
6242b5a
Merge branch 'Atari_part1' of github.com:JulienT01/rlberry into Atari…
JulienT01 Apr 12, 2023
5439abf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
47e7c09
Merge branch 'rlberry-py:main' into Atari_part1
JulienT01 Apr 12, 2023
ca981ac
update changelog
JulienT01 Apr 12, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ dmypy.json

# PyCharm
.idea
.project
.pydevproject
profile.prof
Binary file added docs/_video/video_plot_atari_freeway.mp4
Binary file not shown.
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Dev version

* Move old scripts (jax agents, attention networks, old examples...) that we won't maintain from the main branch to an archive branch.

*PR #277*

* Add and update code to use "Atari games" env


Version 0.4.0 (latest stable version)
--------------------------------------
Expand Down
Binary file added docs/thumbnails/video_plot_atari_freeway.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
120 changes: 120 additions & 0 deletions examples/demo_env/video_plot_atari_freeway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
===============================================
A demo of ATARI Freeway environment with DQNAgent
===============================================
Illustration of the training and video rendering of DQN Agent in
ATARI Freeway environment.

Agent is slightly tuned, but not optimal. This is just for illustration purpose.

.. video:: ../../video_plot_atari_freeway.mp4
:width: 600

"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_atari_freeway.jpg'


from rlberry.manager.agent_manager import AgentManager
from datetime import datetime
from rlberry.agents.torch.dqn.dqn import DQNAgent
from gym.wrappers.record_video import RecordVideo
import shutil
import os
from rlberry.envs.gym_make import atari_make


initial_time = datetime.now()
print("-------- init agent --------")

mlp_configs = {
"type": "MultiLayerPerceptron", # A network architecture
"layer_sizes": [512], # Network dimensions
"reshape": False,
"is_policy": False, # The network should output a distribution
# over actions
}

cnn_configs = {
"type": "ConvolutionalNetwork", # A network architecture
"activation": "RELU",
"in_channels": 4,
"in_height": 84,
"in_width": 84,
"head_mlp_kwargs": mlp_configs,
"transpose_obs": False,
"is_policy": False, # The network should output a distribution
}

tuned_agent = AgentManager(
DQNAgent, # The Agent class.
(
atari_make,
dict(
id="ALE/Freeway-v5",
),
), # The Environment to solve.
init_kwargs=dict( # Where to put the agent's hyperparameters
q_net_constructor="rlberry.agents.torch.utils.training.model_factory_from_env",
q_net_kwargs=cnn_configs,
max_replay_size=50000,
batch_size=32,
learning_starts=25000,
gradient_steps=1,
epsilon_final=0.01,
learning_rate=1e-4, # Size of the policy gradient descent steps.
chunk_size=1,
),
fit_budget=90000, # The number of interactions between the agent and the environment during training.
eval_kwargs=dict(
eval_horizon=500
), # The number of interactions between the agent and the environment during evaluations.
n_fit=1, # The number of agents to train. Usually, it is good to do more than 1 because the training is stochastic.
agent_name="DQN_tuned", # The agent's name.
output_dir="DQN_for_freeway",
)

print("-------- init agent : done!--------")
print("-------- train agent --------")

tuned_agent.fit()

print("-------- train agent : done!--------")

final_train_time = datetime.now()

print("-------- test agent with video--------")

env = atari_make(
"ALE/Freeway-v5",
)
env = RecordVideo(env, "docs/_video/temp")

if "render_modes" in env.metadata:
env.metadata["render.modes"] = env.metadata[
"render_modes"
] # bug with some 'gym' version

state = env.reset()
for tt in range(30000):
action = tuned_agent.get_agent_instances()[0].policy(state)
next_s, _, done, test = env.step(action)
if done:
break
state = next_s

env.close()

print("-------- test agent with video : done!--------")
final_test_time = datetime.now()
tuned_agent.save()

os.rename("_video/temp/rl-video-episode-0.mp4", "_video/video_plot_atari_freeway.mp4")
shutil.rmtree("_video/temp/")


print("Done!!!")
print("-------------")
print("begin run at :" + str(initial_time))
print("end training at :" + str(final_train_time))
print("end run at :" + str(final_test_time))
print("-------------")
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pygame
matplotlib
seaborn
pandas
gym==0.21
gym[accept-rom-license]==0.21.0
dill
docopt
pyyaml
Expand All @@ -18,3 +18,7 @@ torch>=1.6.0
stable-baselines3
protobuf==3.20.1
tensorboard
opencv-python
ale-py==0.7.4
pytest
pytest-xprocess
100 changes: 100 additions & 0 deletions rlberry/agents/torch/tests/test_torch_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from rlberry.manager.agent_manager import AgentManager
from rlberry.agents.torch.dqn.dqn import DQNAgent
from rlberry.envs.gym_make import atari_make


def test_forward_dqn():
mlp_configs = {
"type": "MultiLayerPerceptron", # A network architecture
"layer_sizes": [512], # Network dimensions
"reshape": False,
"is_policy": False, # The network should output a distribution
# over actions
}

cnn_configs = {
"type": "ConvolutionalNetwork", # A network architecture
"activation": "RELU",
"in_channels": 4,
"in_height": 84,
"in_width": 84,
"head_mlp_kwargs": mlp_configs,
"transpose_obs": False,
"is_policy": False, # The network should output a distribution
}

tuned_agent = AgentManager(
DQNAgent, # The Agent class.
(
atari_make,
# uncomment when rlberry will manage vectorized env
# dict(id="ALE/Breakout-v5", n_envs=3),
dict(id="ALE/Breakout-v5", n_envs=1),
), # The Environment to solve.
init_kwargs=dict( # Where to put the agent's hyperparameters
q_net_constructor="rlberry.agents.torch.utils.training.model_factory_from_env",
q_net_kwargs=cnn_configs,
max_replay_size=100,
batch_size=32,
learning_starts=100,
gradient_steps=1,
epsilon_final=0.01,
learning_rate=1e-4, # Size of the policy gradient descent steps.
chunk_size=5,
),
fit_budget=200, # The number of interactions between the agent and the environment during training.
eval_kwargs=dict(
eval_horizon=10
), # The number of interactions between the agent and the environment during evaluations.
n_fit=1, # The number of agents to train. Usually, it is good to do more than 1 because the training is stochastic.
agent_name="DQN_test", # The agent's name.
)

tuned_agent.fit()


def test_forward_empty_input_dim():
mlp_configs = {
"type": "MultiLayerPerceptron", # A network architecture
"layer_sizes": [512], # Network dimensions
"reshape": False,
"is_policy": False, # The network should output a distribution
# over actions
}

cnn_configs = {
"type": "ConvolutionalNetwork", # A network architecture
"activation": "RELU",
"head_mlp_kwargs": mlp_configs,
"transpose_obs": False,
"is_policy": False, # The network should output a distribution
}

tuned_agent = AgentManager(
DQNAgent, # The Agent class.
(
atari_make,
# uncomment when rlberry will manage vectorized env
# dict(id="ALE/Breakout-v5", n_envs=3),
dict(id="ALE/Breakout-v5", n_envs=1),
), # The Environment to solve.
init_kwargs=dict( # Where to put the agent's hyperparameters
q_net_constructor="rlberry.agents.torch.utils.training.model_factory_from_env",
q_net_kwargs=cnn_configs,
max_replay_size=100,
batch_size=32,
learning_starts=100,
gradient_steps=1,
epsilon_final=0.01,
learning_rate=1e-4, # Size of the policy gradient descent steps.
chunk_size=5,
),
fit_budget=10, # The number of interactions between the agent and the environment during training.
eval_kwargs=dict(
eval_horizon=10
), # The number of interactions between the agent and the environment during evaluations.
n_fit=1, # The number of agents to train. Usually, it is good to do more than 1 because the training is stochastic.
agent_name="DQN_test", # The agent's name.
)

tuned_agent.fit()
50 changes: 38 additions & 12 deletions rlberry/agents/torch/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def default_policy_net_fn(env):
)

if len(obs_shape) == 3:
if obs_shape[0] < obs_shape[1] and obs_shape[0] < obs_shape[1]:
if obs_shape[0] < obs_shape[1] and obs_shape[0] < obs_shape[2]:
# Assume CHW observation space
model_config = {
"type": "ConvolutionalNetwork",
Expand Down Expand Up @@ -397,6 +397,8 @@ class ConvolutionalNetwork(nn.Module):
H = height;
W = width.

For the CNN forward, if the tensor has more than 4 dimensions (not BCHW), it keeps the 3 last dimension as CHW and merge all first ones into 1 (Batch). Go through the CNN + MLP, then split the first dimension as before.

Parameters
----------
activation: {"RELU", "TANH", "ELU"}
Expand Down Expand Up @@ -434,25 +436,30 @@ def __init__(
self.conv3 = nn.Conv2d(32, 64, kernel_size=2, stride=2)

# MLP Head
# Number of Linear input connections depends on output of conv2d layers
# and therefore the input image size, so compute it.
def conv2d_size_out(size, kernel_size=2, stride=2):
return (size - (kernel_size - 1) - 1) // stride + 1

convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(in_width)))
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(in_height)))
assert convh > 0 and convw > 0
self.head_mlp_kwargs = head_mlp_kwargs or {}
self.head_mlp_kwargs["in_size"] = convw * convh * 64
self.head_mlp_kwargs["in_size"] = self._get_conv_out_size(
[in_channels, in_height, in_width]
) # Number of Linear input connections depends on output of conv layers
self.head_mlp_kwargs["out_size"] = out_size
self.head_mlp_kwargs["is_policy"] = is_policy
self.head = model_factory(**self.head_mlp_kwargs)

self.is_policy = is_policy
self.transpose_obs = transpose_obs

def _get_conv_out_size(self, shape):
"""
Computes the output dimensions of the convolution network.
Shape : dimension of the input of the CNN
"""
conv_result = self.activation((self.conv1(torch.zeros(1, *shape))))
conv_result = self.activation((self.conv2(conv_result)))
conv_result = self.activation((self.conv3(conv_result)))
return int(np.prod(conv_result.size()))

def convolutions(self, x):
x = x.float()
# if there is no batch (CHW), add one dimension to specify batch of 1 (and get format BCHW)
if len(x.shape) == 3:
x = x.unsqueeze(0)
if self.transpose_obs:
Expand All @@ -470,9 +477,28 @@ def forward(self, x):
Parameters
----------
x: torch.tensor
Tensor of shape BCHW
Tensor of shape BCHW (Batch,Chanel,Height,Width : if more than 4 dimensions, merge all the first in batch dimension)
"""
return self.head(self.convolutions(x))
flag_view_to_change = False

if len(x.shape) > 4:
flag_view_to_change = True
dim_to_retore = x.shape[:-3]
inputview_size = tuple((-1,)) + tuple(x.shape[-3:])
outputview_size = tuple(dim_to_retore) + tuple(
(self.head_mlp_kwargs["out_size"],)
)
x = x.view(inputview_size)

conv_result = self.convolutions(x)
output_result = self.head(
conv_result.view(conv_result.size()[0], -1)
) # give the 'conv_result' flattenned in 2 dimensions (batch and other) to the MLP (head)

if flag_view_to_change:
output_result = output_result.view(outputview_size)

return output_result

def action_scores(self, x):
return self.head.action_scores(self.convolutions(x))
21 changes: 17 additions & 4 deletions rlberry/agents/torch/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,23 @@ def size_model_config(env, **model_config):
return model_config

# Assume CHW observation space
if model_config["type"] == "ConvolutionalNetwork":
model_config["in_channels"] = int(obs_shape[0])
model_config["in_height"] = int(obs_shape[1])
model_config["in_width"] = int(obs_shape[2])
if "type" in model_config and model_config["type"] == "ConvolutionalNetwork":
if "transpose_obs" in model_config and not model_config["transpose_obs"]:
# Assume CHW observation space
if "in_channels" not in model_config:
model_config["in_channels"] = int(obs_shape[0])
if "in_height" not in model_config:
model_config["in_height"] = int(obs_shape[1])
if "in_width" not in model_config:
model_config["in_width"] = int(obs_shape[2])
else:
# Assume WHC observation space to transpose
if "in_channels" not in model_config:
model_config["in_channels"] = int(obs_shape[2])
if "in_height" not in model_config:
model_config["in_height"] = int(obs_shape[1])
if "in_width" not in model_config:
model_config["in_width"] = int(obs_shape[0])
else:
model_config["in_size"] = int(np.prod(obs_shape))

Expand Down
Loading