-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
code-gen: Allowed WaveGroups be distributed along n-dim for DTVA/SwizzledA #1493
Conversation
e86a5d9
to
8ae1a10
Compare
8ae1a10
to
3eb6f57
Compare
3eb6f57
to
2a2663e
Compare
# This is to limit the number of Vgpr | ||
if tc == 'A' and not (state['MIWaveGroup'][1] == 1 and state['MatrixInstBN'] == 1): | ||
reject(state, "MIWaveGroup[1] and MatrixInstBN should be 1 for DirectToVgprA. Current value is [%d, %d]"%(state['MIWaveGroup'][1], state['MatrixInstBN'])) | ||
if tc == 'A' and not (state['MatrixInstBN'] == 1): | ||
reject(state, "MatrixInstBN should be 1 for DirectToVgprA. Current value is %d"%(state['MatrixInstBN'])) | ||
return False | ||
if tc == 'B' and not (state['MIWaveGroup'][0] == 1 and state['MatrixInstBM'] == 1): | ||
reject(state, "MIWaveGroup[0] and MatrixInstBM should be 1 for DirectToVgprB. Current value is [%d, %d]"%(state['MIWaveGroup'][0], state['MatrixInstBM'])) | ||
if tc == 'B' and not (state['MatrixInstBM'] == 1): | ||
reject(state, "MatrixInstBM should be 1 for DirectToVgprB. Current value is %d"%(state['MatrixInstBM'])) | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the restrictions
totalElementsPerpA = state["MacroTileA"] | ||
if state["DirectToVgprA"]: | ||
totalElementsPerpA *= state["MIWaveGroup"][1] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually, other than modifying the NumLoadPerp in the last step, modifying the total-Elem-Perp here is the best way. Same for DTVB
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(32)), src=sgpr("SizesSum"), comment="SWZ: numKr = DimK / 32")) | ||
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave-id *= numKr")) | ||
WvG_M = kernel["MIWaveGroup"][0] | ||
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) %= MIWG[0]")) | ||
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) *= numKr")) | ||
elif isDTVAB: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor modification for swizzled-A
- [16, 16, 16, 1, 1, 5, 8, 2,1 ] # MT = 160x128 | ||
|
||
- [16, 16, 16, 1, 1, 2, 4, 2,2 ] # MT = 64x128 | ||
- [16, 16, 16, 1, 1, 2, 8, 2,2 ] # MT = 64x256 | ||
- [16, 16, 16, 1, 1, 4, 4, 2,2 ] # MT = 128x128 | ||
- [16, 16, 16, 1, 1, 4, 8, 2,2 ] # MT = 128x256 | ||
- [16, 16, 16, 1, 1, 8, 4, 2,2 ] # MT = 256x128 | ||
- [16, 16, 16, 1, 1, 8, 8, 2,2 ] # MT = 256x256 | ||
- [16, 16, 16, 1, 1, 16, 4, 2,2 ] # MT = 512x128 | ||
- [16, 16, 16, 1, 1, 16, 8, 2,2 ] # MT = 512x256 | ||
- [16, 16, 16, 1, 1, 5, 4, 2,2 ] # MT = 160x128 | ||
|
||
- [16, 16, 16, 1, 1, 4, 2, 1,4 ] # MT = 64x128 | ||
- [16, 16, 16, 1, 1, 4, 4, 1,4 ] # MT = 64x256 | ||
- [16, 16, 16, 1, 1, 8, 2, 1,4 ] # MT = 128x128 | ||
- [16, 16, 16, 1, 1, 8, 4, 1,4 ] # MT = 128x256 | ||
- [16, 16, 16, 1, 1, 16, 2, 1,4 ] # MT = 256x128 | ||
- [16, 16, 16, 1, 1, 16, 4, 1,4 ] # MT = 256x256 | ||
- [16, 16, 16, 1, 1, 10, 2, 1,4 ] # MT = 160x128 | ||
- AssertFree0ElementMultiple: [16] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SwizzledA: Added WaveG = 2x2 and 1x4
- [32, 32, 8, 1, 1, 4, 1, 4,1 ] | ||
|
||
- [16, 16, 16, 1, 1, 8,1, 1,4 ] # 128x64 | ||
- [16, 16, 16, 1, 1, 8,1, 2,2 ] # 256x32 | ||
- [16, 16, 16, 1, 1, 16,1, 2,2 ] # 512x32 | ||
- [16, 16, 16, 1, 1, 10,1, 2,2 ] # 160x32 | ||
- WorkGroup: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DTVA: added a few WaveG = 2x2 and 1x4
DTVB: added a few WaveG = 2x2 and 4x1
2a2663e
to
9c9fe65
Compare
if state["ProblemType"]["TLUB"]: | ||
if state["ProblemType"]["TLUB"]: # NT/TT | ||
totalElementsCoalescedB = state["MacroTileB"] | ||
totalElementsPerpB = depthUB | ||
else: | ||
if state["DirectToVgprB"]: | ||
totalElementsCoalescedB *= state["MIWaveGroup"][0] | ||
else: # TN/NN | ||
totalElementsCoalescedB = depthUB | ||
totalElementsPerpB = state["MacroTileB"] | ||
if state["DirectToVgprB"]: | ||
totalElementsPerpB *= state["MIWaveGroup"][0] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed: handled different TLU cases (otherwise, NT ProblemType failed)
Resolved for SWDEV-504181
[gw2] [ 33%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/dtvA_swizzleA.yaml]
[gw0] [ 66%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/dtl.yaml]
[gw1] [100%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/dtv.yaml]
=== 3 passed, 4 warnings in 2395.52s (0:39:55) ===
py310: OK (2540.72=setup[10.65]+cmd[0.57,133.75,2395.76] seconds)
congratulations :) (2540.77 seconds)