diff --git a/tensilelite/Tensile/KernelWriterAssembly.py b/tensilelite/Tensile/KernelWriterAssembly.py index 4f9188de54..f616b16466 100644 --- a/tensilelite/Tensile/KernelWriterAssembly.py +++ b/tensilelite/Tensile/KernelWriterAssembly.py @@ -3352,7 +3352,9 @@ def lwaTileAssignment(self, kernel, tP): tmpSgpr = tmpSgprInfo.idx # Calc numKr, TODO- 32 should be MI_K * 2 module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(32)), src=sgpr("SizesSum"), comment="SWZ: numKr = DimK / 32")) - module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave-id *= numKr")) + WvG_M = kernel["MIWaveGroup"][0] + module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) %= MIWG[0]")) + module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) *= numKr")) elif isDTVAB: # offset calculation for DirectToVgpr # call function from LraTileAssignmentMFMA for DirectToVgpr diff --git a/tensilelite/Tensile/SolutionStructs.py b/tensilelite/Tensile/SolutionStructs.py index fb8ad4e7b9..2d5ec86fdc 100644 --- a/tensilelite/Tensile/SolutionStructs.py +++ b/tensilelite/Tensile/SolutionStructs.py @@ -1458,8 +1458,10 @@ def setGlobalLoadTileDimClassic(state, tc, numLoads, totalVectorsCoalesced, tota and totalElementsPerp % nlp == 0: state["NumLoadsCoalesced%s"%tc] = nlc state["NumLoadsPerpendicular%s"%tc] = nlp - #print("NumLoadsCoalesced",state["NumLoadsCoalesced%s"%tc]) - #print("NumLoadsPerpendicular",state["NumLoadsPerpendicular%s"%tc]) + # print("NumLoads%s:"%tc,state["NumLoads%s"%tc]) + # print("NumLoadsCoalesced%s:"%tc,state["NumLoadsCoalesced%s"%tc]) + # print("NumLoadsPerpendicular%s:"%tc,state["NumLoadsPerpendicular%s"%tc]) + # print("\n") foundValid = True break if not foundValid: @@ -1892,15 +1894,15 @@ def isDirectToVgprDoable(state, tc): reject(state, "DirectToVgpr%c does not support TLU%c+ numByte >= 4 + MIInputPerThread > 1"%(tc, tc)) return False - # MIWaveGroup, MatrixInstBM,BN check - # for A, MIWaveGroup[1] and MatrixInstBN should be 1 - # for B, MIWaveGroup[0] and MatrixInstBM should be 1 + # MatrixInstBM,BN check + # for A, MatrixInstBN should be 1 + # for B, MatrixInstBM should be 1 # This is to limit the number of Vgpr - if tc == 'A' and not (state['MIWaveGroup'][1] == 1 and state['MatrixInstBN'] == 1): - reject(state, "MIWaveGroup[1] and MatrixInstBN should be 1 for DirectToVgprA. Current value is [%d, %d]"%(state['MIWaveGroup'][1], state['MatrixInstBN'])) + if tc == 'A' and not (state['MatrixInstBN'] == 1): + reject(state, "MatrixInstBN should be 1 for DirectToVgprA. Current value is %d"%(state['MatrixInstBN'])) return False - if tc == 'B' and not (state['MIWaveGroup'][0] == 1 and state['MatrixInstBM'] == 1): - reject(state, "MIWaveGroup[0] and MatrixInstBM should be 1 for DirectToVgprB. Current value is [%d, %d]"%(state['MIWaveGroup'][0], state['MatrixInstBM'])) + if tc == 'B' and not (state['MatrixInstBM'] == 1): + reject(state, "MatrixInstBM should be 1 for DirectToVgprB. Current value is %d"%(state['MatrixInstBM'])) return False # Does not work with WaveSeparateGlobalRead @@ -1959,7 +1961,7 @@ def isDirectToVgprDoable(state, tc): if state["PrefetchGlobalRead"] == 0: reject(state, "DirectToVgpr%c does not supports PrefetchGlobalRead == 0."%(tc)) return False - + # for DTVA, does not work with NN and TLDS0 if tc == 'A' and state["TransposeLDS"] == 0 and (not state["ProblemType"]["TransposeA"] and not state["ProblemType"]["TransposeB"]): reject(state, "DirectToVgpr%c does not supports NN case with TransposeLDS == 0."%(tc)) @@ -1974,7 +1976,7 @@ def isDirectToVgprDoable(state, tc): if tc == 'B' and (not state["ProblemType"]["TransposeA"] and not state["ProblemType"]["TransposeB"]): # Use AssertSummationElementMultiple (BoundSizeMultiple in predicates) to exclude failed tail-loop cases state["AssertSummationElementMultiple"] = max(state["AssertSummationElementMultiple"], state["DepthU"]) - + # Does not work with DirectToLDS # -> this will be checked after DirectToLDS doable check is done @@ -2985,19 +2987,27 @@ def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int: validDepthU = True # how many elements to load - if state["ProblemType"]["TLUA"]: + if state["ProblemType"]["TLUA"]: # NT/NN totalElementsCoalescedA = state["MacroTileA"] totalElementsPerpA = depthUA - else: + if state["DirectToVgprA"]: + totalElementsCoalescedA *= state["MIWaveGroup"][1] + else: # TN/TT totalElementsCoalescedA = depthUA totalElementsPerpA = state["MacroTileA"] + if state["DirectToVgprA"]: + totalElementsPerpA *= state["MIWaveGroup"][1] - if state["ProblemType"]["TLUB"]: + if state["ProblemType"]["TLUB"]: # NT/TT totalElementsCoalescedB = state["MacroTileB"] totalElementsPerpB = depthUB - else: + if state["DirectToVgprB"]: + totalElementsCoalescedB *= state["MIWaveGroup"][0] + else: # TN/NN totalElementsCoalescedB = depthUB totalElementsPerpB = state["MacroTileB"] + if state["DirectToVgprB"]: + totalElementsPerpB *= state["MIWaveGroup"][0] totalElementsA = totalElementsCoalescedA * totalElementsPerpA totalElementsB = totalElementsCoalescedB * totalElementsPerpB @@ -3249,7 +3259,7 @@ def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int: if not Solution.isDirectToVgprDoable(state, 'A'): return # rejected if state["DirectToVgprB"]: - if not Solution.isDirectToVgprDoable(state, 'B'): + if not Solution.isDirectToVgprDoable(state, 'B'): return # rejected ######################################## @@ -3478,6 +3488,15 @@ def subCheckLdsBlockSizePerPad(tc, idx): #1LDS buffer must be 0 for DirectToLdsA state["1LDSBuffer"] = 0 + # Re-check DTV + WaveGroup after DTL is confirmed + if state["DirectToLds"]: + if state["DirectToVgprA"] and state['MIWaveGroup'][1] > 1: + reject(state, "DirectToLds + (DirectToVgprA + WaveGroups along N-Dim) is not supported yet") + return False + if state["DirectToVgprB"] and state['MIWaveGroup'][0] > 1: + reject(state, "DirectToLds + (DirectToVgprB + WaveGroups along M-Dim) is not supported yet") + return False + # set NoLdsWriteCode if (DirectToVgpr or DirectToLds)A+B is enabled state["NoLdsWriteCode"] = False if (state["DirectToVgprA"] or state["DirectToLdsA"]) and (state["DirectToVgprB"] or state["DirectToLdsB"]): diff --git a/tensilelite/Tensile/Tests/common/gemm/dtl.yaml b/tensilelite/Tensile/Tests/common/gemm/dtl.yaml index 0eb027d21d..69f9486f38 100644 --- a/tensilelite/Tensile/Tests/common/gemm/dtl.yaml +++ b/tensilelite/Tensile/Tests/common/gemm/dtl.yaml @@ -5,7 +5,7 @@ GlobalParameters: NumElementsToValidate: -1 MinimumRequiredVersion: 4.14.0 PrintLevel: 1 - PrintSolutionRejectionReason: True + # PrintSolutionRejectionReason: True Device: 0 CMakeBuildType: Release KernelTime: True @@ -82,7 +82,7 @@ BenchmarkProblems: - DirectToVgprA: [0,1] - DirectToVgprB: [0,1] - Groups: - - + - - MatrixInstruction: [16, 16, 16, 1, 1, 2, 2, 2,2 ] - MatrixInstruction: [16, 16, 16, 1, 1, 4, 1, 4,1 ] - MatrixInstruction: [16, 16, 16, 1, 1, 1, 4, 1,4 ] @@ -104,7 +104,7 @@ BenchmarkProblems: - Exact: [255, 255, 1, 126] - Exact: [255, 255, 1, 190] - Exact: [255, 255, 1, 256] - + ######################################## # HHS NT DTL ######################################## @@ -170,7 +170,7 @@ BenchmarkProblems: - DirectToVgprA: [0,1] - DirectToVgprB: [0,1] - Groups: - - + - - MatrixInstruction: [16, 16, 16, 1, 1, 2, 2, 2,2 ] - MatrixInstruction: [16, 16, 16, 1, 1, 4, 1, 4,1 ] - MatrixInstruction: [16, 16, 16, 1, 1, 1, 4, 1,4 ] @@ -192,7 +192,7 @@ BenchmarkProblems: - Exact: [255, 255, 1, 127] - Exact: [255, 255, 1, 191] - Exact: [255, 255, 1, 256] - + ######################################## # HHS TN DTL ######################################## @@ -258,7 +258,7 @@ BenchmarkProblems: - DirectToVgprA: [0,1] - DirectToVgprB: [0,1] - Groups: - - + - - MatrixInstruction: [16, 16, 16, 1, 1, 2, 2, 2,2 ] - MatrixInstruction: [16, 16, 16, 1, 1, 4, 1, 4,1 ] - MatrixInstruction: [16, 16, 16, 1, 1, 1, 4, 1,4 ] @@ -345,7 +345,7 @@ BenchmarkProblems: - DirectToVgprA: [0,1] - DirectToVgprB: [0,1] - Groups: - - + - - MatrixInstruction: [16, 16, 16, 1, 1, 2, 2, 2,2 ] - MatrixInstruction: [16, 16, 16, 1, 1, 4, 1, 4,1 ] - MatrixInstruction: [16, 16, 16, 1, 1, 1, 4, 1,4 ] @@ -367,7 +367,7 @@ BenchmarkProblems: - Exact: [255, 255, 1, 126] - Exact: [255, 255, 1, 190] - Exact: [255, 255, 1, 256] - + ######################################## # SGEMM NT DTL ######################################## @@ -426,7 +426,7 @@ BenchmarkProblems: - DirectToVgprA: [0,1] - DirectToVgprB: [0,1] - Groups: - - + - - MatrixInstruction: [16, 16, 4, 1, 1, 2, 1, 2,2 ] - MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 4,1 ] - MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 1,4 ] @@ -507,7 +507,7 @@ BenchmarkProblems: - DirectToVgprA: [0,1] - DirectToVgprB: [0,1] - Groups: - - + - - MatrixInstruction: [16, 16, 4, 1, 1, 2, 1, 2,2 ] - MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 4,1 ] - MatrixInstruction: [16, 16, 4, 1, 1, 1, 4, 1,4 ] @@ -528,7 +528,7 @@ BenchmarkProblems: - Exact: [255, 255, 1, 111] - Exact: [255, 255, 1, 127] - Exact: [255, 255, 1, 128] - + ######################################## # SGEMM TN DTL ######################################## @@ -588,7 +588,7 @@ BenchmarkProblems: - DirectToVgprA: [0,1] - DirectToVgprB: [0,1] - Groups: - - + - - MatrixInstruction: [16, 16, 4, 1, 1, 2, 1, 2,2 ] - MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 4,1 ] - MatrixInstruction: [32, 32, 2, 1, 1, 2, 1, 2,2 ] @@ -608,7 +608,7 @@ BenchmarkProblems: - Exact: [255, 255, 1, 111] - Exact: [255, 255, 1, 127] - Exact: [255, 255, 1, 128] - + ######################################## # SGEMM TT DTL ######################################## @@ -668,7 +668,7 @@ BenchmarkProblems: - DirectToVgprA: [0,1] - DirectToVgprB: [0,1] - Groups: - - + - - MatrixInstruction: [16, 16, 4, 1, 1, 2, 1, 2,2 ] - MatrixInstruction: [16, 16, 4, 1, 1, 4, 1, 4,1 ] - MatrixInstruction: [16, 16, 4, 1, 1, 1, 4, 1,4 ] diff --git a/tensilelite/Tensile/Tests/common/gemm/dtv.yaml b/tensilelite/Tensile/Tests/common/gemm/dtv.yaml index 65960f948a..1c0a729edb 100644 --- a/tensilelite/Tensile/Tests/common/gemm/dtv.yaml +++ b/tensilelite/Tensile/Tests/common/gemm/dtv.yaml @@ -10,9 +10,12 @@ GlobalParameters: CMakeBuildType: Release KernelTime: True MaxWorkspaceSize: 13421772800 + DataInitTypeA: 13 + DataInitTypeB: 12 DataInitTypeAlpha: 1 DataInitTypeBeta: 1 - BoundsCheck: 2 + DataInitTypeBias: 13 + DataInitTypeScaleAlphaVec: 12 #MaxFileName: 256 BenchmarkProblems: @@ -439,6 +442,11 @@ BenchmarkProblems: - [16, 16, 16, 1, 1, 5, 1, 2,1 ] # MT = 160x16. Case to check kernel writer works - [32, 32, 8, 1, 1, 2, 1, 4,1 ] - [32, 32, 8, 1, 1, 4, 1, 4,1 ] + + - [16, 16, 16, 1, 1, 8,1, 1,4 ] # 128x64 + - [16, 16, 16, 1, 1, 8,1, 2,2 ] # 256x32 + - [16, 16, 16, 1, 1, 16,1, 2,2 ] # 512x32 + - [16, 16, 16, 1, 1, 10,1, 2,2 ] # 160x32 - WorkGroup: - [16,16,1] - GlobalReadVectorWidthA: [2,4,8] @@ -613,6 +621,11 @@ BenchmarkProblems: - [16, 16, 16, 1, 1, 1, 5, 1,2 ] # MT = 16x160. Case to check kernel writer works - [32, 32, 8, 1, 1, 1, 2, 1,4 ] - [32, 32, 8, 1, 1, 1, 4, 1,4 ] + + - [16, 16, 16, 1, 1, 1,8, 4,1 ] # 64x128 + - [16, 16, 16, 1, 1, 1,8, 2,2 ] # 32x256 + - [16, 16, 16, 1, 1, 1,16, 2,2 ] # 32x512 + - [16, 16, 16, 1, 1, 1,10, 2,2 ] # 32x160 - WorkGroup: - [16,16,1] - GlobalReadVectorWidthA: [2,4,8] @@ -1985,6 +1998,10 @@ BenchmarkProblems: - [16, 16, 16, 1, 1, 1,8, 4,1 ] # 64x128 - [16, 16, 16, 1, 1, 2,8, 4,1 ] # 128x128 - [16, 16, 16, 1, 1, 4,16, 4,1 ] # 256x256 + + - [16, 16, 16, 1, 1, 4,2, 1,4 ] # 64x128 + - [16, 16, 16, 1, 1, 8,2, 1,4 ] # 128x128 + - [16, 16, 16, 1, 1, 16,4, 1,4 ] # 256x256 - AssertFree0ElementMultiple: [16] - AssertFree1ElementMultiple: [16] - AssertSummationElementMultiple: [32] @@ -2050,6 +2067,10 @@ BenchmarkProblems: - [16, 16, 16, 1, 1, 4,2, 1,4 ] # 64x128 - [16, 16, 16, 1, 1, 8,2, 1,4 ] # 128x128 - [16, 16, 16, 1, 1, 16,4, 1,4 ] # 256x256 + + - [16, 16, 16, 1, 1, 2,4, 4,1 ] # 128x64 + - [16, 16, 16, 1, 1, 2,8, 4,1 ] # 128x128 + - [16, 16, 16, 1, 1, 4,16, 4,1 ] # 256x256 - AssertFree0ElementMultiple: [16] - AssertFree1ElementMultiple: [16] - AssertSummationElementMultiple: [32] diff --git a/tensilelite/Tensile/Tests/common/gemm/dtvA_swizzleA.yaml b/tensilelite/Tensile/Tests/common/gemm/dtvA_swizzleA.yaml index 3b77fa22bc..ff085984b7 100644 --- a/tensilelite/Tensile/Tests/common/gemm/dtvA_swizzleA.yaml +++ b/tensilelite/Tensile/Tests/common/gemm/dtvA_swizzleA.yaml @@ -11,15 +11,17 @@ GlobalParameters: KernelTime: True MaxWorkspaceSize: 13421772800 DataInitTypeA: 13 - DataInitTypeB: 13 + DataInitTypeB: 12 DataInitTypeAlpha: 1 DataInitTypeBeta: 1 + DataInitTypeBias: 13 + DataInitTypeScaleAlphaVec: 12 BoundsCheck: 2 #MaxFileName: 256 BenchmarkProblems: ######################################## - # HHS TN DTVA + SWIZZLED_A + BIAS + Activation + # HHS TN DTVA + SWIZZLED_A + BIAS + Activation + SAV ######################################## - - # ProblemType @@ -36,6 +38,7 @@ BenchmarkProblems: UseBias: 1 Activation: True BiasDataTypeList: ['h'] + UseScaleAlphaVec: 1 - # BenchmarkProblemSizeGroup - Standard - All problem InitialSolutionParameters: BenchmarkCommonParameters: @@ -57,6 +60,24 @@ BenchmarkProblems: - [16, 16, 16, 1, 1, 8, 8, 4,1 ] # MT = 512x128 - [16, 16, 16, 1, 1, 8, 16, 4,1 ] # MT = 512x256 - [16, 16, 16, 1, 1, 5, 8, 2,1 ] # MT = 160x128 + + - [16, 16, 16, 1, 1, 2, 4, 2,2 ] # MT = 64x128 + - [16, 16, 16, 1, 1, 2, 8, 2,2 ] # MT = 64x256 + - [16, 16, 16, 1, 1, 4, 4, 2,2 ] # MT = 128x128 + - [16, 16, 16, 1, 1, 4, 8, 2,2 ] # MT = 128x256 + - [16, 16, 16, 1, 1, 8, 4, 2,2 ] # MT = 256x128 + - [16, 16, 16, 1, 1, 8, 8, 2,2 ] # MT = 256x256 + - [16, 16, 16, 1, 1, 16, 4, 2,2 ] # MT = 512x128 + - [16, 16, 16, 1, 1, 16, 8, 2,2 ] # MT = 512x256 + - [16, 16, 16, 1, 1, 5, 4, 2,2 ] # MT = 160x128 + + - [16, 16, 16, 1, 1, 4, 2, 1,4 ] # MT = 64x128 + - [16, 16, 16, 1, 1, 4, 4, 1,4 ] # MT = 64x256 + - [16, 16, 16, 1, 1, 8, 2, 1,4 ] # MT = 128x128 + - [16, 16, 16, 1, 1, 8, 4, 1,4 ] # MT = 128x256 + - [16, 16, 16, 1, 1, 16, 2, 1,4 ] # MT = 256x128 + - [16, 16, 16, 1, 1, 16, 4, 1,4 ] # MT = 256x256 + - [16, 16, 16, 1, 1, 10, 2, 1,4 ] # MT = 160x128 - AssertFree0ElementMultiple: [16] - AssertSummationElementMultiple: [32] - GlobalReadVectorWidthA: [8] @@ -102,4 +123,4 @@ BenchmarkProblems: - BiasTypeArgs: ['h'] - ActivationArgs: - [Enum: none] - - [Enum: relu] + - [Enum: relu] \ No newline at end of file