Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

O(n) Embedding #80

Merged
merged 10 commits into from
Sep 14, 2024
Merged

O(n) Embedding #80

merged 10 commits into from
Sep 14, 2024

Conversation

hikettei
Copy link
Owner

@hikettei hikettei commented Sep 13, 2024

  • Fuse INDEX-COMPONENT in MultiExpr. (Arange is an alias for Index-Components)
  • Optimize Embedding
  • Writing a test for embedding
  • Test for Norms
  • Test for ConvND
  • (caten (!where (!eq (iconst 2) (ax+b `(3 3) 1 0)) (iconst 2) (iconst 3))) in a single line
  • (caten (!where (!eq (iconst 2) (!index-components `(3 3))) (iconst 2) (iconst 3)))
  • Collapse outermost (batch_size, sequence_len) loops
  • Loop Collapse:
    • ModuleのShapeTracker e.g. A[~] -> A[~]があったら
    • 入力を(!reshape x `(-1 ... minimum_rank))までReshapeする
    • 出力を元々のShapeに戻す
    • でOK
    • EmbeddingはElementWiseである
    • defmethod callでwrap
    • TODO: reshape -1
  • In order to generate the equivalent kernel as tinygrad, we must group all ops in Embedding in the same schedule_group
  • Scheduleした後にMultiExpr適用
  • render-isl-aref: EXPR生成するλlambda関数のやつ作る (format nil no kawrai ni make-expr
  • Loop Permutationが必要: Embedding_dim is the most inner loop

@hikettei
Copy link
Owner Author

Optimal Kernel for Embeding should looks like:

// val_36 for result
// val_31 for weight
// val_8 for input sequence
void main6598825_e26_k0(float* val_36, float* val_31, float* val_8) {
  for(int _gid0=0;(_gid0<=7);_gid0+=1) {
    for(int _gid1=0;(_gid1<=29);_gid1+=1) {
      val_36[240*0+30*_gid0+30*0+_gid1] = 0.0;
      for(int _gid2=0;(_gid2<=29);_gid2+=1) {
        val_36[240*0+30*_gid0+0+_gid1] += (!((30*0+30*0+1*_gid2+1*0)!=val_8[8*0+_gid0+0+0]) ? 1.0 : 0.0) * val_31[900*0+0+30*_gid2+_gid1];
      }
    }
  }
}

@hikettei hikettei changed the title Embedding/Norm/ConvND O(n) Embedding Sep 13, 2024
@hikettei
Copy link
Owner Author

hikettei commented Sep 13, 2024

void main9860200_e24_k0(boolean* val_12, float* val_30, float* val_17, float* val_8, float* val_3) {
  for(int _gid0=0;(_gid0<=7);_gid0+=1) {
    for(int _gid1=0;(_gid1<=29);_gid1+=1) {
      val_30[240*0+30*_gid0+30*0+_gid1] = 0.0;
      for(int _gid2=0;(_gid2<=29);_gid2+=1) {
        if ((_gid0==0)&(_gid1==0)) {
          val_3[30*0+30*0+_gid2+0] = (30*0+30*0+1*_gid2+1*0);
        }
        if (_gid1==0) {
          float val_2 = val_3[30*0+0+_gid2+0];
          boolean val_1 = (val_2!=val_8[8*0+_gid0+0+0]);
          val_12[240*0+30*_gid0+_gid2+0] = !val_1;
        }
        float val_16 = (val_12[240*0+30*_gid0+_gid2+0] ? val_17[900*0+0+30*_gid2+_gid1] : 0.0);
        val_30[240*0+30*_gid0+0+_gid1] = (val_30[240*0+30*_gid0+0+_gid1]+val_16);
      }
    }
  }
}
  • no need to consider loop permutation
  • Infer val_3=_gid2
  • 原因: index-componentsのbroadcasting (bound=1でfuseして欲しいんだけどISLはFuseしない)

@hikettei
Copy link
Owner Author

hikettei commented Sep 13, 2024

最適化は後回しにして,一旦テストと速度検証のためのPipelineを作成する。(llama3をCompileするのが優先,その次に最適化) Don't guess, measure!

@hikettei hikettei marked this pull request as ready for review September 14, 2024 02:34
@hikettei hikettei merged commit 26265a6 into main Sep 14, 2024
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant