Population-Based Training is a novel approach to hyperparameter optimisation by jointly optimising a population of models and their hyperparameters to maximise performance. PBT takes its inspiration from genetic algorithms where each member of the population can exploit information from the remainder of the population.
Illustration of PBT training process (Liebig, Jan Frederik, Evaluating Population based Reinforcement Learning for Transfer Learning, 2021)
To extend the population of agents to extreme-scale using High-Performance Computer, this repo, namely EVO provide a PBT implementation for RL using Message Passing Interface.
Message passing interface (MPI) provides a powerful, efficient, and portable way to express parallel programs. It is the dominant model used in high-performance computing. MPI is a programmable communication protocol on parallel computing nodes, which supports point-to-point communication and collective communication. Socket and TCP protocol communication are used in the transport layer. It is the main communication method on distributed memory supercomputers, and it can also run on shared computers.
mpi4py provides a Python interface that resembles the message passing interface (MPI), and hence allows Python programs to exploit multiple processors on multiple compute nodes.
Prerequisites:
- Python 3.8
- Conda
- (Poetry)
- (Pytorch)1
Clone this repo:
git clone https://github.com/yyzpiero/evo.git
Create conda
environment:
conda create -p ./venv python==X.X
and use poetry to install all Python packages:
poetry install
Please use pip or poetry to install mpi4py
:
pip install mpi4py
or
poetry add mpi4py
Using Conda install may lead to some unknown issues.
Activate conda
environment:
conda activate ./venv
Please use mpiexec
or mpirun
to run experiments:
mpiexec -n 4 python pbt_rl_wta.py --num-agents 4 --env-id CartPole-v1
EVO also supports experiment monitoring with Tensorboard. Example command line to run an experiment with Tensorboard monitoring:
mpiexe -n 4 python pbt_rl_truct_collective.py --num-agents 4 --env-id CartPole-v1 --tb-writer True
The toy example was reproduced from Fig. 2 in the PBT paper
PPO agent from stable-baselines 3
with default settings are used as reinforcement learning agent.
self.model = PPO("MlpPolicy", env=self.env, verbose=0, create_eval_env=True)
However, it can also be replaced by any other reinforcement learning algorithms.
A simply selection mechanism, that for each generation, only the best-performed agent is kept, and its NN parameters are copied to all other agents. .py provides an implementation of such a mechanism using collective communications.
It is the default selection strategy in PBT paper for RL training, and is widely used in other PBT-based methods.
All agents in the entire population are ranked by their episodic rewards. If the agent is in the bottom
Variants | Description |
---|---|
pbt_rl_truct.py |
implementation using point-2-point communications via send and recv . |
pbt_rl_truct_collective.py |
implementation using collective communications. |
For small clusters with a limited number of nodes, we suggest the point-2-point method, which is faster than the collective method. However, for large HPC clusters, the collective method is much faster and more robust.
We used continuous control AntBulletEnv-v0
scenario in PyBullet environments to test our implementations.
Results of the experiments are presented on the Figure below:
Left Figure: Reward per generation using PBT | Right Figure: Reward per step using single SB3 agent
Some key observations:
-
By using PBT to train PPO agents can achieve better results than a SAC agent(single agent)
- Note: SAC should outperforms PPO (see OpenRL) in most PyBullet environments
-
"Winner-takes-all" outperforms the Truncation Selection mechanism in this scenario.
This repo is inspired by graf, angusfung's population based training repo.
Footnotes
-
Please use cpu-only version if possible, as most HPC clusters don't have GPUs ↩
-
This article briefly introduces the difference between point-2-point communications and collective communications in MPI. ↩