Skip to content

Commit

Permalink
Merge branch 'OpenRL-Lab:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
YiwenAI authored Jun 20, 2023
2 parents 9c3a2cb + cbf110e commit 94a8122
Show file tree
Hide file tree
Showing 26 changed files with 1,229 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
## Checklist
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
- [ ] I have ensured `make test` pass (**required**).
- [ ] I have checked the code using `make commit-checks` (**required**).
- [ ] I have checked the code using `make format` (**required**).
4 changes: 1 addition & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ Firstly, you should install the test-related packages: `pip install -e ".[test]"

Then, ensure that unit tests pass by executing `make test`.

Next, format your code by running `make format`.

Lastly, run `make commit-checks` to check if your code complies with OpenRL's coding style.
Lastly, format your code by running `make format`.

> Tip: OpenRL uses [black](https://github.com/psf/black) coding style.
You can install black plugins in your editor as shown in the [official website](https://black.readthedocs.io/en/stable/integrations/editors.html)
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ format:
isort ${PYTHON_FILES}
# Reformat using black
black ${PYTHON_FILES} --preview
# do format agent
isort ${PYTHON_FILES}
black ${PYTHON_FILES} --preview

commit-checks: format lint

Expand Down
4 changes: 1 addition & 3 deletions docs/CONTRIBUTING_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ OpenRL社区欢迎任何人参与到OpenRL的建设中来,无论您是开发

然后,您需要确保单元测试通过,这可以通过执行`make test`来完成。

然后,您需要执行`make format`来格式化您的代码。

最后,您需要执行`make commit-checks`来检查您的代码是否符合OpenRL的代码风格。
最后,您需要执行`make format`来格式化您的代码。

> 小技巧: OpenRL使用 [black](https://github.com/psf/black) 代码风格。
您可以在您的编辑器中安装black的[插件](https://black.readthedocs.io/en/stable/integrations/editors.html)
Expand Down
6 changes: 6 additions & 0 deletions examples/cartpole/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,10 @@ To train with [Dual-clip PPO](https://arxiv.org/abs/1912.09729):

```shell
python train_ppo.py --config dual_clip_ppo.yaml
```

If you want to save checkpoints, try to train with Callbacks:

```shell
python train_ppo.py --config callbacks.yaml
```
8 changes: 8 additions & 0 deletions examples/cartpole/callbacks.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
callbacks:
- id: "CheckpointCallback"
args: {
"save_freq": 500,
"save_path": "./checkpoints/",
"name_prefix": "ppo",
"save_replay_buffer": True
}
8 changes: 5 additions & 3 deletions examples/cartpole/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@


def train():
# create environment, set environment parallelism to 9
env = make("CartPole-v1", env_num=9)
# create the neural network
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args()

# create environment, set environment parallelism to 9
env = make("CartPole-v1", env_num=9)

net = Net(
env,
cfg=cfg,
Expand Down Expand Up @@ -47,4 +49,4 @@ def evaluation(agent):

if __name__ == "__main__":
agent = train()
evaluation(agent)
# evaluation(agent)
2 changes: 1 addition & 1 deletion examples/mpe/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def evaluation(agent):

if __name__ == "__main__":
agent = train()
evaluation(agent)
evaluation(agent)
5 changes: 5 additions & 0 deletions openrl/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
from typing import List

from jsonargparse import ActionConfigFile, ArgumentParser


Expand All @@ -32,6 +34,9 @@ def create_config_parser():
parser.add_argument("--n_head", type=int, default=1)
parser.add_argument("--dec_actor", action="store_true", default=False)
parser.add_argument("--share_actor", action="store_true", default=False)

parser.add_argument("--callbacks", type=List[dict])

# For Hierarchical RL
parser.add_argument(
"--step_difference",
Expand Down
13 changes: 9 additions & 4 deletions openrl/drivers/offpolicy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ def __init__(
config: Dict[str, Any],
trainer,
buffer,
agent,
rank: int = 0,
world_size: int = 1,
client=None,
logger: Optional[Logger] = None,
) -> None:
super(OffPolicyDriver, self).__init__(
config, trainer, buffer, rank, world_size, client, logger
config, trainer, buffer, agent, rank, world_size, client, logger
)

self.buffer_minimal_size = int(config["cfg"].buffer_size * 0.2)
Expand All @@ -55,7 +56,10 @@ def __init__(

def _inner_loop(
self,
) -> None:
) -> bool:
"""
:return: True if training should continue, False if training should stop
"""
rollout_infos = self.actor_rollout()

if self.buffer.get_buffer_size() >= 0:
Expand All @@ -73,6 +77,8 @@ def _inner_loop(
self.logger.log_info(rollout_infos, step=self.total_num_steps)
self.logger.log_info(train_infos, step=self.total_num_steps)

return True

def add2buffer(self, data):
(
obs,
Expand Down Expand Up @@ -135,7 +141,7 @@ def actor_rollout(self):
}

next_obs, rewards, dones, infos = self.envs.step(actions, extra_data)
if type(self.episode_steps)==int:
if type(self.episode_steps) == int:
if not dones:
self.episode_steps += 1
else:
Expand All @@ -157,7 +163,6 @@ def actor_rollout(self):
# "actions: ", actions)
# print("rewards: ", rewards)


data = (
obs,
next_obs,
Expand Down
42 changes: 35 additions & 7 deletions openrl/drivers/onpolicy_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
# limitations under the License.

""""""
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel

from openrl.drivers.rl_driver import RLDriver
from openrl.envs.vec_env.utils.util import prepare_available_actions
from openrl.runners.common.base_agent import BaseAgent
from openrl.utils.logger import Logger
from openrl.utils.type_aliases import MaybeCallback
from openrl.utils.util import _t2n


Expand All @@ -33,19 +35,35 @@ def __init__(
config: Dict[str, Any],
trainer,
buffer,
agent,
rank: int = 0,
world_size: int = 1,
client=None,
logger: Optional[Logger] = None,
callback: MaybeCallback = None,
) -> None:
super(OnPolicyDriver, self).__init__(
config, trainer, buffer, rank, world_size, client, logger
config,
trainer,
buffer,
agent,
rank,
world_size,
client,
logger,
callback=callback,
)

def _inner_loop(
self,
) -> None:
rollout_infos = self.actor_rollout()
) -> bool:
"""
:return: True if training should continue, False if training should stop
"""
rollout_infos, continue_training = self.actor_rollout()
if not continue_training:
return False

train_infos = self.learner_update()
self.buffer.after_update()

Expand All @@ -57,6 +75,7 @@ def _inner_loop(
# rollout_infos can only be used when env is wrapped with VevMonitor
self.logger.log_info(rollout_infos, step=self.total_num_steps)
self.logger.log_info(train_infos, step=self.total_num_steps)
return True

def add2buffer(self, data):
(
Expand Down Expand Up @@ -134,7 +153,9 @@ def add2buffer(self, data):
available_actions=available_actions,
)

def actor_rollout(self):
def actor_rollout(self) -> Tuple[Dict[str, Any], bool]:
self.callback.on_rollout_start()

self.trainer.prep_rollout()
import time

Expand All @@ -151,6 +172,11 @@ def actor_rollout(self):
}

obs, rewards, dones, infos = self.envs.step(actions, extra_data)
self.agent.num_time_steps += self.envs.parallel_env_num
# Give access to local variables
self.callback.update_locals(locals())
if self.callback.on_step() is False:
return {}, False

data = (
obs,
Expand All @@ -168,12 +194,14 @@ def actor_rollout(self):

batch_rew_infos = self.envs.batch_rewards(self.buffer)

self.callback.on_rollout_end()

if self.envs.use_monitor:
statistics_info = self.envs.statistics(self.buffer)
statistics_info.update(batch_rew_infos)
return statistics_info
return statistics_info, True
else:
return batch_rew_infos
return batch_rew_infos, True

@torch.no_grad()
def compute_returns(self):
Expand Down
14 changes: 12 additions & 2 deletions openrl/drivers/rl_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from openrl.drivers.base_driver import BaseDriver
from openrl.envs.vec_env.utils.util import prepare_available_actions
from openrl.utils.logger import Logger
from openrl.utils.type_aliases import MaybeCallback


class RLDriver(BaseDriver, ABC):
Expand All @@ -29,10 +30,12 @@ def __init__(
config: Dict[str, Any],
trainer,
buffer,
agent,
rank: int = 0,
world_size: int = 1,
client=None,
logger: Optional[Logger] = None,
callback: MaybeCallback = None,
) -> None:
self.trainer = trainer
self.buffer = buffer
Expand All @@ -45,6 +48,8 @@ def __init__(
self.program_type = cfg.program_type
self.envs = config["envs"]
self.device = config["device"]
self.callback = callback
self.agent = agent

assert not (
self.program_type != "actor" and self.world_size is None
Expand Down Expand Up @@ -104,7 +109,10 @@ def __init__(
self.cfg = cfg

@abstractmethod
def _inner_loop(self):
def _inner_loop(self) -> bool:
"""
:return: True if training should continue, False if training should stop
"""
raise NotImplementedError

def reset_and_buffer_init(self):
Expand Down Expand Up @@ -143,7 +151,9 @@ def run(self) -> None:
for episode in range(episodes):
self.logger.info("Episode: {}/{}".format(episode, episodes))
self.episode = episode
self._inner_loop()
continue_training = self._inner_loop()
if not continue_training:
break

def learner_update(self):
if self.use_linear_lr_decay:
Expand Down
45 changes: 44 additions & 1 deletion openrl/envs/vec_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,49 @@
from typing import Optional, Type, Union

from gymnasium import Env as GymEnv

from openrl.envs.vec_env.async_venv import AsyncVectorEnv
from openrl.envs.vec_env.base_venv import BaseVecEnv
from openrl.envs.vec_env.sync_venv import SyncVectorEnv
from openrl.envs.vec_env.wrappers.base_wrapper import VecEnvWrapper
from openrl.envs.vec_env.wrappers.reward_wrapper import RewardWrapper
from openrl.envs.vec_env.wrappers.vec_monitor_wrapper import VecMonitorWrapper

__all__ = ["SyncVectorEnv", "AsyncVectorEnv", "VecMonitorWrapper", "RewardWrapper"]
__all__ = [
"BaseVecEnv",
"SyncVectorEnv",
"AsyncVectorEnv",
"VecMonitorWrapper",
"RewardWrapper",
]


def unwrap_vec_wrapper(
env: Union[GymEnv, BaseVecEnv], vec_wrapper_class: Type[VecEnvWrapper]
) -> Optional[VecEnvWrapper]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env:
:param vec_wrapper_class:
:return:
"""
env_tmp = env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, vec_wrapper_class):
return env_tmp
env_tmp = env_tmp.venv
return None


def is_vecenv_wrapped(
env: Union[GymEnv, BaseVecEnv], vec_wrapper_class: Type[VecEnvWrapper]
) -> bool:
"""
Check if an environment is already wrapped by a given ``VecEnvWrapper``.
:param env:
:param vec_wrapper_class:
:return:
"""
return unwrap_vec_wrapper(env, vec_wrapper_class) is not None
2 changes: 1 addition & 1 deletion openrl/modules/networks/utils/nlp/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
model_name (str): name of the causal or seq2seq model from transformers library
optimizer_kwargs (Dict[str, Any], optional): optimizer kwargs. Defaults to {}.
weight_decay (float, optional): weight decay. Defaults to 1e-6.
use_sde (bool, optional): Use state-dependent exploration. Defaults to None. (Unused parameter from stable-baselines3)
use_sde (bool, optional): Use state-dependent exploration. Defaults to None.
apply_model_parallel (bool, optional): whether to apply model parallel. Defaults to True.
optimizer_class (torch.optim.Optimizer, optional): Optimizer class. Defaults to torch.optim.AdamW.
generation_kwargs (Dict[str, Any], optional): generation parameters for rollout. Defaults to {}.
Expand Down
Loading

0 comments on commit 94a8122

Please sign in to comment.