Offline meta-reinforcement learning (OMRL) utilizes pre-collected offline datasets to enhance the agent's generalization ability on unseen tasks. However, the context shift problem arises due to the distribution discrepancy between the contexts used for training (from the behavior policy) and testing (from the exploration policy). The context shift problem leads to incorrect task inference and further deteriorates the generalization ability of the meta-policy. Existing OMRL methods either overlook this problem or attempt to mitigate it with additional information. In this paper, we propose a novel approach called Context Shift Reduction for OMRL (CSRO) to address the context shift problem with only offline datasets. The key insight of CSRO is to minimize the influence of policy in context during both the meta-training and meta-test phases. During meta-training, we design a max-min mutual information representation learning mechanism to diminish the impact of the behavior policy on task representation. In the meta-test phase, we introduce the non-prior context collection strategy to reduce the effect of the exploration policy. Experimental results demonstrate that CSRO significantly reduces the context shift and improves the generalization ability, surpassing previous methods across various challenging domains.
To install locally, you will need to first install MuJoCo. For task distributions in which the reward function varies (Cheetah, Ant, Humanoid), install MuJoCo200. Set LD_LIBRARY_PATH
to point to both the MuJoCo binaries (/$HOME/.mujoco/mujoco200/bin
).
For the remaining dependencies, create conda environment by
conda env create -f environment.yaml
For Walker and Hopper environments, MuJoCo131 is required. Simply install it the same way as MuJoCo200. To switch between different MuJoCo versions:
export MUJOCO_PY_MJPRO_PATH=~/.mujoco/mjpro131
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mjpro131/bin
export MUJOCO_PY_MJPRO_PATH=~/.mujoco/mujoco200
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin
The environments make use of the module rand_param_envs
which is submoduled in this repository https://github.com/dennisl88/rand_param_envs. We modify some parameters of environment in random_param_envs.
The whole pipeline consists of two stages: data generation and Offline RL experiments:
CSRO requires fixed data (batch) for meta-training and meta-testing, which are generated by trained SAC behavior policies. Experiments at this stage are configured via train.yaml
and train_point.yaml
located in ./rlkit/torch/sac/pytorch_sac/config/
.
The following is to divide the all environments into 8 parts. All the environments in the 0 part are trained on gpu 0:
CUDA_VISIBLE_DEVICES=0 python policy_train.py --config ./configs/[ENV].json --split 8 --split_idx 0
Generated data will be saved in ./offline_dataset/
Experiments are configured via json
configuration files located in ./configs
. Basic settings are defined and described in ./configs/default.py
. To reproduce an experiment, run:
CUDA_VISIBLE_DEVICES=0 python launch_experiment.py ./configs/[ENV].json --seed 0
Output files will be written to ./output/[ENV]/[EXP NAME]/seed[seed]
where the experiment name corresponds to the process starting time. The file progress.csv
contains statistics logged over the course of training. We recommend viskit
for visualizing learning curves: https://github.com/vitchyr/viskit. Network weights are also snapshotted during training.
python show_path.py configs/point-robot.json --gpu 0 --seed 1 --exp_name classifier_mix_z0_hvar_weighted-x --algo_type CLASSIFIER --train_z0_policy true --use_hvar true --z_strategy weighted