Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various enhancements and refactorings on caten/ajit #110

Merged
merged 72 commits into from
Sep 25, 2024
Merged

Conversation

hikettei
Copy link
Owner

@hikettei hikettei commented Sep 20, 2024

  • Fix: Optimize gemm (no extra copy by !contiguous)
  • Fix: pass all tests
  • Scheduler: Fuse nested loops
    • Merge inner (:LOCAL) Loop to get an optimal solution of Embedding
    • 同一のタスクにScheduleする条件をゆるくした方がいいかもしれない (MultiExpr or buffer-intersect-p)
  • Fix: 4D Tensor ScaledDotProductAttention
  • Add: ConvND Testing
  • say goodbye to double corruption
  • Scheduler: ISLのASTでなるべくFuseしておく
  • IR: merge viewを考える
  • segv/randn w/ 2~3d inputs are failing
  • INDEX-COMPONENT is always scalar
  • stあるからone-shotでカーネル求めれないかのぅ
  • よくわかんなくなってきた
  • すぐ治したい: transpose matmulのtransposeがzero_costじゃない (Broadcastも...)
  • Improve the debugger visualization

@hikettei
Copy link
Owner Author

hikettei commented Sep 20, 2024

Embedding in a single kernel!!! (If I propagate :INDEX_COMPONENTS, it should be an optimal solution)
(Plus, fuse _gid2 and _gid3)

void main111178_e24_k0(float* val_35, float* val_54, boolean* val_48, float* val_31, float* val_37, float* val_41) {
  for(int _gid0=0;(_gid0<=9);_gid0+=1) {
    val_41[10*0+10*0+_gid0+0] = (10*0+10*0+1*_gid0+1*0);
    for(int _gid1=0;(_gid1<=25);_gid1+=4) {
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        float val_45_0 = val_41[0+0+_gid2+0];
        float val_45_1 = val_41[0+0+_gid2+0];
        float val_45_2 = val_41[0+0+_gid2+0];
        float val_45_3 = val_41[0+0+_gid2+0];
        boolean val_47_0 = (val_45_0!=val_37[10*(_gid1+0)+_gid0+0+0]);
        boolean val_47_1 = (val_45_1!=val_37[10*(_gid1+1)+_gid0+0+0]);
        boolean val_47_2 = (val_45_2!=val_37[10*(_gid1+2)+_gid0+0+0]);
        boolean val_47_3 = (val_45_3!=val_37[10*(_gid1+3)+_gid0+0+0]);
        val_48[100*(_gid1+0)+10*_gid0+_gid2+0] = !val_47_0;
        val_48[100*(_gid1+1)+10*_gid0+_gid2+0] = !val_47_1;
        val_48[100*(_gid1+2)+10*_gid0+_gid2+0] = !val_47_2;
        val_48[100*(_gid1+3)+10*_gid0+_gid2+0] = !val_47_3;
      }
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        val_54[100*(_gid1+0)+10*_gid0+10*0+_gid2] = 0.0;
        val_54[100*(_gid1+1)+10*_gid0+10*0+_gid2] = 0.0;
        val_54[100*(_gid1+2)+10*_gid0+10*0+_gid2] = 0.0;
        val_54[100*(_gid1+3)+10*_gid0+10*0+_gid2] = 0.0;
        for(int _gid3=0;(_gid3<=9);_gid3+=1) {
          val_35[1000*(_gid1+0)+100*_gid0+10*_gid3+_gid2] = (val_48[100*(_gid1+0)+10*_gid0+_gid3+0] ? val_31[0+0+10*_gid3+_gid2] : 0.0);
          val_35[1000*(_gid1+1)+100*_gid0+10*_gid3+_gid2] = (val_48[100*(_gid1+1)+10*_gid0+_gid3+0] ? val_31[0+0+10*_gid3+_gid2] : 0.0);
          val_35[1000*(_gid1+2)+100*_gid0+10*_gid3+_gid2] = (val_48[100*(_gid1+2)+10*_gid0+_gid3+0] ? val_31[0+0+10*_gid3+_gid2] : 0.0);
          val_35[1000*(_gid1+3)+100*_gid0+10*_gid3+_gid2] = (val_48[100*(_gid1+3)+10*_gid0+_gid3+0] ? val_31[0+0+10*_gid3+_gid2] : 0.0);
        }
        for(int _gid3=0;(_gid3<=9);_gid3+=1) {
          val_54[100*(_gid1+0)+10*_gid0+0+_gid2] = (val_54[100*(_gid1+0)+10*_gid0+0+_gid2]+val_35[1000*(_gid1+0)+100*_gid0+10*_gid3+_gid2]);
          val_54[100*(_gid1+1)+10*_gid0+0+_gid2] = (val_54[100*(_gid1+1)+10*_gid0+0+_gid2]+val_35[1000*(_gid1+1)+100*_gid0+10*_gid3+_gid2]);
          val_54[100*(_gid1+2)+10*_gid0+0+_gid2] = (val_54[100*(_gid1+2)+10*_gid0+0+_gid2]+val_35[1000*(_gid1+2)+100*_gid0+10*_gid3+_gid2]);
          val_54[100*(_gid1+3)+10*_gid0+0+_gid2] = (val_54[100*(_gid1+3)+10*_gid0+0+_gid2]+val_35[1000*(_gid1+3)+100*_gid0+10*_gid3+_gid2]);
        }
      }
    }
    for(int _gid1=28;(_gid1<=29);_gid1+=1) {
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        float val_45 = val_41[0+0+_gid2+0];
        boolean val_47 = (val_45!=val_37[10*_gid1+_gid0+0+0]);
        val_48[100*_gid1+10*_gid0+_gid2+0] = !val_47;
      }
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        val_54[100*_gid1+10*_gid0+10*0+_gid2] = 0.0;
        for(int _gid3=0;(_gid3<=9);_gid3+=1) {
          val_35[1000*_gid1+100*_gid0+10*_gid3+_gid2] = (val_48[100*_gid1+10*_gid0+_gid3+0] ? val_31[0+0+10*_gid3+_gid2] : 0.0);
        }
        for(int _gid3=0;(_gid3<=9);_gid3+=1) {
          val_54[100*_gid1+10*_gid0+0+_gid2] = (val_54[100*_gid1+10*_gid0+0+_gid2]+val_35[1000*_gid1+100*_gid0+10*_gid3+_gid2]);
        }
      }
    }
  }
}

@hikettei hikettei changed the title [WIP] Fix JIT Fix Scheduler Sep 20, 2024
@hikettei
Copy link
Owner Author

we have to revise the semantic of !reshape ...

(caten (!sin (!reshape (!sin (!reshape (make-tensor `(3 3)) `(9))) `(3 3))))

@hikettei
Copy link
Owner Author

hikettei commented Sep 21, 2024

JIT in Caten

(文章化してちゃんと考える ...)

  • リファクタしたい,一つのPolyhedral IRにつき一つCLOS Classを用意する

Polyhedral IR in Caten of Embedding, (gained by SERIALIZE=1).

(with-no-grad
    (caten (call (Embedding 10 10) (make-tensor `(10 30)))))
Compiled[e23]:
Compiled[e24]:

/*
Arrays:
  - val_54[float32]: (10 30 1 10) // OUTPUT, TMP
*/
void main2228060_e24_k0(float* val_54);
void main2228060_e24_k0(float* val_54) {
  for(int _gid0=0;(_gid0<=29);_gid0+=1) {
    for(int _gid1=0;(_gid1<=9);_gid1+=1) {
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        val_54[300*_gid1+10*_gid0+10*0+_gid2] = 0.0;
      }
    }
  }
}

/*
Arrays:
  - val_41[float32]: (1 1 10 1) // IO, TMP
*/
void main2228060_e24_k1(float* val_41);
void main2228060_e24_k1(float* val_41) {
  for(int _gid0=0;(_gid0<=9);_gid0+=1) {
    val_41[10*0+10*0+_gid0+0] = (10*0+10*0+1*_gid0+1*0);
  }
}

/*
Arrays:
  - val_45[float32]: (10 30 10 1) // OUTPUT, TMP
  - val_41[float32]: (1 1 10 1) // INPUT, TMP
*/
void main2228060_e24_k2(float* val_45, float* val_41);
void main2228060_e24_k2(float* val_45, float* val_41) {
  for(int _gid0=0;(_gid0<=29);_gid0+=1) {
    for(int _gid1=0;(_gid1<=9);_gid1+=1) {
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        val_45[300*_gid2+10*_gid0+_gid1+0] = val_41[0+0+_gid1+0];
      }
    }
  }
}

/*
Arrays:
  - val_47[bool]: (10 30 10 1) // OUTPUT, TMP
  - val_45[float32]: (10 30 10 1) // INPUT, TMP
  - val_37[float32]: (10 30 1 1) // INPUT, TMP
*/
void main2228060_e24_k3(boolean* val_47, float* val_45, float* val_37);
void main2228060_e24_k3(boolean* val_47, float* val_45, float* val_37) {
  for(int _gid0=0;(_gid0<=29);_gid0+=1) {
    for(int _gid1=0;(_gid1<=9);_gid1+=1) {
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        val_47[300*_gid2+10*_gid0+_gid1+0] = (val_45[300*_gid2+10*_gid0+_gid1+0]!=val_37[30*_gid2+_gid0+0+0]);
      }
    }
  }
}

/*
Arrays:
  - val_48[bool]: (10 30 10 1) // OUTPUT, TMP
  - val_47[bool]: (10 30 10 1) // INPUT, TMP
*/
void main2228060_e24_k4(boolean* val_48, boolean* val_47);
void main2228060_e24_k4(boolean* val_48, boolean* val_47) {
  for(int _gid0=0;(_gid0<=29);_gid0+=1) {
    for(int _gid1=0;(_gid1<=9);_gid1+=1) {
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        val_48[300*_gid2+10*_gid0+_gid1+0] = !val_47[300*_gid2+10*_gid0+_gid1+0];
      }
    }
  }
}

/*
Arrays:
  - val_35[float32]: (10 30 10 10) // OUTPUT, TMP
  - val_48[bool]: (10 30 10 1) // INPUT, TMP
  - val_31[float32]: (1 1 10 10) // INPUT, TMP
*/
void main2228060_e24_k5(float* val_35, boolean* val_48, float* val_31);
void main2228060_e24_k5(float* val_35, boolean* val_48, float* val_31) {
  for(int _gid0=0;(_gid0<=9);_gid0+=1) {
    for(int _gid1=0;(_gid1<=29);_gid1+=1) {
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        for(int _gid3=0;(_gid3<=9);_gid3+=1) {
          val_35[3000*_gid0+100*_gid1+10*_gid2+_gid3] = (val_48[300*_gid0+10*_gid1+_gid2+0] ? val_31[0+0+10*_gid2+_gid3] : 0.0);
        }
      }
    }
  }
}

/*
Arrays:
  - val_54[float32]: (10 30 1 10) // IO, TMP
  - val_35[float32]: (10 30 10 10) // INPUT, TMP
*/
void main2228060_e24_k6(float* val_54, float* val_35);
void main2228060_e24_k6(float* val_54, float* val_35) {
  for(int _gid0=0;(_gid0<=9);_gid0+=1) {
    for(int _gid1=0;(_gid1<=29);_gid1+=1) {
      for(int _gid2=0;(_gid2<=9);_gid2+=1) {
        for(int _gid3=0;(_gid3<=9);_gid3+=1) {
          val_54[300*_gid0+10*_gid1+0+_gid2] = (val_54[300*_gid0+10*_gid1+0+_gid2]+val_35[3000*_gid0+100*_gid1+10*_gid3+_gid2]);
        }
      }
    }
  }
}

@hikettei
Copy link
Owner Author

hikettei commented Sep 24, 2024

  • !randnは修正するべき (OK)
    • Dynamicだと通る,Staticだと通らない (OK)
  • mainとコンパイル結果が一致するように,特に: (OK)
    • !softmax unrolling (packed-funcall, Loopが2つ以上でもUnroll)
    • !log-softmax in-place (Post-MultiExpr) !matmul packing (:LOCAL :GLOBALの付与がおかしい?)
    • !randn (Initial Scheduleが悪い?)(nth=199, is scheduled to T0 -> T2 -> T1)
      • TODO: Fuse val_14
  • everything is an initial schedule issue.
  • Complete Embedding Folding to pass randn test

TODO: Remaining Task

  • fix randn compiler
  • The way determining coincident (to unroll !softmax, !matmul properly)
  • Implement Post-MultiExpr to fuse !log-softmax and Embedding
  • Implement Index-Component Fusion
  • Delete Unused Node After PostMultiExpr

@hikettei
Copy link
Owner Author

Softmax (old vs new)

/*
Arrays:
  - B[uint32]: NIL // INPUT, SHAPE
  - A[uint32]: NIL // INPUT, SHAPE
  - val_34[float32]: (A B) // IO, TMP
*/
void main862852_e30_k0(uint32_t b, uint32_t a, float* val_34);
void main862852_e30_k0(uint32_t b, uint32_t a, float* val_34) {
  for(int _gid0=0;(_gid0<a);_gid0+=1) {
    float val_25 = 0.0;
    for(int _gid1=0;((_gid1+4)<=b);_gid1+=4) {
      val_25 = max(val_25, val_34[b*_gid0+(_gid1+0)]);
      val_25 = max(val_25, val_34[b*_gid0+(_gid1+1)]);
      val_25 = max(val_25, val_34[b*_gid0+(_gid1+2)]);
      val_25 = max(val_25, val_34[b*_gid0+(_gid1+3)]);
    }
    for(int _gid1=(b-(b%4));(_gid1<b);_gid1+=1) {
      val_25 = max(val_25, val_34[b*_gid0+_gid1]);
    }
    float val_26 = -(val_25);
    val_25 = 0.0;
    for(int _gid1=0;((_gid1+4)<=b);_gid1+=4) {
      val_34[b*_gid0+(_gid1+0)] = exp2(((val_34[b*_gid0+(_gid1+0)]+val_26)*1.442695));
      val_34[b*_gid0+(_gid1+1)] = exp2(((val_34[b*_gid0+(_gid1+1)]+val_26)*1.442695));
      val_34[b*_gid0+(_gid1+2)] = exp2(((val_34[b*_gid0+(_gid1+2)]+val_26)*1.442695));
      val_34[b*_gid0+(_gid1+3)] = exp2(((val_34[b*_gid0+(_gid1+3)]+val_26)*1.442695));
      val_25 = (val_25+val_34[b*_gid0+(_gid1+0)]);
      val_25 = (val_25+val_34[b*_gid0+(_gid1+1)]);
      val_25 = (val_25+val_34[b*_gid0+(_gid1+2)]);
      val_25 = (val_25+val_34[b*_gid0+(_gid1+3)]);
    }
    for(int _gid1=(b-(b%4));(_gid1<b);_gid1+=1) {
      val_34[b*_gid0+_gid1] = exp2(((val_34[b*_gid0+_gid1]+val_26)*1.442695));
      val_25 = (val_25+val_34[b*_gid0+_gid1]);
    }
    float val_11 = 1/(val_25);
    for(int _gid1=0;((_gid1+4)<=b);_gid1+=4) {
      val_34[b*_gid0+(_gid1+0)] = (val_34[b*_gid0+(_gid1+0)]*val_11);
      val_34[b*_gid0+(_gid1+1)] = (val_34[b*_gid0+(_gid1+1)]*val_11);
      val_34[b*_gid0+(_gid1+2)] = (val_34[b*_gid0+(_gid1+2)]*val_11);
      val_34[b*_gid0+(_gid1+3)] = (val_34[b*_gid0+(_gid1+3)]*val_11);
    }
    for(int _gid1=(b-(b%4));(_gid1<b);_gid1+=1) {
      val_34[b*_gid0+_gid1] = (val_34[b*_gid0+_gid1]*val_11);
    }
  }
}
/*
Arrays:
  - B[uint32]: NIL // INPUT, SHAPE
  - A[uint32]: NIL // INPUT, SHAPE
  - val_34[float32]: (A B) // IO, TMP
*/
void main736348_e30_k0(uint32_t b, uint32_t a, float* val_34);
void main736348_e30_k0(uint32_t b, uint32_t a, float* val_34) {
  for(int _gid0=0;(_gid0<a);_gid0+=1) {
    float val_11 = 0.0;
    float val_26 = 0.0;
    for(int _gid1=0;((_gid1+4)<=b);_gid1+=4) {
      val_26 = max(val_26, val_34[((b*_gid0)+(_gid1+0))]);
      val_26 = max(val_26, val_34[((b*_gid0)+(_gid1+1))]);
      val_26 = max(val_26, val_34[((b*_gid0)+(_gid1+2))]);
      val_26 = max(val_26, val_34[((b*_gid0)+(_gid1+3))]);
    }
    for(int _gid1=(b-(b%4));(_gid1<b);_gid1+=1) {
      val_26 = max(val_26, val_34[((b*_gid0)+_gid1)]);
    }
    val_26 = -(val_26);
    for(int _gid1=0;((_gid1+4)<=b);_gid1+=4) {
      val_34[((b*_gid0)+(_gid1+0))] = exp2(((val_34[((b*_gid0)+(_gid1+0))]+val_26)*1.442695));
      val_34[((b*_gid0)+(_gid1+1))] = exp2(((val_34[((b*_gid0)+(_gid1+1))]+val_26)*1.442695));
      val_34[((b*_gid0)+(_gid1+2))] = exp2(((val_34[((b*_gid0)+(_gid1+2))]+val_26)*1.442695));
      val_34[((b*_gid0)+(_gid1+3))] = exp2(((val_34[((b*_gid0)+(_gid1+3))]+val_26)*1.442695));
    }
    for(int _gid1=(b-(b%4));(_gid1<b);_gid1+=1) {
      val_34[((b*_gid0)+_gid1)] = exp2(((val_34[((b*_gid0)+_gid1)]+val_26)*1.442695));
    }
    for(int _gid1=0;((_gid1+4)<=b);_gid1+=4) {
      val_11 = (val_11+val_34[((b*_gid0)+(_gid1+0))]);
      val_11 = (val_11+val_34[((b*_gid0)+(_gid1+1))]);
      val_11 = (val_11+val_34[((b*_gid0)+(_gid1+2))]);
      val_11 = (val_11+val_34[((b*_gid0)+(_gid1+3))]);
    }
    for(int _gid1=(b-(b%4));(_gid1<b);_gid1+=1) {
      val_11 = (val_11+val_34[((b*_gid0)+_gid1)]);
    }
    val_11 = 1/(val_11);
    for(int _gid1=0;((_gid1+4)<=b);_gid1+=4) {
      val_34[((b*_gid0)+(_gid1+0))] = (val_34[((b*_gid0)+(_gid1+0))]*val_11);
      val_34[((b*_gid0)+(_gid1+1))] = (val_34[((b*_gid0)+(_gid1+1))]*val_11);
      val_34[((b*_gid0)+(_gid1+2))] = (val_34[((b*_gid0)+(_gid1+2))]*val_11);
      val_34[((b*_gid0)+(_gid1+3))] = (val_34[((b*_gid0)+(_gid1+3))]*val_11);
    }
    for(int _gid1=(b-(b%4));(_gid1<b);_gid1+=1) {
      val_34[((b*_gid0)+_gid1)] = (val_34[((b*_gid0)+_gid1)]*val_11);
    }
  }
}

@hikettei hikettei marked this pull request as ready for review September 25, 2024 08:59
@hikettei hikettei merged commit d3f9fdf into main Sep 25, 2024
2 of 4 checks passed
hikettei added a commit that referenced this pull request Sep 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant