Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix invaild shape inference in composed gemm and index-components #51

Merged
merged 22 commits into from
Sep 5, 2024

Conversation

hikettei
Copy link
Owner

@hikettei hikettei commented Sep 4, 2024

TODO in this PR

  • Reenable (propagate-rebundant-loadp (avm-graph avm)) to get gelu fused
  • wanna use schedule_whole_component? (confirmed softmax was worked)
  • !padding cannot be fused in a single kernel (but it should be ok, as there is no extra memory allocation)
  • composed !tan cannot be fused in a single kernel (need fold scalar simplification)
  • Fix: (!softmax (make-tensor (list 'a 'b))) (do not allow nodes starting with IF)
  • threefry2x32 was broken
    • [FIX] :separateにすると,ASTが生成したkernel単位の時系列がおかしい?
    • Minimal repro: (proceed (!tan (!tan (!tan (make-tensor (list 3 3) :initial-element 1.0)))))
    • eliminate: broadcast [0+0] (due to if c0 == 0 w/ atomic)
    • (caten (call (caten/nn:Embedding 512 1024) (make-tensor (10 30))))
  • Remove duplicated index computation process by using expr-eq. (AbstractTensor.lispのを持ってくる)
    • If tmpvar is only used in the kernel; don't allocate, use scalar instead.
  • Implement Loop Collapse in transform.lisp level. (judge the validity based on permutable and coincidence)
  • Unroll the outermost iteration
  • Pass gemm-test (by fixing invalid shape inference)
  • Eliminate guard node (e.g.: c0 == 0)
    • as well as symbolic mode
    • Appeared in fused !tan, LayerNorm, and ConvND
  • float4 accumlation
  • optimize for :reduction; create float _acc_0
  • Optimize for metal: use float4
  • Simplifier: Fold Scalars
  • Make Symbolic Compilation completely equivalent to static one.
    • Involve shape computation in a VM level op.
  • BEAM Search for determining global_size/local_size
  • Lisp JIT
  • Fuse Conv < 2 Kernels
    • delete guard: c0==0
    • relocated_to_most_nearest
  • Add NN Ops test with restricted accuracy
    • make all tests green
  • BugFix:
(let ((a (make-tensor `(5 5))))
	(caten (!add (!view a `(0 2) `(0 2)) (!view a `(2 4) `(2 4)))))
  • Test w/ 1024x1024 gemm

@hikettei hikettei changed the title [WIP] [WIP] Fix invaild shape inference in composed gemm and index-components Sep 4, 2024
@hikettei hikettei marked this pull request as ready for review September 4, 2024 09:11
@hikettei
Copy link
Owner Author

hikettei commented Sep 4, 2024

(Unrolling can be replaced w/ Vectorize, 8x8 gemm, Tiling, etc ... in the renderer level, they are still polyhedron.)
Softmax in this PR looks like:

CATEN> (caten (!softmax (ax+b `(256 256) 0.01 0.0)))
Compiled[e5]:

/*
Arrays:
  - val_6[float32]: (256 1) // OUTPUT, TMP
  - val_2[float32]: (256 256) // OUTPUT, TMP
  - val_16[float32]: (256 1) // OUTPUT, TMP
*/
void main77636_e5_k0(float* val_6, float* val_2, float* val_16);
void main77636_e5_k0(float* val_6, float* val_2, float* val_16) {
  for(int c0=0;(c0<=255);c0+=1) {
    val_6[c0+0] = 0.0;
    for(int c1=0;(c1<=255);c1+=4) {
      val_2[256*c0+(c1+0)] = ((256*c0+1*(c1+0))*0.01);
      val_2[256*c0+(c1+1)] = ((256*c0+1*(c1+1))*0.01);
      val_2[256*c0+(c1+2)] = ((256*c0+1*(c1+2))*0.01);
      val_2[256*c0+(c1+3)] = ((256*c0+1*(c1+3))*0.01);
      val_2[256*c0+(c1+0)] = val_2[256*c0+(c1+0)];
      val_2[256*c0+(c1+1)] = val_2[256*c0+(c1+1)];
      val_2[256*c0+(c1+2)] = val_2[256*c0+(c1+2)];
      val_2[256*c0+(c1+3)] = val_2[256*c0+(c1+3)];
      val_6[c0+0] = max(val_6[c0+0], val_2[256*c0+(c1+0)]);
      val_6[c0+0] = max(val_6[c0+0], val_2[256*c0+(c1+1)]);
      val_6[c0+0] = max(val_6[c0+0], val_2[256*c0+(c1+2)]);
      val_6[c0+0] = max(val_6[c0+0], val_2[256*c0+(c1+3)]);
    }
    val_6[c0+0] = -(val_6[c0+0]);
    val_16[c0+0] = 0.0;
    for(int c1=0;(c1<=255);c1+=4) {
      val_2[256*c0+(c1+0)] = exp2(((val_2[256*c0+(c1+0)]+val_6[c0+0])*1.442695));
      val_2[256*c0+(c1+1)] = exp2(((val_2[256*c0+(c1+1)]+val_6[c0+0])*1.442695));
      val_2[256*c0+(c1+2)] = exp2(((val_2[256*c0+(c1+2)]+val_6[c0+0])*1.442695));
      val_2[256*c0+(c1+3)] = exp2(((val_2[256*c0+(c1+3)]+val_6[c0+0])*1.442695));
      val_16[c0+0] = (val_16[c0+0]+val_2[256*c0+(c1+0)]);
      val_16[c0+0] = (val_16[c0+0]+val_2[256*c0+(c1+1)]);
      val_16[c0+0] = (val_16[c0+0]+val_2[256*c0+(c1+2)]);
      val_16[c0+0] = (val_16[c0+0]+val_2[256*c0+(c1+3)]);
    }
    val_16[c0+0] = 1/(val_16[c0+0]);
    for(int c1=0;(c1<=255);c1+=4) {
      val_2[256*c0+(c1+0)] = (val_2[256*c0+(c1+0)]*val_16[c0+0]);
      val_2[256*c0+(c1+1)] = (val_2[256*c0+(c1+1)]*val_16[c0+0]);
      val_2[256*c0+(c1+2)] = (val_2[256*c0+(c1+2)]*val_16[c0+0]);
      val_2[256*c0+(c1+3)] = (val_2[256*c0+(c1+3)]*val_16[c0+0]);
    }
  }
}

Gemm in this PR

CATEN> (caten (!matmul (make-tensor `(256 512)) (make-tensor `(512 1024))))
Compiled[e13]:

/*
Arrays:
  - val_8[float32]: (1 1024 512) // INPUT, TMP
  - val_4[float32]: (256 1 512) // INPUT, TMP
  - val_12[float32]: (256 1024 1) // OUTPUT, TMP
*/
void main81140_e13_k0(float* val_8, float* val_4, float* val_12);
void main81140_e13_k0(float* val_8, float* val_4, float* val_12) {
  for(int c0=0;(c0<=255);c0+=1) {
    for(int c1=0;(c1<=1023);c1+=4) {
      val_12[1024*c0+(c1+0)+0] = 0.0;
      val_12[1024*c0+(c1+1)+0] = 0.0;
      val_12[1024*c0+(c1+2)+0] = 0.0;
      val_12[1024*c0+(c1+3)+0] = 0.0;
      for(int c2=0;(c2<=511);c2+=4) {
        val_12[1024*c0+(c1+0)+0] += val_4[512*c0+0+(c2+0)] * val_8[0+(c1+0)+1024*(c2+0)];
        val_12[1024*c0+(c1+1)+0] += val_4[512*c0+0+(c2+0)] * val_8[0+(c1+1)+1024*(c2+0)];
        val_12[1024*c0+(c1+2)+0] += val_4[512*c0+0+(c2+0)] * val_8[0+(c1+2)+1024*(c2+0)];
        val_12[1024*c0+(c1+3)+0] += val_4[512*c0+0+(c2+0)] * val_8[0+(c1+3)+1024*(c2+0)];
        val_12[1024*c0+(c1+0)+0] += val_4[512*c0+0+(c2+1)] * val_8[0+(c1+0)+1024*(c2+1)];
        val_12[1024*c0+(c1+1)+0] += val_4[512*c0+0+(c2+1)] * val_8[0+(c1+1)+1024*(c2+1)];
        val_12[1024*c0+(c1+2)+0] += val_4[512*c0+0+(c2+1)] * val_8[0+(c1+2)+1024*(c2+1)];
        val_12[1024*c0+(c1+3)+0] += val_4[512*c0+0+(c2+1)] * val_8[0+(c1+3)+1024*(c2+1)];
        val_12[1024*c0+(c1+0)+0] += val_4[512*c0+0+(c2+2)] * val_8[0+(c1+0)+1024*(c2+2)];
        val_12[1024*c0+(c1+1)+0] += val_4[512*c0+0+(c2+2)] * val_8[0+(c1+1)+1024*(c2+2)];
        val_12[1024*c0+(c1+2)+0] += val_4[512*c0+0+(c2+2)] * val_8[0+(c1+2)+1024*(c2+2)];
        val_12[1024*c0+(c1+3)+0] += val_4[512*c0+0+(c2+2)] * val_8[0+(c1+3)+1024*(c2+2)];
        val_12[1024*c0+(c1+0)+0] += val_4[512*c0+0+(c2+3)] * val_8[0+(c1+0)+1024*(c2+3)];
        val_12[1024*c0+(c1+1)+0] += val_4[512*c0+0+(c2+3)] * val_8[0+(c1+1)+1024*(c2+3)];
        val_12[1024*c0+(c1+2)+0] += val_4[512*c0+0+(c2+3)] * val_8[0+(c1+2)+1024*(c2+3)];
        val_12[1024*c0+(c1+3)+0] += val_4[512*c0+0+(c2+3)] * val_8[0+(c1+3)+1024*(c2+3)];
      }
    }
  }
}

@hikettei
Copy link
Owner Author

hikettei commented Sep 4, 2024

Symbolic Softmax

CATEN> (caten (!softmax (ax+b `(a b) 0.01 0.0)))
WARNING: WIP: MaxOp
Compiled[e33]:

/*
Arrays:
  - A[uint32]: NIL // INPUT, SHAPE
  - B[uint32]: NIL // INPUT, SHAPE
  - val_10[uint32]: NIL // INPUT, TMP
  - val_16[float32]: (A 1) // OUTPUT, TMP
  - val_8[float32]: (A B) // OUTPUT, TMP
  - val_36[float32]: (A 1) // OUTPUT, TMP
*/
void main953086_e33_k0(uint32_t a, uint32_t b, uint32_t val_10, float* val_16, float* val_8, float* val_36);
void main953086_e33_k0(uint32_t a, uint32_t b, uint32_t val_10, float* val_16, float* val_8, float* val_36) {
  for(int c0=0;(c0<a);c0+=1) {
    val_16[c0+0] = 0.0;
    for(int c1=0;((c1+4)<=b);c1+=4) {
      val_8[b*c0+(c1+0)] = ((val_10*c0+1*(c1+0))*0.01);
      val_8[b*c0+(c1+1)] = ((val_10*c0+1*(c1+1))*0.01);
      val_8[b*c0+(c1+2)] = ((val_10*c0+1*(c1+2))*0.01);
      val_8[b*c0+(c1+3)] = ((val_10*c0+1*(c1+3))*0.01);
      val_8[b*c0+(c1+0)] = val_8[b*c0+(c1+0)];
      val_8[b*c0+(c1+1)] = val_8[b*c0+(c1+1)];
      val_8[b*c0+(c1+2)] = val_8[b*c0+(c1+2)];
      val_8[b*c0+(c1+3)] = val_8[b*c0+(c1+3)];
      val_16[c0+0] = max(val_16[c0+0], val_8[b*c0+(c1+0)]);
      val_16[c0+0] = max(val_16[c0+0], val_8[b*c0+(c1+1)]);
      val_16[c0+0] = max(val_16[c0+0], val_8[b*c0+(c1+2)]);
      val_16[c0+0] = max(val_16[c0+0], val_8[b*c0+(c1+3)]);
    }
    for(int c1=(b-(b%4));(c1<b);c1+=1) {
      val_8[b*c0+c1] = ((val_10*c0+1*c1)*0.01);
      val_8[b*c0+c1] = val_8[b*c0+c1];
      val_16[c0+0] = max(val_16[c0+0], val_8[b*c0+c1]);
    }
    val_16[c0+0] = -(val_16[c0+0]);
    val_36[c0+0] = 0.0;
    for(int c1=0;((c1+4)<=b);c1+=4) {
      val_8[b*c0+(c1+0)] = exp2(((val_8[b*c0+(c1+0)]+val_16[c0+0])*1.442695));
      val_8[b*c0+(c1+1)] = exp2(((val_8[b*c0+(c1+1)]+val_16[c0+0])*1.442695));
      val_8[b*c0+(c1+2)] = exp2(((val_8[b*c0+(c1+2)]+val_16[c0+0])*1.442695));
      val_8[b*c0+(c1+3)] = exp2(((val_8[b*c0+(c1+3)]+val_16[c0+0])*1.442695));
      val_36[c0+0] = (val_36[c0+0]+val_8[b*c0+(c1+0)]);
      val_36[c0+0] = (val_36[c0+0]+val_8[b*c0+(c1+1)]);
      val_36[c0+0] = (val_36[c0+0]+val_8[b*c0+(c1+2)]);
      val_36[c0+0] = (val_36[c0+0]+val_8[b*c0+(c1+3)]);
    }
    for(int c1=(b-(b%4));(c1<b);c1+=1) {
      val_8[b*c0+c1] = exp2(((val_8[b*c0+c1]+val_16[c0+0])*1.442695));
      val_36[c0+0] = (val_36[c0+0]+val_8[b*c0+c1]);
    }
    val_36[c0+0] = 1/(val_36[c0+0]);
    for(int c1=0;((c1+4)<=b);c1+=4) {
      val_8[b*c0+(c1+0)] = (val_8[b*c0+(c1+0)]*val_36[c0+0]);
      val_8[b*c0+(c1+1)] = (val_8[b*c0+(c1+1)]*val_36[c0+0]);
      val_8[b*c0+(c1+2)] = (val_8[b*c0+(c1+2)]*val_36[c0+0]);
      val_8[b*c0+(c1+3)] = (val_8[b*c0+(c1+3)]*val_36[c0+0]);
    }
    for(int c1=(b-(b%4));(c1<b);c1+=1) {
      val_8[b*c0+c1] = (val_8[b*c0+c1]*val_36[c0+0]);
    }
  }
}

@hikettei
Copy link
Owner Author

hikettei commented Sep 5, 2024

GPUカーネルのためにguardを削除したい

  • separateを有効にできるよう頑張る
  • 自力で頑張って消す
  • loop for belowのexpr simplify

@hikettei hikettei merged commit b48fb2d into main Sep 5, 2024
2 checks passed
@hikettei hikettei mentioned this pull request Sep 6, 2024
12 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant