Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPT] Tail Loop Optimization #1567

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open

Conversation

briannwu
Copy link
Contributor

@briannwu briannwu commented Jan 17, 2025

details:

  1. Separate tailLoopOpt for A / B: tailLoopOptA / tailLoopOptB.
  2. Not supported: DTV, SparseGemm.
  3. Reorder load instructions with more vgprs.

Compare:

globalReadMode = 3 -> use more vgpr to reorder GR, waitcnt, v_or_b32 instructions

Before:
/* g2l=0, load component 0 /
buffer_load_ubyte_d16 v[vgprG2LA+0+0], ..., 0 offen offset:0 // load one buffer value
/
g2l=0, load component 1 /
buffer_load_ubyte_d16 v0, ..., 0 offen offset:1 // load one buffer value
s_waitcnt vmcnt(0)
v_lshlrev_b32 v0, 0x8, v0 // shift left to higher 8 bits
v_or_b32 v[vgprG2LA+0+0], v[vgprG2LA+0+0], v0 // pack a sub 8-bit with dest
/
g2l=0, load component 0 /
buffer_load_ubyte_d16 v[vgprG2LA+0+4], ... offen offset:0 // load one buffer value
/
g2l=0, load component 1 */
buffer_load_ubyte_d16 v0, ... offen offset:1 // load one buffer value
s_waitcnt vmcnt(0)
v_lshlrev_b32 v0, 0x8, v0 // shift left to higher 8 bits
v_or_b32 v[vgprG2LA+0+4], v[vgprG2LA+0+4], v0 // pack a sub 8-bit with dest
...

After:
buffer_load_ubyte_d16 v[vgprG2LA+0+0], ... offen offset:0 // load one buffer value
buffer_load_ubyte_d16 v0, ..., 0 offen offset:1 // load one buffer value
buffer_load_ubyte_d16 v[vgprG2LA+0+4], ... offen offset:0 // load one buffer value
buffer_load_ubyte_d16 v1, ... offen offset:1 // load one buffer value
buffer_load_ubyte_d16 v[vgprG2LA+1+0], offen offset:0 // load one buffer value
...
s_waitcnt vmcnt(10)
v_lshlrev_b32 v0, 0x8, v0 // shift left to higher 8 bits
v_or_b32 v[vgprG2LA+0+0], v[vgprG2LA+0+0], v0 // pack a sub 8-bit with dest
s_waitcnt vmcnt(8)
v_lshlrev_b32 v1, 0x8, v1 // shift left to higher 8 bits
v_or_b32 v[vgprG2LA+0+4], v[vgprG2LA+0+4], v1 // pack a sub 8-bit with dest
...

//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
globalReadMode = 2 -> use wider global load instructions
Before:
/* g2l=0, load component 0 /
buffer_load_ubyte_d16 v[vgprG2LB+0+0], ..., 0 offen offset:0 // load one buffer value
/
g2l=0, load component 1 /
buffer_load_ubyte_d16 v51, ..., 0 offen offset:1 // load one buffer value
/
g2l=0, load component 2 /
buffer_load_ubyte_d16_hi v52, ..., 0 offen offset:2 // load one buffer value
/
g2l=0, load component 3 */
buffer_load_ubyte_d16_hi v53, ..., 0 offen offset:3 // load one buffer value
...
s_waitcnt vmcnt(14)
v_lshlrev_b32 v51, 0x8, v51 // shift left to higher 8 bits
v_or_b32 v[vgprG2LB+0+0], v[vgprG2LB+0+0], v51 // pack a sub 8-bit with dest
s_waitcnt vmcnt(13)
v_or_b32 v[vgprG2LB+0+0], v[vgprG2LB+0+0], v52 // pack a sub 8-bit with dest
s_waitcnt vmcnt(12)
v_lshlrev_b32 v53, 0x8, v53 // shift left to higher 8 bits
v_or_b32 v[vgprG2LB+0+0], v[vgprG2LB+0+0], v53 // pack a sub 8-bit with dest
...

After:
buffer_load_dwordx4 v[vgprG2LB+0:vgprG2LB+0+3], v[vgprGlobalReadOffsetB+0], s[sgprSrdB:sgprSrdB+3], 0 offen offset:0 // G -> Reg 0_0_0_0
... (calculate some data to determine how to load the last data)
label_LoadB:
... (jump to specified load tile)
label_LOAD_B0:
label_LOAD_B0_K1:
s_cmp_ge_u32 s11, 1
s_cbranch_scc0 label_MergeB
/* g2l=0, load component 0 */
buffer_load_ubyte_d16 v54, ... 0 offen offset:0 // load one buffer value
label_LOAD_B0_K2
...
label_LOAD_B0_K15:
... (load code)
s_branch label_MergeB
label_MergeB:
... (jump to specified load tile)
label_MERGE_B0:
label_MERGE_B0_K1:
s_cmp_ge_u32 s11, 1
s_cbranch_scc0 label_CheckB_OOB
s_waitcnt vmcnt(0)
v_or_b32 v[vgprG2LB+0+0], v[vgprG2LB+0+0], v54 // pack a sub 8-bit with dest
label_MERGE_B0_K2:
...
label_MERGE_B0_K15:
... (pack code)
s_branch label_CheckB_OOB
label_CheckB_OOB:
...
label_CheckLoopBeginB:
... (calculate size to be loaded and size can be loaded)
label_B0:
... (check if there's other tile should be loaded again due to OOB)
s_cbranch_scc1 label_LoadB // Reload
s_branch label_CheckLoopBeginB // Re check
label_TailGlobalLoadEnd:
s_waitcnt vmcnt(0)

@hcman2
Copy link
Contributor

hcman2 commented Jan 20, 2025

Any brief before/after comparison of the tail loop asm code?

@briannwu
Copy link
Contributor Author

image

image

@briannwu briannwu force-pushed the tail_opt branch 4 times, most recently from ef4242e to 4b4f883 Compare January 20, 2025 07:58
hcman2
hcman2 previously approved these changes Jan 20, 2025
Copy link
Contributor

@hcman2 hcman2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good Opt. If you can share the performance gain for sensitive sizes, it will be much better.

details:
1. Separate tailLoopOpt for A / B: tailLoopOptA / tailLoopOptB.
2. Not supported: DTV, SparseGemm.
3. Reorder load instructions with more vgprs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants