Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BugFix: MHA in rtol<1e-5 #270

Merged
merged 21 commits into from
Nov 29, 2024
Merged

BugFix: MHA in rtol<1e-5 #270

merged 21 commits into from
Nov 29, 2024

Conversation

hikettei
Copy link
Owner

@hikettei hikettei commented Nov 29, 2024

  • Create a small repro, which only works in JIT=0
  • EXP needs more accuracy?
  • more clean up to scheduler ...
  • Having a reliable scheduler first
  • Having a reliable shape tracker second
  • CI: Benchmark by symbolic
  • Attention = 5 Kernel
  • the bug is originated from composing multiple views in the same expr
  • NO_EXPRにしてfloat x = ...;で降ろしてからWMMAしてるところが悪い
  • Next, fix the memory planner and failing case works
  • test-attention-large is the shapetracker's bug

@hikettei
Copy link
Owner Author

hikettei commented Nov 29, 2024

image

@hikettei
Copy link
Owner Author

hikettei commented Nov 29, 2024

i got a clue:

Discrepancy point in the third kernel wmma?

val_38 should be val_38[(((_gid0*64)+(8*_gid1))+_gid3)]

val_45 = (val_45+(val_38[(((_gid0*64)+(8*_gid1))+_gid2)]*key[(((_gid0*64)+(8*_gid2))+_gid3)]));
void fused_sumnode_matmul_matmul1873006(float* val_27, const float* restrict val_41, const float* restrict val_17) {
  for (int _gid0=0; (_gid0<16); _gid0+=1) {
    for (int _gid1=0; (_gid1<8); _gid1+=1) {
      for (int _gid2=0; (_gid2<8); _gid2+=1) {
        float val_48 = 0.0;
        for (int _gid3=0; (_gid3<8); _gid3+=1) {
          val_48 = (val_48+(val_41[(((_gid0*64)+(8*_gid1))+_gid2)]*val_17[(((_gid0*64)+(8*_gid2))+_gid3)]));
        }
        val_27[(((64*_gid0)+(8*_gid1))+_gid2)] = (val_48*0.35355338);
      }
    }
  }
}

@hikettei hikettei changed the title MHA in rtol<1e-5 BugFix: MHA in rtol<1e-5 Nov 29, 2024
@hikettei
Copy link
Owner Author

hikettei commented Nov 29, 2024

this is what scheduled: val_40 overwrites the view of val_38

<Node[BINARYOPS] MOVE(NID2910893) : val_40* <- (val_39, val_38) where :_type_relay=#<INFERRED-TYPE ((4 4 8 8 1)) <- ((4 4 8 8 1) (4 4 8 8 1))> :_read_views=(NIL
                                                                                              (<VIEW : val_74 <- (val_38, shape=(4, 4, 8, 8, 1), views=((0, 4, 1, NIL), (0, 4, 1, NIL), (0, 8, 1, NIL), (0, 8, 1, NIL), (0, 1, 1, T)), stride=(256, 64, 8, 1, 1), permute=(0 1 2 3 4))>))>

<Node[BINARYOPS] MOVE(NID2912389) : val_42* <- (val_41, val_40) where :_type_relay=#<INFERRED-TYPE ((4 4 8 8 8)) <- ((4 4 8 8 8) (4 4 8 8 8))> :_read_views=(NIL
                                                                                              (<VIEW : val_73 <- (val_40, shape=(4, 4, 8, 8, 8), views=((0, 4, 1, NIL), (0, 4, 1, NIL), (0, 8, 1, NIL), (0, 8, 1, T), (0, 8, 1, NIL)), stride=(256, 64, 8, 1, 1), permute=(0 1 2 3 4))>))>

@hikettei
Copy link
Owner Author

hikettei commented Nov 29, 2024

Here
image

@hikettei hikettei marked this pull request as ready for review November 29, 2024 11:59
@hikettei hikettei merged commit e076aab into main Nov 29, 2024
5 of 6 checks passed
@hikettei hikettei deleted the attn-fix branch November 29, 2024 12:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant