LightZero is an MCTS+RL reinforcement learning framework that provides a set of high-level APIs, enabling users to customize their algorithms within it. Here are some steps and considerations on how to customize an algorithm in LightZero.
Before you start coding your custom algorithms, you need to have a basic understanding of the LightZero framework's structure. The LightZero pipeline is illustrated in the following diagram.
The repository's folder consists primarily of two parts: lzero
and zoo
. The lzero
folder contains the core modules required for the LightZero framework's workflow. The zoo
folder provides a set of predefined environments (envs
) and their corresponding configuration (config
) files. The lzero
folder includes several core modules, including the policy
, model
, worker
, and entry
. These modules work together to implement complex reinforcement learning algorithms.
-
In this architecture, the
policy
module is responsible for implementing the algorithm's decision-making logic, such as action selection during agent-environment interaction and how to update the policy based on collected data. Themodel
module is responsible for implementing the neural network structures required by the algorithm. -
The
worker
module consists of two classes: Collector and Evaluator. An instance of the Collector class handles the agent-environment interaction to collect the necessary data for training, while an instance of the Evaluator class evaluates the performance of the current policy. -
The
entry
module is responsible for initializing the environment, model, policy, etc., and its main loop implements core processes such as data collection, model training, and policy evaluation. -
There are close interactions among these modules. Specifically, the
entry
module calls the Collector and Evaluator of theworker
module to perform data collection and algorithm evaluation. The decision functions of thepolicy
module are called by the Collector and Evaluator to determine the agent's actions in a specific environment. The neural network models implemented in themodel
module are embedded in thepolicy
object for action generation during interaction and for updates during the training process. -
In the
policy
module, you can find implementations of various algorithms. For example, the MuZero policy is implemented in themuzero.py
file.
Create a new Python file under the lzero/policy
directory. This file will contain your algorithm implementation. For example, if your algorithm is called MyAlgorithm, you can create a file named my_algorithm.py
.
Within your policy file, you need to define a class to implement your strategy. This class should inherit from the Policy
class in DI-engine and implement required methods. Below is a basic framework for a policy class:
@POLICY_REGISTRY.register('my_algorithm')
class MyAlgorithmPolicy(Policy):
"""
Overview:
The policy class for MyAlgorithm.
"""
config = dict(
# Add your config here
)
def __init__(self, cfg, **kwargs):
super().__init__(cfg, **kwargs)
# Initialize your policy here
def default_model(self) -> Tuple[str, List[str]]:
# Set the default model name and the import path so that the default model can be loaded during policy initialization
def _init_learn(self):
# Initialize the learn mode here
def _forward_learn(self, data):
# Implement the forward function for learning mode here
def _init_collect(self):
# Initialize the collect mode here
def _forward_collect(self, data, **kwargs):
# Implement the forward function for collect mode here
def _init_eval(self):
# Initialize the eval mode here
def _forward_eval(self, data, **kwargs):
# Implement the forward function for eval mode here
- In
default_model
, set the class name of the default model used by the current policy and the corresponding reference path. - The
_init_collect
and_init_eval
functions are responsible for instantiating the action selection policy, and the respective policy instances will be called by the _forward_collect and _forward_eval functions. - The
_forward_collect
function takes the current state of the environment and selects a step action by calling the instantiated policy in_init_collect
. The function returns the selected action list and other relevant information. During training, this function is called through thecollector.collect
method of the Collector object created by the Entry file. - The logic of the
_forward_eval
function is similar to that of the_forward_collect
function. The only difference is that the policy used in_forward_collect
is more focused on exploration to collect diverse training information, while the policy used in_forward_eval
is more focused on exploitation to obtain the optimal performance of the current policy. During training, this function is called through theevaluator.eval
method of the Evaluator object created by the Entry file.
- The
_init_learn
function initializes the network model, optimizer, and other objects required during training using the associated parameters of the strategy, such as learning rate, update frequency, optimizer type, passed in from the config file. - The
_forward_learn
function is responsible for updating the network. Typically, the_forward_learn
function receives the data collected by the Collector, calculates the loss function based on this data, and performs gradient updates. The function returns the various losses during the update process and the relevant parameters used for the update, for experimental recording purposes. During training, this function is called through thelearner.train
method of the Learner object created by the Entry file.
To make LightZero recognize your policy, you need to use the @POLICY_REGISTRY.register('my_algorithm')
decorator above your policy class to register your policy. This way, LightZero can refer to your policy by the name 'my_algorithm'. Specifically, in the experiment's configuration file, the corresponding algorithm is specified through the create_config
section:
create_config = dict(
...
policy=dict(
type='my_algorithm',
import_names=['lzero.policy.my_algorithm'],
),
...
)
Here, type
should be set to the registered policy name, and import_names
should be set to the location of the policy package.
- Model: The LightZero
model.common
package provides some common network structures, such as theRepresentationNetwork
that maps 2D images to a latent space representation and thePredictionNetwork
used in MCTS for predicting probabilities and node values. If a custom policy requires a specific network model, you need to implement the corresponding model under themodel
folder. For example, the model for the MuZero algorithm is saved in themuzero_model.py
file, which implements theDynamicsNetwork
required by the MuZero algorithm and ultimately creates theMuZeroModel
by calling the existing network structures in themodel.common
package. - Worker: LightZero provides corresponding
worker
for AlphaZero and MuZero. Subsequent algorithms like EfficientZero and GumbelMuzero inherit theworker
from MuZero. If your algorithm has different logic for data collection, you need to implement the correspondingworker
. For example, if your algorithm requires preprocessing of collected transitions, you can add this segment under thecollect
function of the collector, in which theget_train_sample
function implements the specific data processing process.
if timestep.done:
# Prepare trajectory data.
transitions = to_tensor_transitions(self._traj_buffer[env_id])
# Use ``get_train_sample`` to process the data.
train_sample = self._policy.get_train_sample(transitions)
return_data.extend(train_sample)
self._traj_buffer[env_id].clear()
After implementing your strategy, it is crucial to ensure its correctness and effectiveness. To do so, you should write some unit tests to verify that your strategy is functioning correctly. For example, you can test if the strategy can execute in a specific environment and if the output of the strategy matches the expected results. You can refer to the documentation in the DI-engine for guidance on how to write unit tests. You can add your tests in the lzero/policy/tests
. When writing tests, try to consider all possible scenarios and boundary conditions to ensure your strategy can run properly in various situations.
Here is an example of unit testing in LightZero. In this example, we test the inverse_scalar_transform
and InverseScalarTransform
methods. Both methods reverse the transformation of a given value, but they have different implementations. In the unit test, we apply these two methods to the same set of data and compare the output results. If the results are the same, the test passes.
import pytest
import torch
from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform
@pytest.mark.unittest
def test_scaling_transform():
import time
logit = torch.randn(16, 601)
start = time.time()
output_1 = inverse_scalar_transform(logit, 300)
print('t1', time.time() - start)
handle = InverseScalarTransform(300)
start = time.time()
output_2 = handle(logit)
print('t2', time.time() - start)
assert output_1.shape == output_2.shape == (16, 1)
assert (output_1 == output_2).all()
In the unit test file, you need to mark the tests with @pytest.mark.unittest
to include them in the Python testing framework. This allows you to run the unit test file directly by entering pytest -sv xxx.py
in the command line. -sv
is a command option that, when used, prints detailed information to the terminal during the test execution for easier inspection.
- After ensuring the basic functionality of the policy, you need to use classic environments like cartpole to conduct comprehensive correctness and convergence tests on your policy. This is to verify that your policy can work effectively not only in unit tests but also in real game environments.
- You can write related configuration files and entry programs by referring to
cartpole_muzero_config.py
. During the testing process, pay attention to record performance data of the policy, such as the score of each round, the convergence speed of the policy, etc., for analysis and improvement.
-
After completing all the above steps, if you wish to contribute your policy to the LightZero repository, you can submit a Pull Request on the official repository. Before submission, ensure your code complies with the repository's coding standards, all tests have passed, and there are sufficient documents and comments to explain your code and policy.
-
In the description of the PR, explain your policy in detail, including its working principle, your implementation method, and its performance in tests. This will help others understand your contribution and speed up the PR review process.
- After implementing and testing the policy, consider sharing your results and experiences with the community. You can post your policy and test results on forums, blogs, or social media and invite others to review and discuss your work. This not only allows you to receive feedback from others but also helps you build a professional network and may trigger new ideas and collaborations.
- Based on your test results and community feedback, continuously improve and optimize your policy. This may involve adjusting policy parameters, improving code performance, or solving problems and bugs that arise. Remember, policy development is an iterative process, and there's always room for improvement.
- Ensure that your code complies with the Python PEP8 coding standards.
- When implementing methods like
_forward_learn
,_forward_collect
, and_forward_eval
, ensure that you correctly handle input and returned data. - When writing your policy, ensure that you consider different types of environments. Your policy should be able to handle various environments.
- When implementing your policy, try to make your code as modular as possible, facilitating others to understand and reuse your code.
- Write clear documentation and comments describing how your policy works and how your code implements this policy. Strive to maintain the core meaning of the content while enhancing its professionalism and fluency.