Skip to content

Commit

Permalink
code-gen: Allowed WaveGroups be distributed along n-dim for DTVA/Swiz…
Browse files Browse the repository at this point in the history
…zledA

* Allow WaveGroup in N-dim for swizzledA

* directly modify totalElementsCoal/PerpA

* restore assertion and rejection

* Added DTVB & updated pytests

* Fix TLU=True case
  • Loading branch information
solaslin authored Jan 17, 2025
1 parent 0357bb4 commit 38efb62
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 35 deletions.
4 changes: 3 additions & 1 deletion tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -3363,7 +3363,9 @@ def lwaTileAssignment(self, kernel, tP):
tmpSgpr = tmpSgprInfo.idx
# Calc numKr, TODO- 32 should be MI_K * 2
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(32)), src=sgpr("SizesSum"), comment="SWZ: numKr = DimK / 32"))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave-id *= numKr"))
WvG_M = kernel["MIWaveGroup"][0]
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) %= MIWG[0]"))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) *= numKr"))
elif isDTVAB:
# offset calculation for DirectToVgpr
# call function from LraTileAssignmentMFMA for DirectToVgpr
Expand Down
51 changes: 35 additions & 16 deletions tensilelite/Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,8 +1458,10 @@ def setGlobalLoadTileDimClassic(state, tc, numLoads, totalVectorsCoalesced, tota
and totalElementsPerp % nlp == 0:
state["NumLoadsCoalesced%s"%tc] = nlc
state["NumLoadsPerpendicular%s"%tc] = nlp
#print("NumLoadsCoalesced",state["NumLoadsCoalesced%s"%tc])
#print("NumLoadsPerpendicular",state["NumLoadsPerpendicular%s"%tc])
# print("NumLoads%s:"%tc,state["NumLoads%s"%tc])
# print("NumLoadsCoalesced%s:"%tc,state["NumLoadsCoalesced%s"%tc])
# print("NumLoadsPerpendicular%s:"%tc,state["NumLoadsPerpendicular%s"%tc])
# print("\n")
foundValid = True
break
if not foundValid:
Expand Down Expand Up @@ -1892,15 +1894,15 @@ def isDirectToVgprDoable(state, tc):
reject(state, "DirectToVgpr%c does not support TLU%c+ numByte >= 4 + MIInputPerThread > 1"%(tc, tc))
return False

# MIWaveGroup, MatrixInstBM,BN check
# for A, MIWaveGroup[1] and MatrixInstBN should be 1
# for B, MIWaveGroup[0] and MatrixInstBM should be 1
# MatrixInstBM,BN check
# for A, MatrixInstBN should be 1
# for B, MatrixInstBM should be 1
# This is to limit the number of Vgpr
if tc == 'A' and not (state['MIWaveGroup'][1] == 1 and state['MatrixInstBN'] == 1):
reject(state, "MIWaveGroup[1] and MatrixInstBN should be 1 for DirectToVgprA. Current value is [%d, %d]"%(state['MIWaveGroup'][1], state['MatrixInstBN']))
if tc == 'A' and not (state['MatrixInstBN'] == 1):
reject(state, "MatrixInstBN should be 1 for DirectToVgprA. Current value is %d"%(state['MatrixInstBN']))
return False
if tc == 'B' and not (state['MIWaveGroup'][0] == 1 and state['MatrixInstBM'] == 1):
reject(state, "MIWaveGroup[0] and MatrixInstBM should be 1 for DirectToVgprB. Current value is [%d, %d]"%(state['MIWaveGroup'][0], state['MatrixInstBM']))
if tc == 'B' and not (state['MatrixInstBM'] == 1):
reject(state, "MatrixInstBM should be 1 for DirectToVgprB. Current value is %d"%(state['MatrixInstBM']))
return False

# Does not work with WaveSeparateGlobalRead
Expand Down Expand Up @@ -1959,7 +1961,7 @@ def isDirectToVgprDoable(state, tc):
if state["PrefetchGlobalRead"] == 0:
reject(state, "DirectToVgpr%c does not supports PrefetchGlobalRead == 0."%(tc))
return False

# for DTVA, does not work with NN and TLDS0
if tc == 'A' and state["TransposeLDS"] == 0 and (not state["ProblemType"]["TransposeA"] and not state["ProblemType"]["TransposeB"]):
reject(state, "DirectToVgpr%c does not supports NN case with TransposeLDS == 0."%(tc))
Expand All @@ -1974,7 +1976,7 @@ def isDirectToVgprDoable(state, tc):
if tc == 'B' and (not state["ProblemType"]["TransposeA"] and not state["ProblemType"]["TransposeB"]):
# Use AssertSummationElementMultiple (BoundSizeMultiple in predicates) to exclude failed tail-loop cases
state["AssertSummationElementMultiple"] = max(state["AssertSummationElementMultiple"], state["DepthU"])

# Does not work with DirectToLDS
# -> this will be checked after DirectToLDS doable check is done

Expand Down Expand Up @@ -2985,19 +2987,27 @@ def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int:
validDepthU = True

# how many elements to load
if state["ProblemType"]["TLUA"]:
if state["ProblemType"]["TLUA"]: # NT/NN
totalElementsCoalescedA = state["MacroTileA"]
totalElementsPerpA = depthUA
else:
if state["DirectToVgprA"]:
totalElementsCoalescedA *= state["MIWaveGroup"][1]
else: # TN/TT
totalElementsCoalescedA = depthUA
totalElementsPerpA = state["MacroTileA"]
if state["DirectToVgprA"]:
totalElementsPerpA *= state["MIWaveGroup"][1]

if state["ProblemType"]["TLUB"]:
if state["ProblemType"]["TLUB"]: # NT/TT
totalElementsCoalescedB = state["MacroTileB"]
totalElementsPerpB = depthUB
else:
if state["DirectToVgprB"]:
totalElementsCoalescedB *= state["MIWaveGroup"][0]
else: # TN/NN
totalElementsCoalescedB = depthUB
totalElementsPerpB = state["MacroTileB"]
if state["DirectToVgprB"]:
totalElementsPerpB *= state["MIWaveGroup"][0]

totalElementsA = totalElementsCoalescedA * totalElementsPerpA
totalElementsB = totalElementsCoalescedB * totalElementsPerpB
Expand Down Expand Up @@ -3249,7 +3259,7 @@ def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int:
if not Solution.isDirectToVgprDoable(state, 'A'):
return # rejected
if state["DirectToVgprB"]:
if not Solution.isDirectToVgprDoable(state, 'B'):
if not Solution.isDirectToVgprDoable(state, 'B'):
return # rejected

########################################
Expand Down Expand Up @@ -3478,6 +3488,15 @@ def subCheckLdsBlockSizePerPad(tc, idx):
#1LDS buffer must be 0 for DirectToLdsA
state["1LDSBuffer"] = 0

# Re-check DTV + WaveGroup after DTL is confirmed
if state["DirectToLds"]:
if state["DirectToVgprA"] and state['MIWaveGroup'][1] > 1:
reject(state, "DirectToLds + (DirectToVgprA + WaveGroups along N-Dim) is not supported yet")
return False
if state["DirectToVgprB"] and state['MIWaveGroup'][0] > 1:
reject(state, "DirectToLds + (DirectToVgprB + WaveGroups along M-Dim) is not supported yet")
return False

# set NoLdsWriteCode if (DirectToVgpr or DirectToLds)A+B is enabled
state["NoLdsWriteCode"] = False
if (state["DirectToVgprA"] or state["DirectToLdsA"]) and (state["DirectToVgprB"] or state["DirectToLdsB"]):
Expand Down
28 changes: 14 additions & 14 deletions tensilelite/Tensile/Tests/common/gemm/dtl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ GlobalParameters:
NumElementsToValidate: -1
MinimumRequiredVersion: 4.14.0
PrintLevel: 1
PrintSolutionRejectionReason: True
# PrintSolutionRejectionReason: True
Device: 0
CMakeBuildType: Release
KernelTime: True
Expand Down Expand Up @@ -82,7 +82,7 @@ BenchmarkProblems:
- DirectToVgprA: [0,1]
- DirectToVgprB: [0,1]
- Groups:
-
-
- MatrixInstruction: [16, 16, 16, 1, 1, 2, 2, 2,2 ]
- MatrixInstruction: [16, 16, 16, 1, 1, 4, 1, 4,1 ]
- MatrixInstruction: [16, 16, 16, 1, 1, 1, 4, 1,4 ]
Expand All @@ -104,7 +104,7 @@ BenchmarkProblems:
- Exact: [255, 255, 1, 126]
- Exact: [255, 255, 1, 190]
- Exact: [255, 255, 1, 256]

########################################
# HHS NT DTL
########################################
Expand Down Expand Up @@ -170,7 +170,7 @@ BenchmarkProblems:
- DirectToVgprA: [0,1]
- DirectToVgprB: [0,1]
- Groups:
-
-
- MatrixInstruction: [16, 16, 16, 1, 1, 2, 2, 2,2 ]
- MatrixInstruction: [16, 16, 16, 1, 1, 4, 1, 4,1 ]
- MatrixInstruction: [16, 16, 16, 1, 1, 1, 4, 1,4 ]
Expand All @@ -192,7 +192,7 @@ BenchmarkProblems:
- Exact: [255, 255, 1, 127]
- Exact: [255, 255, 1, 191]
- Exact: [255, 255, 1, 256]

########################################
# HHS TN DTL
########################################
Expand Down Expand Up @@ -258,7 +258,7 @@ BenchmarkProblems:
- DirectToVgprA: [0,1]
- DirectToVgprB: [0,1]
- Groups:
-
-
- MatrixInstruction: [16, 16, 16, 1, 1, 2, 2, 2,2 ]
- MatrixInstruction: [16, 16, 16, 1, 1, 4, 1, 4,1 ]
- MatrixInstruction: [16, 16, 16, 1, 1, 1, 4, 1,4 ]
Expand Down Expand Up @@ -345,7 +345,7 @@ BenchmarkProblems:
- DirectToVgprA: [0,1]
- DirectToVgprB: [0,1]
- Groups:
-
-
- MatrixInstruction: [16, 16, 16, 1, 1, 2, 2, 2,2 ]
- MatrixInstruction: [16, 16, 16, 1, 1, 4, 1, 4,1 ]
- MatrixInstruction: [16, 16, 16, 1, 1, 1, 4, 1,4 ]
Expand All @@ -367,7 +367,7 @@ BenchmarkProblems:
- Exact: [255, 255, 1, 126]
- Exact: [255, 255, 1, 190]
- Exact: [255, 255, 1, 256]

########################################
# SGEMM NT DTL
########################################
Expand Down Expand Up @@ -426,7 +426,7 @@ BenchmarkProblems:
- DirectToVgprA: [0,1]
- DirectToVgprB: [0,1]
- Groups:
-
-
- MatrixInstruction: [16, 16, 4, 1, 1, 2, 1, 2,2 ]
- MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 4,1 ]
- MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 1,4 ]
Expand Down Expand Up @@ -507,7 +507,7 @@ BenchmarkProblems:
- DirectToVgprA: [0,1]
- DirectToVgprB: [0,1]
- Groups:
-
-
- MatrixInstruction: [16, 16, 4, 1, 1, 2, 1, 2,2 ]
- MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 4,1 ]
- MatrixInstruction: [16, 16, 4, 1, 1, 1, 4, 1,4 ]
Expand All @@ -528,7 +528,7 @@ BenchmarkProblems:
- Exact: [255, 255, 1, 111]
- Exact: [255, 255, 1, 127]
- Exact: [255, 255, 1, 128]

########################################
# SGEMM TN DTL
########################################
Expand Down Expand Up @@ -588,7 +588,7 @@ BenchmarkProblems:
- DirectToVgprA: [0,1]
- DirectToVgprB: [0,1]
- Groups:
-
-
- MatrixInstruction: [16, 16, 4, 1, 1, 2, 1, 2,2 ]
- MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 4,1 ]
- MatrixInstruction: [32, 32, 2, 1, 1, 2, 1, 2,2 ]
Expand All @@ -608,7 +608,7 @@ BenchmarkProblems:
- Exact: [255, 255, 1, 111]
- Exact: [255, 255, 1, 127]
- Exact: [255, 255, 1, 128]

########################################
# SGEMM TT DTL
########################################
Expand Down Expand Up @@ -668,7 +668,7 @@ BenchmarkProblems:
- DirectToVgprA: [0,1]
- DirectToVgprB: [0,1]
- Groups:
-
-
- MatrixInstruction: [16, 16, 4, 1, 1, 2, 1, 2,2 ]
- MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 4,1 ]
- MatrixInstruction: [16, 16, 4, 1, 1, 1, 4, 1,4 ]
Expand Down
23 changes: 22 additions & 1 deletion tensilelite/Tensile/Tests/common/gemm/dtv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ GlobalParameters:
CMakeBuildType: Release
KernelTime: True
MaxWorkspaceSize: 13421772800
DataInitTypeA: 13
DataInitTypeB: 12
DataInitTypeAlpha: 1
DataInitTypeBeta: 1
BoundsCheck: 2
DataInitTypeBias: 13
DataInitTypeScaleAlphaVec: 12
#MaxFileName: 256

BenchmarkProblems:
Expand Down Expand Up @@ -439,6 +442,11 @@ BenchmarkProblems:
- [16, 16, 16, 1, 1, 5, 1, 2,1 ] # MT = 160x16. Case to check kernel writer works
- [32, 32, 8, 1, 1, 2, 1, 4,1 ]
- [32, 32, 8, 1, 1, 4, 1, 4,1 ]

- [16, 16, 16, 1, 1, 8,1, 1,4 ] # 128x64
- [16, 16, 16, 1, 1, 8,1, 2,2 ] # 256x32
- [16, 16, 16, 1, 1, 16,1, 2,2 ] # 512x32
- [16, 16, 16, 1, 1, 10,1, 2,2 ] # 160x32
- WorkGroup:
- [16,16,1]
- GlobalReadVectorWidthA: [2,4,8]
Expand Down Expand Up @@ -613,6 +621,11 @@ BenchmarkProblems:
- [16, 16, 16, 1, 1, 1, 5, 1,2 ] # MT = 16x160. Case to check kernel writer works
- [32, 32, 8, 1, 1, 1, 2, 1,4 ]
- [32, 32, 8, 1, 1, 1, 4, 1,4 ]

- [16, 16, 16, 1, 1, 1,8, 4,1 ] # 64x128
- [16, 16, 16, 1, 1, 1,8, 2,2 ] # 32x256
- [16, 16, 16, 1, 1, 1,16, 2,2 ] # 32x512
- [16, 16, 16, 1, 1, 1,10, 2,2 ] # 32x160
- WorkGroup:
- [16,16,1]
- GlobalReadVectorWidthA: [2,4,8]
Expand Down Expand Up @@ -1985,6 +1998,10 @@ BenchmarkProblems:
- [16, 16, 16, 1, 1, 1,8, 4,1 ] # 64x128
- [16, 16, 16, 1, 1, 2,8, 4,1 ] # 128x128
- [16, 16, 16, 1, 1, 4,16, 4,1 ] # 256x256

- [16, 16, 16, 1, 1, 4,2, 1,4 ] # 64x128
- [16, 16, 16, 1, 1, 8,2, 1,4 ] # 128x128
- [16, 16, 16, 1, 1, 16,4, 1,4 ] # 256x256
- AssertFree0ElementMultiple: [16]
- AssertFree1ElementMultiple: [16]
- AssertSummationElementMultiple: [32]
Expand Down Expand Up @@ -2050,6 +2067,10 @@ BenchmarkProblems:
- [16, 16, 16, 1, 1, 4,2, 1,4 ] # 64x128
- [16, 16, 16, 1, 1, 8,2, 1,4 ] # 128x128
- [16, 16, 16, 1, 1, 16,4, 1,4 ] # 256x256

- [16, 16, 16, 1, 1, 2,4, 4,1 ] # 128x64
- [16, 16, 16, 1, 1, 2,8, 4,1 ] # 128x128
- [16, 16, 16, 1, 1, 4,16, 4,1 ] # 256x256
- AssertFree0ElementMultiple: [16]
- AssertFree1ElementMultiple: [16]
- AssertSummationElementMultiple: [32]
Expand Down
27 changes: 24 additions & 3 deletions tensilelite/Tensile/Tests/common/gemm/dtvA_swizzleA.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ GlobalParameters:
KernelTime: True
MaxWorkspaceSize: 13421772800
DataInitTypeA: 13
DataInitTypeB: 13
DataInitTypeB: 12
DataInitTypeAlpha: 1
DataInitTypeBeta: 1
DataInitTypeBias: 13
DataInitTypeScaleAlphaVec: 12
BoundsCheck: 2
#MaxFileName: 256

BenchmarkProblems:
########################################
# HHS TN DTVA + SWIZZLED_A + BIAS + Activation
# HHS TN DTVA + SWIZZLED_A + BIAS + Activation + SAV
########################################
-
- # ProblemType
Expand All @@ -36,6 +38,7 @@ BenchmarkProblems:
UseBias: 1
Activation: True
BiasDataTypeList: ['h']
UseScaleAlphaVec: 1
- # BenchmarkProblemSizeGroup - Standard - All problem
InitialSolutionParameters:
BenchmarkCommonParameters:
Expand All @@ -57,6 +60,24 @@ BenchmarkProblems:
- [16, 16, 16, 1, 1, 8, 8, 4,1 ] # MT = 512x128
- [16, 16, 16, 1, 1, 8, 16, 4,1 ] # MT = 512x256
- [16, 16, 16, 1, 1, 5, 8, 2,1 ] # MT = 160x128

- [16, 16, 16, 1, 1, 2, 4, 2,2 ] # MT = 64x128
- [16, 16, 16, 1, 1, 2, 8, 2,2 ] # MT = 64x256
- [16, 16, 16, 1, 1, 4, 4, 2,2 ] # MT = 128x128
- [16, 16, 16, 1, 1, 4, 8, 2,2 ] # MT = 128x256
- [16, 16, 16, 1, 1, 8, 4, 2,2 ] # MT = 256x128
- [16, 16, 16, 1, 1, 8, 8, 2,2 ] # MT = 256x256
- [16, 16, 16, 1, 1, 16, 4, 2,2 ] # MT = 512x128
- [16, 16, 16, 1, 1, 16, 8, 2,2 ] # MT = 512x256
- [16, 16, 16, 1, 1, 5, 4, 2,2 ] # MT = 160x128

- [16, 16, 16, 1, 1, 4, 2, 1,4 ] # MT = 64x128
- [16, 16, 16, 1, 1, 4, 4, 1,4 ] # MT = 64x256
- [16, 16, 16, 1, 1, 8, 2, 1,4 ] # MT = 128x128
- [16, 16, 16, 1, 1, 8, 4, 1,4 ] # MT = 128x256
- [16, 16, 16, 1, 1, 16, 2, 1,4 ] # MT = 256x128
- [16, 16, 16, 1, 1, 16, 4, 1,4 ] # MT = 256x256
- [16, 16, 16, 1, 1, 10, 2, 1,4 ] # MT = 160x128
- AssertFree0ElementMultiple: [16]
- AssertSummationElementMultiple: [32]
- GlobalReadVectorWidthA: [8]
Expand Down Expand Up @@ -102,4 +123,4 @@ BenchmarkProblems:
- BiasTypeArgs: ['h']
- ActivationArgs:
- [Enum: none]
- [Enum: relu]
- [Enum: relu]

0 comments on commit 38efb62

Please sign in to comment.