Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

op unittest for split #1453

Merged
merged 2 commits into from
May 24, 2023
Merged

op unittest for split #1453

merged 2 commits into from
May 24, 2023

Conversation

zzk0
Copy link
Contributor

@zzk0 zzk0 commented May 23, 2023

描述

From #1378

给 split 算子添加单元测试。

算子类型

  • ElementWise:输入张量索引和输出张量索引之间存在一对一的对应关系
  • Broadcast:输入张量索引和输出张量索引之间存在一对多的对应关系
  • Injective:单射算子,可以将一个输出 axis 映射到一个输入 axis
  • Reduction:输入张量索引和输出张量索引之间存在多对一的对应关系
  • OutFusible:复杂算子,仍然可以将一对一的算子融合到其输出中。
  • kNonFusible:无法融合的算子

Test Cases Checklist

张量维度

  • 1D 张量
  • 2D 张量
  • 3D 张量
  • 4D 张量

special shape

挑选 2D/3D/4D 张量中的一个,测试下面的特殊情况。

  • 其中一个维度为 1
  • 其中一个维度小于 1024
  • 其中一个维度大于 1024
  • 向量的所有维度都是 1

张量数据类型

  • int32
  • int64
  • float16
  • float32
  • float64

###广播

  • 这个算子是否支持广播?
  • 广播的测试样例

算子属性

  • num_or_sections: 配置 split axis 的数量
  • axis: 指定 split 的维度

@paddle-bot
Copy link

paddle-bot bot commented May 23, 2023

Thanks for your contribution!

@zzk0
Copy link
Contributor Author

zzk0 commented May 23, 2023

遗留问题:

TestSplitOpAttributeLargeNum 中 num_or_sections: [512] 这个 test case 会生成如下 kernel,在 num_or_sections 数量超过 128 的时候,一个测试用例的整体运行时间将超过 1 分钟。代码生成之后运行起来其实还好,不会太慢。猜想可能的原因有:中间的优化步骤太慢?代码生成太慢?

extern "C" {

__global__
void __launch_bounds__(8) fn_split_0_kernel(const float* __restrict__ x, float* __restrict__ var_128, float* __restrict__ var_129, float* __restrict__ var_130, float* __restrict__ var_131, float* __restrict__ var_132, float* __restrict__ var_133, float* __restrict__ var_134, float* __restrict__ var_135, float* __restrict__ var_136, float* __restrict__ var_137, float* __restrict__ var_138, float* __restrict__ var_139, float* __restrict__ var_140, float* __restrict__ var_141, float* __restrict__ var_142, float* __restrict__ var_143, float* __restrict__ var_144, float* __restrict__ var_145, float* __restrict__ var_146, float* __restrict__ var_147, float* __restrict__ var_148, float* __restrict__ var_149, float* __restrict__ var_150, float* __restrict__ var_151, float* __restrict__ var_152, float* __restrict__ var_153, float* __restrict__ var_154, float* __restrict__ var_155, float* __restrict__ var_156, float* __restrict__ var_157, float* __restrict__ var_158, float* __restrict__ var_159, float* __restrict__ var_160, float* __restrict__ var_161, float* __restrict__ var_162, float* __restrict__ var_163, float* __restrict__ var_164, float* __restrict__ var_165, float* __restrict__ var_166, float* __restrict__ var_167, float* __restrict__ var_168, float* __restrict__ var_169, float* __restrict__ var_170, float* __restrict__ var_171, float* __restrict__ var_172, float* __restrict__ var_173, float* __restrict__ var_174, float* __restrict__ var_175, float* __restrict__ var_176, float* __restrict__ var_177, float* __restrict__ var_178, float* __restrict__ var_179, float* __restrict__ var_180, float* __restrict__ var_181, float* __restrict__ var_182, float* __restrict__ var_183, float* __restrict__ var_184, float* __restrict__ var_185, float* __restrict__ var_186, float* __restrict__ var_187, float* __restrict__ var_188, float* __restrict__ var_189, float* __restrict__ var_190, float* __restrict__ var_191, float* __restrict__ var_192, float* __restrict__ var_193, float* __restrict__ var_194, float* __restrict__ var_195, float* __restrict__ var_196, float* __restrict__ var_197, float* __restrict__ var_198, float* __restrict__ var_199, float* __restrict__ var_200, float* __restrict__ var_201, float* __restrict__ var_202, float* __restrict__ var_203, float* __restrict__ var_204, float* __restrict__ var_205, float* __restrict__ var_206, float* __restrict__ var_207, float* __restrict__ var_208, float* __restrict__ var_209, float* __restrict__ var_210, float* __restrict__ var_211, float* __restrict__ var_212, float* __restrict__ var_213, float* __restrict__ var_214, float* __restrict__ var_215, float* __restrict__ var_216, float* __restrict__ var_217, float* __restrict__ var_218, float* __restrict__ var_219, float* __restrict__ var_220, float* __restrict__ var_221, float* __restrict__ var_222, float* __restrict__ var_223, float* __restrict__ var_224, float* __restrict__ var_225, float* __restrict__ var_226, float* __restrict__ var_227, float* __restrict__ var_228, float* __restrict__ var_229, float* __restrict__ var_230, float* __restrict__ var_231, float* __restrict__ var_232, float* __restrict__ var_233, float* __restrict__ var_234, float* __restrict__ var_235, float* __restrict__ var_236, float* __restrict__ var_237, float* __restrict__ var_238, float* __restrict__ var_239, float* __restrict__ var_240, float* __restrict__ var_241, float* __restrict__ var_242, float* __restrict__ var_243, float* __restrict__ var_244, float* __restrict__ var_245, float* __restrict__ var_246, float* __restrict__ var_247, float* __restrict__ var_248, float* __restrict__ var_249, float* __restrict__ var_250, float* __restrict__ var_251, float* __restrict__ var_252, float* __restrict__ var_253, float* __restrict__ var_254, float* __restrict__ var_255)
{
  if (((int)blockIdx.x < 1)) {
    if (((int)threadIdx.x < 8)) {
    {
      var_128[(int)threadIdx.x] = x[(int)threadIdx.x];
      var_129[(int)threadIdx.x] = x[(8 + (int)threadIdx.x)];
      var_130[(int)threadIdx.x] = x[(16 + (int)threadIdx.x)];
      var_131[(int)threadIdx.x] = x[(24 + (int)threadIdx.x)];
      var_132[(int)threadIdx.x] = x[(32 + (int)threadIdx.x)];
      var_133[(int)threadIdx.x] = x[(40 + (int)threadIdx.x)];
      var_134[(int)threadIdx.x] = x[(48 + (int)threadIdx.x)];
      var_135[(int)threadIdx.x] = x[(56 + (int)threadIdx.x)];
      var_136[(int)threadIdx.x] = x[(64 + (int)threadIdx.x)];
      var_137[(int)threadIdx.x] = x[(72 + (int)threadIdx.x)];
      var_138[(int)threadIdx.x] = x[(80 + (int)threadIdx.x)];
      ......

@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label May 23, 2023
@thisjiang
Copy link
Collaborator

遗留问题:

TestSplitOpAttributeLargeNum 中 num_or_sections: [512] 这个 test case 会生成如下 kernel,在 num_or_sections 数量超过 128 的时候,一个测试用例的整体运行时间将超过 1 分钟。代码生成之后运行起来其实还好,不会太慢。猜想可能的原因有:中间的优化步骤太慢?代码生成太慢?

extern "C" {

__global__
void __launch_bounds__(8) fn_split_0_kernel(const float* __restrict__ x, float* __restrict__ var_128, float* __restrict__ var_129, float* __restrict__ var_130, float* __restrict__ var_131, float* __restrict__ var_132, float* __restrict__ var_133, float* __restrict__ var_134, float* __restrict__ var_135, float* __restrict__ var_136, float* __restrict__ var_137, float* __restrict__ var_138, float* __restrict__ var_139, float* __restrict__ var_140, float* __restrict__ var_141, float* __restrict__ var_142, float* __restrict__ var_143, float* __restrict__ var_144, float* __restrict__ var_145, float* __restrict__ var_146, float* __restrict__ var_147, float* __restrict__ var_148, float* __restrict__ var_149, float* __restrict__ var_150, float* __restrict__ var_151, float* __restrict__ var_152, float* __restrict__ var_153, float* __restrict__ var_154, float* __restrict__ var_155, float* __restrict__ var_156, float* __restrict__ var_157, float* __restrict__ var_158, float* __restrict__ var_159, float* __restrict__ var_160, float* __restrict__ var_161, float* __restrict__ var_162, float* __restrict__ var_163, float* __restrict__ var_164, float* __restrict__ var_165, float* __restrict__ var_166, float* __restrict__ var_167, float* __restrict__ var_168, float* __restrict__ var_169, float* __restrict__ var_170, float* __restrict__ var_171, float* __restrict__ var_172, float* __restrict__ var_173, float* __restrict__ var_174, float* __restrict__ var_175, float* __restrict__ var_176, float* __restrict__ var_177, float* __restrict__ var_178, float* __restrict__ var_179, float* __restrict__ var_180, float* __restrict__ var_181, float* __restrict__ var_182, float* __restrict__ var_183, float* __restrict__ var_184, float* __restrict__ var_185, float* __restrict__ var_186, float* __restrict__ var_187, float* __restrict__ var_188, float* __restrict__ var_189, float* __restrict__ var_190, float* __restrict__ var_191, float* __restrict__ var_192, float* __restrict__ var_193, float* __restrict__ var_194, float* __restrict__ var_195, float* __restrict__ var_196, float* __restrict__ var_197, float* __restrict__ var_198, float* __restrict__ var_199, float* __restrict__ var_200, float* __restrict__ var_201, float* __restrict__ var_202, float* __restrict__ var_203, float* __restrict__ var_204, float* __restrict__ var_205, float* __restrict__ var_206, float* __restrict__ var_207, float* __restrict__ var_208, float* __restrict__ var_209, float* __restrict__ var_210, float* __restrict__ var_211, float* __restrict__ var_212, float* __restrict__ var_213, float* __restrict__ var_214, float* __restrict__ var_215, float* __restrict__ var_216, float* __restrict__ var_217, float* __restrict__ var_218, float* __restrict__ var_219, float* __restrict__ var_220, float* __restrict__ var_221, float* __restrict__ var_222, float* __restrict__ var_223, float* __restrict__ var_224, float* __restrict__ var_225, float* __restrict__ var_226, float* __restrict__ var_227, float* __restrict__ var_228, float* __restrict__ var_229, float* __restrict__ var_230, float* __restrict__ var_231, float* __restrict__ var_232, float* __restrict__ var_233, float* __restrict__ var_234, float* __restrict__ var_235, float* __restrict__ var_236, float* __restrict__ var_237, float* __restrict__ var_238, float* __restrict__ var_239, float* __restrict__ var_240, float* __restrict__ var_241, float* __restrict__ var_242, float* __restrict__ var_243, float* __restrict__ var_244, float* __restrict__ var_245, float* __restrict__ var_246, float* __restrict__ var_247, float* __restrict__ var_248, float* __restrict__ var_249, float* __restrict__ var_250, float* __restrict__ var_251, float* __restrict__ var_252, float* __restrict__ var_253, float* __restrict__ var_254, float* __restrict__ var_255)
{
  if (((int)blockIdx.x < 1)) {
    if (((int)threadIdx.x < 8)) {
    {
      var_128[(int)threadIdx.x] = x[(int)threadIdx.x];
      var_129[(int)threadIdx.x] = x[(8 + (int)threadIdx.x)];
      var_130[(int)threadIdx.x] = x[(16 + (int)threadIdx.x)];
      var_131[(int)threadIdx.x] = x[(24 + (int)threadIdx.x)];
      var_132[(int)threadIdx.x] = x[(32 + (int)threadIdx.x)];
      var_133[(int)threadIdx.x] = x[(40 + (int)threadIdx.x)];
      var_134[(int)threadIdx.x] = x[(48 + (int)threadIdx.x)];
      var_135[(int)threadIdx.x] = x[(56 + (int)threadIdx.x)];
      var_136[(int)threadIdx.x] = x[(64 + (int)threadIdx.x)];
      var_137[(int)threadIdx.x] = x[(72 + (int)threadIdx.x)];
      var_138[(int)threadIdx.x] = x[(80 + (int)threadIdx.x)];
      ......

这问题之前就遇到过,的确是中间AST优化步骤太慢了,当前CINN里AST优化是深度遍历+循环嵌套调用,导致AST优化耗时特别慢。要想修复或许只能等到AST重构了

Copy link
Collaborator

@thisjiang thisjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thisjiang thisjiang merged commit 169126b into PaddlePaddle:develop May 24, 2023
jiahy0825 pushed a commit to jiahy0825/CINN that referenced this pull request May 25, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants