-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
code-gen: Allowed WaveGroups be distributed along n-dim for DTVA/SwizzledA #1493
Changes from all commits
69fad7b
62af081
7978b83
ca4d28e
9c9fe65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1458,8 +1458,10 @@ def setGlobalLoadTileDimClassic(state, tc, numLoads, totalVectorsCoalesced, tota | |
and totalElementsPerp % nlp == 0: | ||
state["NumLoadsCoalesced%s"%tc] = nlc | ||
state["NumLoadsPerpendicular%s"%tc] = nlp | ||
#print("NumLoadsCoalesced",state["NumLoadsCoalesced%s"%tc]) | ||
#print("NumLoadsPerpendicular",state["NumLoadsPerpendicular%s"%tc]) | ||
# print("NumLoads%s:"%tc,state["NumLoads%s"%tc]) | ||
# print("NumLoadsCoalesced%s:"%tc,state["NumLoadsCoalesced%s"%tc]) | ||
# print("NumLoadsPerpendicular%s:"%tc,state["NumLoadsPerpendicular%s"%tc]) | ||
# print("\n") | ||
foundValid = True | ||
break | ||
if not foundValid: | ||
|
@@ -1892,15 +1894,15 @@ def isDirectToVgprDoable(state, tc): | |
reject(state, "DirectToVgpr%c does not support TLU%c+ numByte >= 4 + MIInputPerThread > 1"%(tc, tc)) | ||
return False | ||
|
||
# MIWaveGroup, MatrixInstBM,BN check | ||
# for A, MIWaveGroup[1] and MatrixInstBN should be 1 | ||
# for B, MIWaveGroup[0] and MatrixInstBM should be 1 | ||
# MatrixInstBM,BN check | ||
# for A, MatrixInstBN should be 1 | ||
# for B, MatrixInstBM should be 1 | ||
# This is to limit the number of Vgpr | ||
if tc == 'A' and not (state['MIWaveGroup'][1] == 1 and state['MatrixInstBN'] == 1): | ||
reject(state, "MIWaveGroup[1] and MatrixInstBN should be 1 for DirectToVgprA. Current value is [%d, %d]"%(state['MIWaveGroup'][1], state['MatrixInstBN'])) | ||
if tc == 'A' and not (state['MatrixInstBN'] == 1): | ||
reject(state, "MatrixInstBN should be 1 for DirectToVgprA. Current value is %d"%(state['MatrixInstBN'])) | ||
return False | ||
if tc == 'B' and not (state['MIWaveGroup'][0] == 1 and state['MatrixInstBM'] == 1): | ||
reject(state, "MIWaveGroup[0] and MatrixInstBM should be 1 for DirectToVgprB. Current value is [%d, %d]"%(state['MIWaveGroup'][0], state['MatrixInstBM'])) | ||
if tc == 'B' and not (state['MatrixInstBM'] == 1): | ||
reject(state, "MatrixInstBM should be 1 for DirectToVgprB. Current value is %d"%(state['MatrixInstBM'])) | ||
return False | ||
Comment on lines
1900
to
1906
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the restrictions |
||
|
||
# Does not work with WaveSeparateGlobalRead | ||
|
@@ -1959,7 +1961,7 @@ def isDirectToVgprDoable(state, tc): | |
if state["PrefetchGlobalRead"] == 0: | ||
reject(state, "DirectToVgpr%c does not supports PrefetchGlobalRead == 0."%(tc)) | ||
return False | ||
|
||
# for DTVA, does not work with NN and TLDS0 | ||
if tc == 'A' and state["TransposeLDS"] == 0 and (not state["ProblemType"]["TransposeA"] and not state["ProblemType"]["TransposeB"]): | ||
reject(state, "DirectToVgpr%c does not supports NN case with TransposeLDS == 0."%(tc)) | ||
|
@@ -1974,7 +1976,7 @@ def isDirectToVgprDoable(state, tc): | |
if tc == 'B' and (not state["ProblemType"]["TransposeA"] and not state["ProblemType"]["TransposeB"]): | ||
# Use AssertSummationElementMultiple (BoundSizeMultiple in predicates) to exclude failed tail-loop cases | ||
state["AssertSummationElementMultiple"] = max(state["AssertSummationElementMultiple"], state["DepthU"]) | ||
|
||
# Does not work with DirectToLDS | ||
# -> this will be checked after DirectToLDS doable check is done | ||
|
||
|
@@ -2985,19 +2987,27 @@ def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int: | |
validDepthU = True | ||
|
||
# how many elements to load | ||
if state["ProblemType"]["TLUA"]: | ||
if state["ProblemType"]["TLUA"]: # NT/NN | ||
totalElementsCoalescedA = state["MacroTileA"] | ||
totalElementsPerpA = depthUA | ||
else: | ||
if state["DirectToVgprA"]: | ||
totalElementsCoalescedA *= state["MIWaveGroup"][1] | ||
else: # TN/TT | ||
totalElementsCoalescedA = depthUA | ||
totalElementsPerpA = state["MacroTileA"] | ||
if state["DirectToVgprA"]: | ||
totalElementsPerpA *= state["MIWaveGroup"][1] | ||
|
||
Comment on lines
2997
to
3000
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eventually, other than modifying the NumLoadPerp in the last step, modifying the total-Elem-Perp here is the best way. Same for DTVB |
||
if state["ProblemType"]["TLUB"]: | ||
if state["ProblemType"]["TLUB"]: # NT/TT | ||
totalElementsCoalescedB = state["MacroTileB"] | ||
totalElementsPerpB = depthUB | ||
else: | ||
if state["DirectToVgprB"]: | ||
totalElementsCoalescedB *= state["MIWaveGroup"][0] | ||
else: # TN/NN | ||
totalElementsCoalescedB = depthUB | ||
totalElementsPerpB = state["MacroTileB"] | ||
if state["DirectToVgprB"]: | ||
totalElementsPerpB *= state["MIWaveGroup"][0] | ||
|
||
Comment on lines
-2995
to
3011
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed: handled different TLU cases (otherwise, NT ProblemType failed) |
||
totalElementsA = totalElementsCoalescedA * totalElementsPerpA | ||
totalElementsB = totalElementsCoalescedB * totalElementsPerpB | ||
|
@@ -3249,7 +3259,7 @@ def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int: | |
if not Solution.isDirectToVgprDoable(state, 'A'): | ||
return # rejected | ||
if state["DirectToVgprB"]: | ||
if not Solution.isDirectToVgprDoable(state, 'B'): | ||
if not Solution.isDirectToVgprDoable(state, 'B'): | ||
return # rejected | ||
|
||
######################################## | ||
|
@@ -3478,6 +3488,15 @@ def subCheckLdsBlockSizePerPad(tc, idx): | |
#1LDS buffer must be 0 for DirectToLdsA | ||
state["1LDSBuffer"] = 0 | ||
|
||
# Re-check DTV + WaveGroup after DTL is confirmed | ||
if state["DirectToLds"]: | ||
if state["DirectToVgprA"] and state['MIWaveGroup'][1] > 1: | ||
reject(state, "DirectToLds + (DirectToVgprA + WaveGroups along N-Dim) is not supported yet") | ||
return False | ||
if state["DirectToVgprB"] and state['MIWaveGroup'][0] > 1: | ||
reject(state, "DirectToLds + (DirectToVgprB + WaveGroups along M-Dim) is not supported yet") | ||
return False | ||
|
||
# set NoLdsWriteCode if (DirectToVgpr or DirectToLds)A+B is enabled | ||
state["NoLdsWriteCode"] = False | ||
if (state["DirectToVgprA"] or state["DirectToLdsA"]) and (state["DirectToVgprB"] or state["DirectToLdsB"]): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,9 +10,12 @@ GlobalParameters: | |
CMakeBuildType: Release | ||
KernelTime: True | ||
MaxWorkspaceSize: 13421772800 | ||
DataInitTypeA: 13 | ||
DataInitTypeB: 12 | ||
DataInitTypeAlpha: 1 | ||
DataInitTypeBeta: 1 | ||
BoundsCheck: 2 | ||
DataInitTypeBias: 13 | ||
DataInitTypeScaleAlphaVec: 12 | ||
#MaxFileName: 256 | ||
|
||
BenchmarkProblems: | ||
|
@@ -439,6 +442,11 @@ BenchmarkProblems: | |
- [16, 16, 16, 1, 1, 5, 1, 2,1 ] # MT = 160x16. Case to check kernel writer works | ||
- [32, 32, 8, 1, 1, 2, 1, 4,1 ] | ||
- [32, 32, 8, 1, 1, 4, 1, 4,1 ] | ||
|
||
- [16, 16, 16, 1, 1, 8,1, 1,4 ] # 128x64 | ||
- [16, 16, 16, 1, 1, 8,1, 2,2 ] # 256x32 | ||
- [16, 16, 16, 1, 1, 16,1, 2,2 ] # 512x32 | ||
- [16, 16, 16, 1, 1, 10,1, 2,2 ] # 160x32 | ||
- WorkGroup: | ||
Comment on lines
444
to
450
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DTVA: added a few WaveG = 2x2 and 1x4 |
||
- [16,16,1] | ||
- GlobalReadVectorWidthA: [2,4,8] | ||
|
@@ -613,6 +621,11 @@ BenchmarkProblems: | |
- [16, 16, 16, 1, 1, 1, 5, 1,2 ] # MT = 16x160. Case to check kernel writer works | ||
- [32, 32, 8, 1, 1, 1, 2, 1,4 ] | ||
- [32, 32, 8, 1, 1, 1, 4, 1,4 ] | ||
|
||
- [16, 16, 16, 1, 1, 1,8, 4,1 ] # 64x128 | ||
- [16, 16, 16, 1, 1, 1,8, 2,2 ] # 32x256 | ||
- [16, 16, 16, 1, 1, 1,16, 2,2 ] # 32x512 | ||
- [16, 16, 16, 1, 1, 1,10, 2,2 ] # 32x160 | ||
- WorkGroup: | ||
- [16,16,1] | ||
- GlobalReadVectorWidthA: [2,4,8] | ||
|
@@ -1985,6 +1998,10 @@ BenchmarkProblems: | |
- [16, 16, 16, 1, 1, 1,8, 4,1 ] # 64x128 | ||
- [16, 16, 16, 1, 1, 2,8, 4,1 ] # 128x128 | ||
- [16, 16, 16, 1, 1, 4,16, 4,1 ] # 256x256 | ||
|
||
- [16, 16, 16, 1, 1, 4,2, 1,4 ] # 64x128 | ||
- [16, 16, 16, 1, 1, 8,2, 1,4 ] # 128x128 | ||
- [16, 16, 16, 1, 1, 16,4, 1,4 ] # 256x256 | ||
- AssertFree0ElementMultiple: [16] | ||
- AssertFree1ElementMultiple: [16] | ||
- AssertSummationElementMultiple: [32] | ||
|
@@ -2050,6 +2067,10 @@ BenchmarkProblems: | |
- [16, 16, 16, 1, 1, 4,2, 1,4 ] # 64x128 | ||
- [16, 16, 16, 1, 1, 8,2, 1,4 ] # 128x128 | ||
- [16, 16, 16, 1, 1, 16,4, 1,4 ] # 256x256 | ||
|
||
- [16, 16, 16, 1, 1, 2,4, 4,1 ] # 128x64 | ||
- [16, 16, 16, 1, 1, 2,8, 4,1 ] # 128x128 | ||
- [16, 16, 16, 1, 1, 4,16, 4,1 ] # 256x256 | ||
- AssertFree0ElementMultiple: [16] | ||
- AssertFree1ElementMultiple: [16] | ||
- AssertSummationElementMultiple: [32] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,15 +11,17 @@ GlobalParameters: | |
KernelTime: True | ||
MaxWorkspaceSize: 13421772800 | ||
DataInitTypeA: 13 | ||
DataInitTypeB: 13 | ||
DataInitTypeB: 12 | ||
DataInitTypeAlpha: 1 | ||
DataInitTypeBeta: 1 | ||
DataInitTypeBias: 13 | ||
DataInitTypeScaleAlphaVec: 12 | ||
BoundsCheck: 2 | ||
#MaxFileName: 256 | ||
|
||
BenchmarkProblems: | ||
######################################## | ||
# HHS TN DTVA + SWIZZLED_A + BIAS + Activation | ||
# HHS TN DTVA + SWIZZLED_A + BIAS + Activation + SAV | ||
######################################## | ||
- | ||
- # ProblemType | ||
|
@@ -36,6 +38,7 @@ BenchmarkProblems: | |
UseBias: 1 | ||
Activation: True | ||
BiasDataTypeList: ['h'] | ||
UseScaleAlphaVec: 1 | ||
- # BenchmarkProblemSizeGroup - Standard - All problem | ||
InitialSolutionParameters: | ||
BenchmarkCommonParameters: | ||
|
@@ -57,6 +60,24 @@ BenchmarkProblems: | |
- [16, 16, 16, 1, 1, 8, 8, 4,1 ] # MT = 512x128 | ||
- [16, 16, 16, 1, 1, 8, 16, 4,1 ] # MT = 512x256 | ||
- [16, 16, 16, 1, 1, 5, 8, 2,1 ] # MT = 160x128 | ||
|
||
- [16, 16, 16, 1, 1, 2, 4, 2,2 ] # MT = 64x128 | ||
- [16, 16, 16, 1, 1, 2, 8, 2,2 ] # MT = 64x256 | ||
- [16, 16, 16, 1, 1, 4, 4, 2,2 ] # MT = 128x128 | ||
- [16, 16, 16, 1, 1, 4, 8, 2,2 ] # MT = 128x256 | ||
- [16, 16, 16, 1, 1, 8, 4, 2,2 ] # MT = 256x128 | ||
- [16, 16, 16, 1, 1, 8, 8, 2,2 ] # MT = 256x256 | ||
- [16, 16, 16, 1, 1, 16, 4, 2,2 ] # MT = 512x128 | ||
- [16, 16, 16, 1, 1, 16, 8, 2,2 ] # MT = 512x256 | ||
- [16, 16, 16, 1, 1, 5, 4, 2,2 ] # MT = 160x128 | ||
|
||
- [16, 16, 16, 1, 1, 4, 2, 1,4 ] # MT = 64x128 | ||
- [16, 16, 16, 1, 1, 4, 4, 1,4 ] # MT = 64x256 | ||
- [16, 16, 16, 1, 1, 8, 2, 1,4 ] # MT = 128x128 | ||
- [16, 16, 16, 1, 1, 8, 4, 1,4 ] # MT = 128x256 | ||
- [16, 16, 16, 1, 1, 16, 2, 1,4 ] # MT = 256x128 | ||
- [16, 16, 16, 1, 1, 16, 4, 1,4 ] # MT = 256x256 | ||
- [16, 16, 16, 1, 1, 10, 2, 1,4 ] # MT = 160x128 | ||
- AssertFree0ElementMultiple: [16] | ||
Comment on lines
62
to
81
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SwizzledA: Added WaveG = 2x2 and 1x4 |
||
- AssertSummationElementMultiple: [32] | ||
- GlobalReadVectorWidthA: [8] | ||
|
@@ -102,4 +123,4 @@ BenchmarkProblems: | |
- BiasTypeArgs: ['h'] | ||
- ActivationArgs: | ||
- [Enum: none] | ||
- [Enum: relu] | ||
- [Enum: relu] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor modification for swizzled-A