Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: DTVB with Swizzling (tensorB) #1562

Merged
merged 4 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions tensilelite/Tensile/Components/GSU.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,12 @@ def computeLoadSrd(self, writer, kernel, tP, stmp, tileStart):
tc = tP["tensorChar"]
depthU = kernel["DepthU"]
depthUDiv = kernel["DepthU"]
# Swizzled for A, TODO- check for SwizzleTensorB
depthUDiv = "%s%s"%(kernel["DepthU"], "*MI_M") if (tP["isSwizzled"] and tc == 'A') else "%s"%kernel["DepthU"]
#
# swizzle
if (tP["isSwizzled"] and tc == 'A'):
depthUDiv = "%s%s"%(kernel["DepthU"], "*MI_M")
elif (tP["isSwizzled"] and tc == 'B'):
depthUDiv = "%s%s"%(kernel["DepthU"], "*MI_N")

gsuOffsetStr = "gsuOffset = DepthU*bpeGR*GSUSumIdx"
Comment on lines 216 to 223
Copy link
Contributor Author

@solaslin solaslin Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, all changes about renaming from MI_M to (MI_MN or MI_MorN) is for both A/B

divider = 1
if kernel["ProblemType"]["Sparse"]:
Expand Down Expand Up @@ -271,11 +274,15 @@ def graIncrements(self, writer, kernel, loopIdx, tP):

tcGR = tc if tc == "Metadata" else (tc + "GR")

# DVTA + SwA
mult_MI_M = "*MI_M" if tc == "A" and kernel["ProblemType"]["SwizzleTensorA"] else ""
# swizzle
mult_MI_Dim = ""
if tc == "A" and kernel["ProblemType"]["SwizzleTensorA"]:
mult_MI_Dim = "*MI_M"
elif tc == "B" and kernel["ProblemType"]["SwizzleTensorB"]:
mult_MI_Dim = "*MI_N"

module.add(SAndB32(dst=sgpr(gsuSgpr), src0=sgpr("GSU"), src1=hex(0x3FFF), comment="Restore GSU"))
module.add(SMulI32(dst=sgpr(gsuSgpr), src0=sgpr(gsuSgpr), src1="DepthU*Bpe%s%s"%(tcGR, mult_MI_M), comment="GSU*DepthU*Bpe%s"%(mult_MI_M)))
module.add(SMulI32(dst=sgpr(gsuSgpr), src0=sgpr(gsuSgpr), src1="DepthU*Bpe%s%s"%(tcGR, mult_MI_Dim), comment="GSU*DepthU*Bpe%s"%(mult_MI_Dim)))
module.add(SAndB32(dst=sgpr(tmpSgpr), src0=sgpr("GSU"), src1=hex(0x8000), comment="SCC = (GSUC == 1) ?"))

m = sgpr(gsuSgpr)
Expand All @@ -284,7 +291,7 @@ def graIncrements(self, writer, kernel, loopIdx, tP):
m.setMinus(True)

incr = sgpr("GlobalReadIncs%s+%u"%(tc, loopIdx))
duBpe = "DepthU*Bpe%s%s"%(tcGR, mult_MI_M)
duBpe = "DepthU*Bpe%s%s"%(tcGR, mult_MI_Dim)
# multiply by stride, optimizing if unit stride
if writer.isConstUnitStride(stride):
module.add(SCSelectB32(dst=incr, src0=duBpe, src1=m, comment="incr%s (unrollIdx)"%(tc)))
Expand Down
78 changes: 54 additions & 24 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,9 +814,11 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module:
module.add(ValueSet("BpeAGRLog2", log2(tPA["bpeGR"])))
module.add(ValueSet("BpeBGR", tPB["bpeGR"]))
module.add(ValueSet("BpeBGRLog2", log2(tPB["bpeGR"])))
# TODO- get real value from MI-InstM
# TODO- get real value from MI-InstM/N
if kernel["ProblemType"]["SwizzleTensorA"]:
module.add(ValueSet("MI_M", 16))
if kernel["ProblemType"]["SwizzleTensorB"]:
module.add(ValueSet("MI_N", 16))
#
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
module.add(ValueSet("BpeMetadata", tPM["bpe"]))
Expand Down Expand Up @@ -853,6 +855,8 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module:
GOList.append(("Metadata", kernel["ProblemType"]["IndexAssignmentsMetadata"], kernel["BufferLoad"], tPM, False))
if kernel["ProblemType"]["SwizzleTensorA"]:
GOList.append(("A", kernel["ProblemType"]["IndexAssignmentsA"], kernel["BufferLoad"], tPA, True))
if kernel["ProblemType"]["SwizzleTensorB"]:
GOList.append(("B", kernel["ProblemType"]["IndexAssignmentsB"], kernel["BufferLoad"], tPB, True))

for (tc, indices, justOffset32, tP, isSwizzled) in GOList:

Expand Down Expand Up @@ -1016,7 +1020,7 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module:
macro.add(VMulLOU32(dst=destLo,
src0="s[\\sgprStrideSW%s]" % tc, \
src1=offset, \
comment="SWZ: mul d%u lower"%i))
comment="SWZ-%s: mul d%u lower"%(tc, i)))
else:
macro.add(VMulLOU32(dst=destLo,
src0=self.strideRef(tc, indices[i]), \
Expand Down Expand Up @@ -2314,16 +2318,21 @@ def graTileOffsets(self, kernel, tP, margin=-1):
module.add(VLShiftLeftB32(dst=vgpr(v), shiftHex=hex(log2(margin)), src=vgpr(v), comment="gro%s%s_%u *= %d"%(tP["tensorChar"], tP["tileChar"], 0, margin)))
else:
module.add(VMovB32(dst=vgpr(v), src=vgpr(tP["gpr"]["tReg"]), comment="gro%s%s_%u"%(tP["tensorChar"], tP["tileChar"], 0) ))
# TODO- current for A only
# swizzle
if tP["isSwizzled"]:
swizzleStridePerWave = self.sgprPool.checkOut(1)
WvG_M = kernel["MIWaveGroup"][0]
# Calc stride = numKr * WaveG[0]
# Calc stride = numKr * WaveG
swizzleStrideVal = tP["swizzleK"]
module.addComment(f"Align to {swizzleStrideVal}")
module.add(SAddU32(sgpr(swizzleStridePerWave), sgpr("SizesSum"), swizzleStrideVal-1))
module.add(SLShiftRightB32(dst=sgpr(swizzleStridePerWave), src=sgpr(swizzleStridePerWave), shiftHex=hex(log2(swizzleStrideVal))))
module.add(SMulI32(dst=sgpr(swizzleStridePerWave), src0=hex(WvG_M), src1=sgpr(swizzleStridePerWave), comment="SWZ: numKr *= MI_WvG[0] (%u), how many wave-M per WG" % WvG_M))
module.add(SLShiftRightB32(dst=sgpr(swizzleStridePerWave), src=sgpr(swizzleStridePerWave), shiftHex=hex(log2(swizzleStrideVal)), comment="SWZ-%s: numKr = DimK / 32"%tc))
if tP["isA"]:
WvG_MorN = kernel["MIWaveGroup"][0]
commentMsg = "SWZ-%s: numKr *= MI_WvG[0] (%u), how many wave-M per WG" %(tc, WvG_MorN)
elif tP["isB"]:
WvG_MorN = kernel["MIWaveGroup"][1]
commentMsg = "SWZ-%s: numKr *= MI_WvG[1] (%u), how many wave-N per WG" %(tc, WvG_MorN)
module.add(SMulI32(dst=sgpr(swizzleStridePerWave), src0=hex(WvG_MorN), src1=sgpr(swizzleStridePerWave), comment=commentMsg))

for l in range(1, tP["nrt"]):
strideValue = stride
Expand All @@ -2335,7 +2344,7 @@ def graTileOffsets(self, kernel, tP, margin=-1):
# swizzle
else:
module.add(VAddCOU32(dst=vgpr(v+l), dst1=VCC(), src0=sgpr(swizzleStridePerWave), \
src1=vgpr(v+l-1), comment="SWZ: gro%s%s_%u += (numKr * WaveG[0])"%(tP["tensorChar"], tP["tileChar"], l) ))
src1=vgpr(v+l-1), comment="SWZ-%s: gro%s%s_%u += (numKr * WaveG[0 or 1])"%(tc, tc, tP["tileChar"], l) ))
if tP["isSwizzled"]:
self.sgprPool.checkIn(swizzleStridePerWave)

Expand Down Expand Up @@ -2377,13 +2386,13 @@ def graUnrollOffsets(self, kernel, tP):
tc = tP["tensorChar"]
if kernel["_UseSgprForGRO"]:
tP["gpr"]["unrollOffsets"] = tP["gpr"]["uReg"]
# TODO- so far only A, check B
# swizzle
elif tP["isSwizzled"]:
numUnrollOffsets = 1
tP["gpr"]["unrollOffsets"] = self.vgprPool.checkOut(numUnrollOffsets, "unrollOffsets", self.states.preventVgprOverflowDuringNewTile)
v = tP["gpr"]["unrollOffsets"]
module.addComment0("SWZ: Unroll increament is calculated in I-Offset since tile-memory is flattened")
module.add(VMovB32(dst=vgpr(v), src=vgpr(tP["gpr"]["uReg"]), comment="SWZ: only one gro%s%s_%u Base"%(tP["tensorChar"], self.states.unrollChar, 0)))
module.addComment0("SWZ-%s: Unroll increament is calculated in I/J-Offset since tile-memory is flattened"%tc)
module.add(VMovB32(dst=vgpr(v), src=vgpr(tP["gpr"]["uReg"]), comment="SWZ-%s: only one gro%s%s_%u Base"%(tc, tc, self.states.unrollChar, 0)))
else:
numUnrollOffsets = tP["nru"]
tP["gpr"]["unrollOffsets"] = self.vgprPool.checkOut(numUnrollOffsets, "unrollOffsets", self.states.preventVgprOverflowDuringNewTile)
Expand Down Expand Up @@ -2581,10 +2590,15 @@ def graFinalOffsets(self, kernel, tP):
graIdx = 0
swapPerpPara = (((tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]) and (not tP["tlu"]) and tP["nrp"] > 1)

# swizzle
if tP["isSwizzled"]:
swizzleK = tP["swizzleK"]
tP["swizzledBlockSize"] = self.sgprPool.checkOut(1)
module.add(SMovB32(dst=sgpr(tP["swizzledBlockSize"]), src=hex(16*swizzleK), comment="SWZ: swizzled block = MI_M(%u) * MI_K(%u) * pack-K(%u)" %(16, kernel["MatrixInstK"], tP["swizzlePackK"])))
if tc == "A":
commentMsg = "SWZ-%s: swizzled block = MI_M(%u) * MI_K(%u) * pack-K(%u)" %(tc, 16, kernel["MatrixInstK"], tP["swizzlePackK"])
elif tc == "B":
commentMsg = "SWZ-%s: swizzled block = MI_N(%u) * MI_K(%u) * pack-K(%u)" %(tc, 16, kernel["MatrixInstK"], tP["swizzlePackK"])
module.add(SMovB32(dst=sgpr(tP["swizzledBlockSize"]), src=hex(16*swizzleK), comment=commentMsg))

# both UseSgprForGRO and DTVA/B are enabled
if ((tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]) and kernel["_UseSgprForGRO"]:
Expand Down Expand Up @@ -2727,12 +2741,16 @@ def graFinalOffsetsSingleLoop(self, kernel, tP, tc, tmp, graIdx, perp, sPerp, pa

# Swizzled version:
# 1. passing swizzled block size as stride to macro
# 2. If moving unroll = advance I-Offset for one swizzled block
# 2. If moving unroll = advance I/J-Offset (A/B) for one swizzled block
if tP["isSwizzled"]:
if tc == "A":
commentMsg = "SWZ-%s: groAL = groA0I + 1 Block of flattened mem"%tc
elif tc == "B":
commentMsg = "SWZ-%s: groBL = groB0J + 1 Block of flattened mem"%tc

bfArgs.append(tP["swizzledBlockSize"])
if para > 0:
module.add(VAddCOU32(dst=vgpr(vgprTile), dst1=VCC(), src0=hex(1), src1=vgpr(vgprTile), \
comment="SWZ: groAL = groA0I + 1 Block of flattened mem"))
module.add(VAddCOU32(dst=vgpr(vgprTile), dst1=VCC(), src0=hex(1), src1=vgpr(vgprTile), comment=commentMsg))

bfArgs.append( "%u" % tmp )
bfComment = "gRO%s_%u_%u_%u_%u" % (tP["tensorChar"], para, sPara, perp, sPerp)
Expand Down Expand Up @@ -3036,13 +3054,16 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe):
module.add(SLShiftRightB32(dst=sgpr(stmp), src=size, shiftHex=0x1, comment="(size/2)"))
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=0x1, comment="(size/2-1)"))
else:
if tP["isA"] and tP["isSwizzled"]:
if tP["isSwizzled"]:
if idx in kernel["ProblemType"]["IndicesSummation"]:
module.addModuleAsFlatItems(self.alignTo(stmp, "SizeL", tP["swizzleK"]))
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="(size-1)"))
elif idx == kernel["ProblemType"]["Index0"]:
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="SWZ-%s align: (size-1)"%tc))
elif tP["isA"] and idx == kernel["ProblemType"]["Index0"]:
module.addModuleAsFlatItems(self.alignTo(stmp, "SizeI", 16))
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="(size-1)"))
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="SWZ-%s align: (size-1)"%tc))
elif tP["isB"] and idx == kernel["ProblemType"]["Index1"]:
module.addModuleAsFlatItems(self.alignTo(stmp, "SizeJ", 16))
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="SWZ-%s align: (size-1)"%tc))
else:
module.add(SSubU32(dst=sgpr(stmp), src0=size, src1=0x1, comment="(size-1)"))
else:
Expand Down Expand Up @@ -3145,8 +3166,10 @@ def graAddresses(self, kernel, tP):
tc = tP["tensorChar"]
graIdx = 0

if tP["isA"] and tP["isSwizzled"]:
module.addModuleAsFlatItems(self.alignTo("StrideA0I", "StrideA0I", tP["swizzleK"]))
if tP["isSwizzled"]:
# "StrideA0I" or "StrideB1J"
strideName = "Stride%s%s"%(tc,self.states.indexChars[tP["idx"]])
module.addModuleAsFlatItems(self.alignTo(strideName, strideName, tP["swizzleK"]))

Comment on lines 3168 to 3173
Copy link
Contributor Author

@solaslin solaslin Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strideName = "Stride%s%s"%(tc,self.states.indexChars[tP["idx"]])
this would be "StrideA0I" or "StrideB1J".
Thank Jimmy for providing this info.

if kernel["BufferLoad"]:
# maxAddrSgpr = size[n] * stride[n-1]
Expand Down Expand Up @@ -3379,6 +3402,7 @@ def lwaTileAssignment(self, kernel, tP):
# store DirectToVgpr K interval for later use
dtvKInterval = 1

# swizzle
if tP["isSwizzled"]:
module.addComment0("TileAssignment for DirectToVgpr%s and SwizzleTensor%s" % (tc, tc))
module.add(vectorStaticDivideAndRemainder(qReg, rReg, dividendReg, kernel["WavefrontSize"], tmpVgprRes))
Expand All @@ -3388,10 +3412,16 @@ def lwaTileAssignment(self, kernel, tP):
swizzleStrideVal = tP["swizzleK"]
module.addComment(f"Align to {swizzleStrideVal}")
module.add(SAddU32(sgpr(tmpSgpr), sgpr("SizesSum"), swizzleStrideVal-1))
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(swizzleStrideVal)), src=sgpr(tmpSgpr), comment="SWZ: numKr = DimK / %s"%swizzleStrideVal))
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(swizzleStrideVal)), src=sgpr(tmpSgpr), comment="SWZ-%s: numKr = DimK / %s"%(tc, swizzleStrideVal)))
WvG_M = kernel["MIWaveGroup"][0]
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) %= MIWG[0]"))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) *= numKr"))
if tP["isA"]:
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_M) mod MIWG[0]"%tc))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_M) *= numKr"%tc))
elif tP["isB"]:
# NB:
# Calc of w_id is: /= MIWG[0], not %= MIWG[1]
module.add(VLShiftRightB32(dst=vgpr(qReg), shiftHex=log2(WvG_M), src=vgpr(qReg), comment="SWZ-%s: wave_id (along_N) /= MIWG[0]"%tc))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_N) *= numKr"%tc))
elif isDTVAB:
Comment on lines 3414 to 3425
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be the most important note:
For swizzling A: the order of wave is wave_id = wave_id % MI_WaveG[0]
For swizzling B: the order of wave is wave_id =wave_id / MI_WaveG[0] (not wave_id % MI_WaveG[1])

# offset calculation for DirectToVgpr
# call function from LraTileAssignmentMFMA for DirectToVgpr
Expand Down
5 changes: 2 additions & 3 deletions tensilelite/Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,9 +2874,8 @@ def calSwizzleK(state, tc):
if state["ProblemType"]["SwizzleTensorB"]:
if not state["DirectToVgprB"]:
reject(state, f"Tensor B swizzling requires DirectToVgprB")
# TODO- NN fails validation due to DTVB + Tail-Loop is not working correctly
if not (state["ProblemType"]["TransposeA"] and not state["ProblemType"]["TransposeB"]):
reject(state, f"Tensor B swizzling supports TN only")
if state["ProblemType"]["TransposeB"]:
reject(state, f"Tensor B swizzling supports TN or NN only")

# Force GRVW the same when UnrollLoopSwapGlobalReadOrder = 1.
if genGRVWA and state["UnrollLoopSwapGlobalReadOrder"] == 1:
Expand Down
Loading