Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Generate spmd rule and reshard impl in phi api #56831

Merged
merged 15 commits into from
Sep 6, 2023

Conversation

chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Aug 31, 2023

PR types

New features

PR changes

Others

Description

Pcard-73145

[AutoParallel] Generate spmd rule and reshard impl in phi api

在PHI前向API中生成切分推导与切分转换的逻辑实现。

  1. 切分推导

具体地,以matmul为例,切分推导规则所对应函数增加到了yaml的infer_meta字段下

- op : matmul
  args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false)
  output : Tensor
  infer_meta :
    func : MatmulInferMeta
    spmd_rule : MatmulSpmdInferForward
  kernel :
    func : matmul
  backward : matmul_grad

考虑如下:

  • 切分推导仍然是属于Tensor Meta信息推导的范畴,且预计在较长时间内都是可选字段,因此不在顶层新增,而作为infer_meta字段的一个子字段
  • 切分推导的输入参数目前复用infer_meta的param信息,如果出现不一致的情况,可能也需要为spmd_rule新增param的子字段

然后在前向API实现中生成对MatmulSpmdInferForward的调用:

    // 1. InferSpmd (Infer DistAttr of Inputs&Outputs)
    auto meta_dist_x = MakeDistMetaTensor(*x.impl());
    auto meta_dist_y = MakeDistMetaTensor(*y.impl());
    auto spmd_info = phi::distributed::MatmulSpmdInferForward(meta_dist_x, meta_dist_y, transpose_x, transpose_y);
  1. 切分转换

前向API仅需要对Input进行切分转换,对原来假设的流程进行了微调,生成代码如下:

    // 5. Reshard Input
    auto dist_input_x = ReshardDistTensor(dev_ctx, x, spmd_info.first[0]);
    auto dist_input_y = ReshardDistTensor(dev_ctx, y, spmd_info.first[1]);

调整后的前向API动半分支生成结果:

  // Auto Parallel condition
  if (AllInputsAreDistTensor(x, y)) {
    // 1. InferSpmd (Infer DistAttr of Inputs&Outputs)
    auto meta_dist_x = MakeDistMetaTensor(*x.impl());
    auto meta_dist_y = MakeDistMetaTensor(*y.impl());
    auto spmd_info = phi::distributed::MatmulSpmdInferForward(meta_dist_x, meta_dist_y, transpose_x, transpose_y);

    // 2. Create API Output & Prepare Dist and Dense Output
    Tensor api_output;

    auto dist_out = SetKernelDistOutput(&api_output, spmd_info.second[0]);
    auto dense_out = dist_out->unsafe_mutable_value();

    // 3. Infer DistTensor's Global Shape
    phi::MetaTensor meta_dist_out(dist_out);
    phi::MatmulInferMeta(meta_dist_x, meta_dist_y, transpose_x, transpose_y, &meta_dist_out);

    // 4. Select Kernel
    VLOG(6) << "matmul API dist branch: kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
    auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
        "matmul", {kernel_backend, kernel_layout, kernel_data_type});
    const auto& kernel = kernel_result.kernel;
    VLOG(6) << "matmul kernel: " << kernel;
    auto* dev_ctx = GetDeviceContextByBackend(kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);

    // 5. Reshard Input
    auto dist_input_x = ReshardDistTensor(dev_ctx, x, spmd_info.first[0]);
    auto dist_input_y = ReshardDistTensor(dev_ctx, y, spmd_info.first[1]);

    // 6. PrepareData (DataTransform & Prepare Dense Input)
    dist_input_x = PrepareDataForDistTensor(dist_input_x, GetKernelInputArgDef(kernel.InputAt(0), kernel_backend), {}, kernel_result.is_stride_kernel);
    auto input_x = &dist_input_x->value();

    dist_input_y = PrepareDataForDistTensor(dist_input_y, GetKernelInputArgDef(kernel.InputAt(1), kernel_backend), {}, kernel_result.is_stride_kernel);
    auto input_y = &dist_input_y->value();

    // 7. Infer Local DenseTensor Meta
    phi::MetaTensor meta_dense_out(dense_out);
    phi::MatmulInferMeta(MakeMetaTensor(*input_x), MakeMetaTensor(*input_y), transpose_x, transpose_y, &meta_dense_out);

    // 8. DenseTensor Kernel Call
    using kernel_signature = void(*)(const phi::DeviceContext&, const phi::DenseTensor&, const phi::DenseTensor&, bool, bool, phi::DenseTensor*);
    auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
    (*kernel_fn)(*dev_ctx, *input_x, *input_y, transpose_x, transpose_y, dense_out);

    // 9. Return
    return api_output;
  }

TODO(下个PR进行):

  1. 反向流程需要重新梳理完善一下
  2. InferSpmd函数命名更新
  3. 特殊情况,前向输出如果是partial,需要reshard为replicated

@paddle-bot
Copy link

paddle-bot bot commented Aug 31, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@chenwhql chenwhql changed the title [AutoParallel] Adapt general spmd rule for static and dynamic mode [AutoParallel] Generate spmd rule and reshard impl in phi api Aug 31, 2023
Copy link
Contributor

@LiYuRio LiYuRio left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@GhostScreaming GhostScreaming left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql closed this Sep 6, 2023
@chenwhql chenwhql reopened this Sep 6, 2023
@chenwhql chenwhql merged commit e9364a3 into PaddlePaddle:develop Sep 6, 2023
BeingGod pushed a commit to BeingGod/Paddle that referenced this pull request Sep 9, 2023
…Paddle#56831)

* add spmd and reshard code gen

* add backward reshard code gen

* test matmul forward success

* polish test impl

* add unsafe mutable value

* polish details and add test

* fix unittest time out

* fix typo

* refactor reshard input generate impl

* resolve conflict with develop

* fix compile error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants