Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT: Post-MultiExpr Optimization between different iteration spaces (Final) #117

Merged
merged 24 commits into from
Sep 28, 2024

Conversation

hikettei
Copy link
Owner

@hikettei hikettei commented Sep 27, 2024

  • Getting Optimal Embedding Kernel
    • Fix the behaviour by writing tests
  • TODO: Index-Component does not take the first argument. (Embedding < 3 Tensors)
  • Getting Optimal Matmul+Transpose Kernel
    • Tests
  • Getting Optimal threefry2x32/randn Kernrel
    • Tests
  • Post-MultiExpr (where subdomain is the equivalent)
  • Index-Component-Fusion (Tested by view)
  • Add the following transform pattern to EXPR:
val_445[(_gid0+0)] = (val_456_0^((val_445[(_gid0+0)]*32768)+(uint32_t)((float)val_445[(_gid0+0)]*7.6293945e-6)));
->
float _tmp_0 = val_445[_gid0+0];
val_445[_gid0+0] = (val_456_0^((_tmp_0*32768)+(uint32_t)((float)_tmp_0*7.6293945e-6)));
  • Fix: 6d transposed matmul scheduler

  • randn < 2 kernels including scalar parts

    • Tests for In-place Embedding, 2 Kernel Randn
    • Accuracy tests for RMSNorm, Embedding, etc ...
  • Deleting the dead code:

void main500811_e229_k0(uint32_t* val_400, uint32_t* val_123, uint32_t val_397) {
  (*val_123) = 4;
  (*val_400) = (val_397+(*val_123));
  uint32_t val_124 = (*val_400);
}
  • More Restrict Schedule test in nn or ajit (e.g.: Embedding in one kernel)

@hikettei
Copy link
Owner Author

hikettei commented Sep 27, 2024

TODO: Make val_650 scalar in the first, second, 3th, ..., n-1 load. Lastly assign to the output. (if other kernels need it)

(proceed (!normal `(3 3) :mean 1.0))

...

void main6948185_e255_k1(float* val_650, float* val_637) {
  for(int _gid0=0;(_gid0<=99);_gid0+=1) {
    for(int _gid1=0;(_gid1<=99);_gid1+=1) {
      val_650[((100*_gid0)+_gid1)] = (sin(((val_637[((100*_gid0)+_gid1)]*6.2831855)+1.5707964))*sqrt(((log2((1.0+-(val_637[((10000+(100*_gid0))+_gid1)])))*0.6931472)*-2.0)));
      val_650[((100*_gid0)+_gid1)] = (val_650[((100*_gid0)+_gid1)]+1.0);
    }
  }
}

@hikettei
Copy link
Owner Author

hikettei commented Sep 27, 2024

Optimal Embedding Kernel!!! (todo: delete val_41, originated from index_components, after merging this pr for debugging)

void main12246031_e24_k0(float* val_55, float* val_31, float* val_37, float* val_41) {
  for(int _gid0=0;(_gid0<=9);_gid0+=1) {
    for(int _gid1=0;(_gid1<=1);_gid1+=1) {
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        val_55[(((100*_gid1)+(10*_gid0))+_gid2)] = 0.0;
        for(int _gid3=0;(_gid3<=9);_gid3+=1) {
          float val_42 = _gid3;
          boolean val_49 = !(val_42!=val_37[((10*_gid1)+_gid2)]);
          val_55[(((100*_gid1)+(10*_gid0))+_gid2)] = (val_55[(((100*_gid1)+(10*_gid0))+_gid2)]+(val_49 ? val_31[((10*_gid3)+_gid2)] : 0.0));
        }
      }
    }
  }
}

@hikettei
Copy link
Owner Author

Symbolic

void main12360416_e153_k0(uint32_t val_169, uint32_t val_206, uint32_t batch, uint32_t vocab_size, float* val_212, float* val_67, float* val_145, float* val_157) {
  for(int _gid0=0;(_gid0<vocab_size);_gid0+=1) {
    for(int _gid1=0;(_gid1<batch);_gid1+=1) {
      for(int _gid2=0;(_gid2<=102);_gid2+=1) {
        val_212[(((val_206*_gid1)+(103*_gid0))+_gid2)] = 0.0;
        for(int _gid3=0;(_gid3<=100);_gid3+=1) {
          float val_158 = _gid3;
          boolean val_178 = !(val_158!=val_145[((vocab_size*_gid1)+_gid0)]);
          val_212[(((val_206*_gid1)+(103*_gid0))+_gid2)] = (val_212[(((val_206*_gid1)+(103*_gid0))+_gid2)]+(val_178 ? val_67[((103*_gid3)+_gid2)] : 0.0));
        }
      }
    }
  }
}

@hikettei
Copy link
Owner Author

hikettei commented Sep 28, 2024

Writing tests for:

  • Embedding (1 Kernel, 3 Tensors)
  • Threefry2x32 (2 Kernel, 2 Tensors)
  • FeedForward (After implementing WMMA, 2 Kernels)
  • SDotATtn (After implementing WMMA Fusion, 2 Kernels)

@hikettei
Copy link
Owner Author

Embedding

TEST> (ctx:with-contextvar (:packed 0 :jit 1 :jit_debug 1)
        (with-no-grad
          (caten (call (Embedding 100 100) (make-tensor `(100 100))))))
Compiled[e23]:
Compiled[e24]:

/*
Arrays:
  - val_55[float32]: (100 100 1 100) // OUTPUT, TMP
  - val_31[float32]: (1 1 100 100) // INPUT, TMP
  - val_37[float32]: (100 100 1 1) // INPUT, TMP
*/
void main1638888_e24_k0(float* val_55, float* val_31, float* val_37);
void main1638888_e24_k0(float* val_55, float* val_31, float* val_37) {
  for(int _gid0=0;(_gid0<=99);_gid0+=1) {
    for(int _gid1=0;(_gid1<=99);_gid1+=1) {
      for(int _gid2=0;(_gid2<=99);_gid2+=1) {
        val_55[(((10000*_gid1)+(100*_gid0))+_gid2)] = 0.0;
        for(int _gid3=0;(_gid3<=99);_gid3+=1) {
          float val_42 = _gid3;
          boolean val_49 = !(val_42!=val_37[((100*_gid1)+_gid0)]);
          val_55[(((10000*_gid1)+(100*_gid0))+_gid2)] = (val_55[(((10000*_gid1)+(100*_gid0))+_gid2)]+(val_49 ? val_31[((100*_gid3)+_gid2)] : 0.0));
        }
      }
    }
  }
}

@hikettei hikettei marked this pull request as ready for review September 28, 2024 03:41
@hikettei hikettei merged commit 96f6e45 into main Sep 28, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant