Skip to content

Latest commit

 

History

History
166 lines (119 loc) · 11.7 KB

customize_algos_zh.md

File metadata and controls

166 lines (119 loc) · 11.7 KB

LightZero 中如何自定义算法?

LightZero 是一个 MCTS+RL 强化学习框架,它提供了一组高级 API,使得用户可以在其中自定义自己的算法。以下是一些关于如何在 LightZero 中自定义算法的步骤和注意事项。

基本步骤

1. 理解框架结构

在开始编写自定义算法之前,你需要对 LightZero 的框架结构有一个基本的理解,LightZero 的流程如图所示。

Image

仓库的文件夹主要由 lzerozoo 这两部分组成。lzero 中实现了LightZero框架流程所需的核心模块。而 zoo 提供了一系列预定义的环境(envs)以及对应的配置(config)文件。 lzero 文件夹下包括多个核心模块,包括策略(policy)、模型(model)、工作件(worker)以及入口(entry)等。这些模块在一起协同工作,实现复杂的强化学习算法。

  • 在此架构中,policy 模块负责实现算法的决策逻辑,如在智能体与环境交互时的动作选择,以及如何根据收集到的数据更新策略。 model 模块则负责实现算法所需的神经网络结构。
  • worker 模块包含 Collector 和 Evaluator 两个类。 Collector 实例负责执行智能体与环境的交互,以收集训练所需的数据,而 Evaluator 实例则负责评估当前策略的性能。
  • entry 模块负责初始化环境、模型、策略等,并在其主循环中负责实现数据收集、模型训练以及策略评估等核心过程。
  • 在这些模块之间,存在着紧密的交互关系。具体来说, entry 模块会调用 worker 模块的Collector和Evaluator来完成数据收集和算法评估。同时, policy 模块的决策函数会被Collector和Evaluator调用,以决定智能体在特定环境中的行动。而 model 模块实现的神经网络模型,则被嵌入到 policy 对象中,用于在交互过程中生成动作,以及在训练过程中进行更新。
  • policy 模块中,你可以找到多种算法的实现,例如,MuZero策略就在 muzero.py 文件中实现。

2. 创建新的策略文件

lzero/policy 目录下创建一个新的 Python 文件。这个文件将包含你的算法实现。例如,如果你的算法名为 MyAlgorithm ,你可以创建一个名为 my_algorithm.py 的文件。

3. 实现你的策略

在你的策略文件中,你需要定义一个类来实现你的策略。这个类应该继承自 DI-engine中的 Policy 类,并实现所需的方法。

以下是一个基本的策略类的框架:

@POLICY_REGISTRY.register('my_algorithm')
class MyAlgorithmPolicy(Policy):
    """
    Overview:
        The policy class for MyAlgorithm.
    """
    
    config = dict(
        # Add your config here
    )
    
    def __init__(self, cfg, **kwargs):
        super().__init__(cfg, **kwargs)
        # Initialize your policy here

    def default_model(self) -> Tuple[str, List[str]]:
        # Set the default model name and the import path so that the default model can be loaded during policy initialization
    
    def _init_learn(self):
        # Initialize the learn mode here
    
    def _forward_learn(self, data):
        # Implement the forward function for learning mode here
    
    def _init_collect(self):
        # Initialize the collect mode here
    
    def _forward_collect(self, data, **kwargs):
        # Implement the forward function for collect mode here
    
    def _init_eval(self):
        # Initialize the eval mode here
    
    def _forward_eval(self, data, **kwargs):
        # Implement the forward function for eval mode here

收集数据与评估模型

  • default_model 中设置当前策略使用的默认模型的类名和相应的引用路径。
  • _init_collect_init_eval 函数均负责实例化动作选取策略,相应的策略实例会被 _forward_collect_forward_eval 函数调用。
  • _forward_collect 函数会接收当前环境的状态,并通过调用 _init_collect 中实例化的策略来选择一步动作。函数会返回所选的动作列表以及其他相关信息。在训练期间,该函数会通过由Entry文件创建的Collector对象的 collector.collect 方法进行调用。
  • _forward_eval 函数的逻辑与 _forward_collect 函数基本一致。唯一的区别在于, _forward_collect 中采用的策略更侧重于探索,以收集尽可能多样的训练信息;而在 _forward_eval 函数中,所采用的策略更侧重于利用,以获取当前策略的最优性能。在训练期间,该函数会通过由Entry文件创建的Evaluator对象的 evaluator.eval 方法进行调用。

策略的学习

  • _init_learn 函数会利用 config 文件传入的学习率、更新频率、优化器类型等策略的关联参数初始化网络模型、优化器以及训练过程中所需的其他对象。
  • _forward_learn 函数则负责实现网络的更新。通常, _forward_learn 函数会接收 Collector 所收集的数据,根据这些数据计算损失函数并进行梯度更新。函数会返回更新过程中的各项损失以及更新所采用的相关参数,以便进行实验记录。在训练期间,该函数会通过由 Entry 文件创建的 Learner 对象的 learner.train 方法进行调用。

4. 注册你的策略

为了让 LightZero 能够识别你的策略,你需要在你的策略类上方使用 @POLICY_REGISTRY.register('my_algorithm') 这个装饰器来注册你的策略。这样, LightZero 就可以通过 'my_algorithm' 这个名字来引用你的策略了。 具体而言,在实验的配置文件中,通过 create_config 部分来指定相应的算法:

create_config = dict(
    ...
    policy=dict(
        type='my_algorithm',
        import_names=['lzero.policy.my_algorithm'],
    ),
    ...
)

其中 type 要设定为所注册的策略名, import_names 则设置为策略包的位置。

5. 可能的其他更改

  • 模型(model):在 LightZero 的 model.common 包中提供了一些通用的网络结构,例如将2D图像映射到隐空间中的表征网络 RepresentationNetwork ,在MCTS中用于预测概率和节点价值的预测网络 PredictionNetwork 等。如果自定义的策略需要专门的网络模型,则需要自行在 model 文件夹下实现相应的模型。例如 Muzero 算法的模型保存在 muzero_model.py 文件中,该文件实现了 Muzero 算法所需要的 DynamicsNetwork ,并通过调用 model.common 包中现成的网络结构最终实现了 MuZeroModel
  • 工作件(worker):在 LightZero 中实现了 AlphaZero 和 MuZero 的相应 worker 。后续的 EfficientZero 和 GumbelMuzero 等算法沿用了 MuZero 的 worker 。如果你的算法在数据采集的逻辑上有所不同,则需要自行实现相应的 worker 。例如,如果你的算法需要对采集到的transitions 进行预处理,可以在 collector 文件中的 collect 函数下加入下面这一片段。其中 get_train_sample 函数实现了具体的数据处理过程。
if timestep.done:
    # Prepare trajectory data.
    transitions = to_tensor_transitions(self._traj_buffer[env_id])
    # Use ``get_train_sample`` to process the data.
    train_sample = self._policy.get_train_sample(transitions)
    return_data.extend(train_sample)
    self._traj_buffer[env_id].clear()

6. 测试你的策略

在你实现你的策略之后,确保策略的正确性和有效性是非常重要的。为此,你应该编写一些单元测试来验证你的策略是否正常工作。比如,你可以测试策略是否能在特定的环境中执行,策略的输出是否符合预期等。单元测试的编写及意义可以参考 DI-engine 中的单元测试指南 ,你可以在 lzero/policy/tests 目录下添加你的测试。在编写测试时,尽可能考虑到所有可能的场景和边界条件,确保你的策略在各种情况下都能正常运行。 下面是一个 LightZero 中单元测试的例子。在这个例子中,所测试的对象是 inverse_scalar_transformInverseScalarTransform 方法。这两个方法都将经过变换的 value 逆变换为原本的值,但是采取了不同的实现。单元测试时,用这两个方法对同一组数据进行处理,并比较输出的结果是否相同。如果相同,则会通过测试。

import pytest
import torch
from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform

@pytest.mark.unittest
def test_scaling_transform():
    import time
    logit = torch.randn(16, 601)
    start = time.time()
    output_1 = inverse_scalar_transform(logit, 300)
    print('t1', time.time() - start)
    handle = InverseScalarTransform(300)
    start = time.time()
    output_2 = handle(logit)
    print('t2', time.time() - start)
    assert output_1.shape == output_2.shape == (16, 1)
    assert (output_1 == output_2).all()

在单元测试文件中,要将测试通过 @pytest.mark.unittest 标记到python的测试框架中,这样就可以通过在命令行输入 pytest -sv xxx.py 直接运行单元测试文件。其中 -sv 是一个命令选项,表示在测试运行过程中将详细的信息打印到终端以便查看。

7. 完整测试与运行

在确保策略的基本功能正常之后,你需要利用如 cartpole 等经典环境,对你的策略进行完整的正确性和收敛性测试。这是为了验证你的策略不仅能在单元测试中工作,而且能在实际游戏环境中有效工作。

你可以仿照 cartpole_muzero_config.py 编写相关的配置文件和入口程序。在测试过程中,注意记录策略的性能数据,如每轮的得分、策略的收敛速度等,以便于分析和改进。

8. 贡献

在你完成了所有以上步骤后,如果你希望把你的策略贡献到 LightZero 仓库中,你可以在官方仓库上提交 Pull Request 。在提交之前,请确保你的代码符合仓库的编码规范,所有测试都已通过,并且已经有足够的文档和注释来解释你的代码和策略。

在 PR 的描述中,详细说明你的策略,包括它的工作原理,你的实现方法,以及在测试中的表现。这会帮助其他人理解你的贡献,并加速 PR 的审查过程。

9. 分享讨论,反馈改进

完成策略实现和测试后,考虑将你的结果和经验分享给社区。你可以在论坛、博客或者社交媒体上发布你的策略和测试结果,邀请其他人对你的工作进行评价和讨论。这不仅可以得到其他人的反馈,还能帮助你建立专业网络,并可能引发新的想法和合作。

基于你的测试结果和社区的反馈,不断改进和优化你的策略。这可能涉及到调整策略的参数,改进代码的性能,或者解决出现的问题和 bug 。记住,策略的开发是一个迭代的过程,永远有提升的空间。

注意事项

  • 请确保你的代码符合 python PEP8 编码规范。
  • 当你在实现 _forward_learn_forward_collect_forward_eval 等方法时,请确保正确处理输入和返回的数据。
  • 在编写策略时,请确保考虑到不同的环境类型。你的策略应该能够处理不同的环境。
  • 在实现你的策略时,请尽可能使你的代码模块化,以便于其他人理解和重用你的代码。
  • 请编写清晰的文档和注释,描述你的策略如何工作,以及你的代码是如何实现这个策略的。