Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code-gen: Allowed WaveGroups be distributed along n-dim for DTVA/SwizzledA #1493

Merged
merged 5 commits into from
Jan 17, 2025

Conversation

solaslin
Copy link
Contributor

@solaslin solaslin commented Dec 25, 2024

Resolved for SWDEV-504181

  • Implemented for DTVA/SwizzledA when WaveGroups[1] > 1 (N-Dim)
  • pytests: for DTVA and SwizzledA
  • Support and pytests: for DTVB - WaveGroups[0] > 1 (M-Dim)

[gw2] [ 33%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/dtvA_swizzleA.yaml]
[gw0] [ 66%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/dtl.yaml]
[gw1] [100%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/dtv.yaml]
=== 3 passed, 4 warnings in 2395.52s (0:39:55) ===
py310: OK (2540.72=setup[10.65]+cmd[0.57,133.75,2395.76] seconds)
congratulations :) (2540.77 seconds)

@solaslin solaslin self-assigned this Dec 25, 2024
@solaslin solaslin added enhancement New feature or request noCI Disable testing on supported CI systems: math libraries CI has this feature enabled.. labels Dec 25, 2024
@solaslin solaslin removed the noCI Disable testing on supported CI systems: math libraries CI has this feature enabled.. label Jan 14, 2025
@solaslin solaslin marked this pull request as ready for review January 14, 2025 07:18
Comment on lines 1900 to 1906
# This is to limit the number of Vgpr
if tc == 'A' and not (state['MIWaveGroup'][1] == 1 and state['MatrixInstBN'] == 1):
reject(state, "MIWaveGroup[1] and MatrixInstBN should be 1 for DirectToVgprA. Current value is [%d, %d]"%(state['MIWaveGroup'][1], state['MatrixInstBN']))
if tc == 'A' and not (state['MatrixInstBN'] == 1):
reject(state, "MatrixInstBN should be 1 for DirectToVgprA. Current value is %d"%(state['MatrixInstBN']))
return False
if tc == 'B' and not (state['MIWaveGroup'][0] == 1 and state['MatrixInstBM'] == 1):
reject(state, "MIWaveGroup[0] and MatrixInstBM should be 1 for DirectToVgprB. Current value is [%d, %d]"%(state['MIWaveGroup'][0], state['MatrixInstBM']))
if tc == 'B' and not (state['MatrixInstBM'] == 1):
reject(state, "MatrixInstBM should be 1 for DirectToVgprB. Current value is %d"%(state['MatrixInstBM']))
return False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the restrictions

Comment on lines 2995 to 3000
totalElementsPerpA = state["MacroTileA"]
if state["DirectToVgprA"]:
totalElementsPerpA *= state["MIWaveGroup"][1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, other than modifying the NumLoadPerp in the last step, modifying the total-Elem-Perp here is the best way. Same for DTVB

Comment on lines 3354 to 3358
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(32)), src=sgpr("SizesSum"), comment="SWZ: numKr = DimK / 32"))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave-id *= numKr"))
WvG_M = kernel["MIWaveGroup"][0]
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) %= MIWG[0]"))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) *= numKr"))
elif isDTVAB:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor modification for swizzled-A

Comment on lines 62 to 81
- [16, 16, 16, 1, 1, 5, 8, 2,1 ] # MT = 160x128

- [16, 16, 16, 1, 1, 2, 4, 2,2 ] # MT = 64x128
- [16, 16, 16, 1, 1, 2, 8, 2,2 ] # MT = 64x256
- [16, 16, 16, 1, 1, 4, 4, 2,2 ] # MT = 128x128
- [16, 16, 16, 1, 1, 4, 8, 2,2 ] # MT = 128x256
- [16, 16, 16, 1, 1, 8, 4, 2,2 ] # MT = 256x128
- [16, 16, 16, 1, 1, 8, 8, 2,2 ] # MT = 256x256
- [16, 16, 16, 1, 1, 16, 4, 2,2 ] # MT = 512x128
- [16, 16, 16, 1, 1, 16, 8, 2,2 ] # MT = 512x256
- [16, 16, 16, 1, 1, 5, 4, 2,2 ] # MT = 160x128

- [16, 16, 16, 1, 1, 4, 2, 1,4 ] # MT = 64x128
- [16, 16, 16, 1, 1, 4, 4, 1,4 ] # MT = 64x256
- [16, 16, 16, 1, 1, 8, 2, 1,4 ] # MT = 128x128
- [16, 16, 16, 1, 1, 8, 4, 1,4 ] # MT = 128x256
- [16, 16, 16, 1, 1, 16, 2, 1,4 ] # MT = 256x128
- [16, 16, 16, 1, 1, 16, 4, 1,4 ] # MT = 256x256
- [16, 16, 16, 1, 1, 10, 2, 1,4 ] # MT = 160x128
- AssertFree0ElementMultiple: [16]
Copy link
Contributor Author

@solaslin solaslin Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SwizzledA: Added WaveG = 2x2 and 1x4

Comment on lines 444 to 450
- [32, 32, 8, 1, 1, 4, 1, 4,1 ]

- [16, 16, 16, 1, 1, 8,1, 1,4 ] # 128x64
- [16, 16, 16, 1, 1, 8,1, 2,2 ] # 256x32
- [16, 16, 16, 1, 1, 16,1, 2,2 ] # 512x32
- [16, 16, 16, 1, 1, 10,1, 2,2 ] # 160x32
- WorkGroup:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DTVA: added a few WaveG = 2x2 and 1x4
DTVB: added a few WaveG = 2x2 and 4x1

Comment on lines -2995 to 3011
if state["ProblemType"]["TLUB"]:
if state["ProblemType"]["TLUB"]: # NT/TT
totalElementsCoalescedB = state["MacroTileB"]
totalElementsPerpB = depthUB
else:
if state["DirectToVgprB"]:
totalElementsCoalescedB *= state["MIWaveGroup"][0]
else: # TN/NN
totalElementsCoalescedB = depthUB
totalElementsPerpB = state["MacroTileB"]
if state["DirectToVgprB"]:
totalElementsPerpB *= state["MIWaveGroup"][0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: handled different TLU cases (otherwise, NT ProblemType failed)

@solaslin solaslin merged commit 38efb62 into ROCm:develop Jan 17, 2025
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants