-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: DTVB with Swizzling (tensorB) #1562
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -814,9 +814,11 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: | |
module.add(ValueSet("BpeAGRLog2", log2(tPA["bpeGR"]))) | ||
module.add(ValueSet("BpeBGR", tPB["bpeGR"])) | ||
module.add(ValueSet("BpeBGRLog2", log2(tPB["bpeGR"]))) | ||
# TODO- get real value from MI-InstM | ||
# TODO- get real value from MI-InstM/N | ||
if kernel["ProblemType"]["SwizzleTensorA"]: | ||
module.add(ValueSet("MI_M", 16)) | ||
if kernel["ProblemType"]["SwizzleTensorB"]: | ||
module.add(ValueSet("MI_N", 16)) | ||
# | ||
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: | ||
module.add(ValueSet("BpeMetadata", tPM["bpe"])) | ||
|
@@ -853,6 +855,8 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: | |
GOList.append(("Metadata", kernel["ProblemType"]["IndexAssignmentsMetadata"], kernel["BufferLoad"], tPM, False)) | ||
if kernel["ProblemType"]["SwizzleTensorA"]: | ||
GOList.append(("A", kernel["ProblemType"]["IndexAssignmentsA"], kernel["BufferLoad"], tPA, True)) | ||
if kernel["ProblemType"]["SwizzleTensorB"]: | ||
GOList.append(("B", kernel["ProblemType"]["IndexAssignmentsB"], kernel["BufferLoad"], tPB, True)) | ||
|
||
for (tc, indices, justOffset32, tP, isSwizzled) in GOList: | ||
|
||
|
@@ -1016,7 +1020,7 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: | |
macro.add(VMulLOU32(dst=destLo, | ||
src0="s[\\sgprStrideSW%s]" % tc, \ | ||
src1=offset, \ | ||
comment="SWZ: mul d%u lower"%i)) | ||
comment="SWZ-%s: mul d%u lower"%(tc, i))) | ||
else: | ||
macro.add(VMulLOU32(dst=destLo, | ||
src0=self.strideRef(tc, indices[i]), \ | ||
|
@@ -2314,16 +2318,21 @@ def graTileOffsets(self, kernel, tP, margin=-1): | |
module.add(VLShiftLeftB32(dst=vgpr(v), shiftHex=hex(log2(margin)), src=vgpr(v), comment="gro%s%s_%u *= %d"%(tP["tensorChar"], tP["tileChar"], 0, margin))) | ||
else: | ||
module.add(VMovB32(dst=vgpr(v), src=vgpr(tP["gpr"]["tReg"]), comment="gro%s%s_%u"%(tP["tensorChar"], tP["tileChar"], 0) )) | ||
# TODO- current for A only | ||
# swizzle | ||
if tP["isSwizzled"]: | ||
swizzleStridePerWave = self.sgprPool.checkOut(1) | ||
WvG_M = kernel["MIWaveGroup"][0] | ||
# Calc stride = numKr * WaveG[0] | ||
# Calc stride = numKr * WaveG | ||
swizzleStrideVal = tP["swizzleK"] | ||
module.addComment(f"Align to {swizzleStrideVal}") | ||
module.add(SAddU32(sgpr(swizzleStridePerWave), sgpr("SizesSum"), swizzleStrideVal-1)) | ||
module.add(SLShiftRightB32(dst=sgpr(swizzleStridePerWave), src=sgpr(swizzleStridePerWave), shiftHex=hex(log2(swizzleStrideVal)))) | ||
module.add(SMulI32(dst=sgpr(swizzleStridePerWave), src0=hex(WvG_M), src1=sgpr(swizzleStridePerWave), comment="SWZ: numKr *= MI_WvG[0] (%u), how many wave-M per WG" % WvG_M)) | ||
module.add(SLShiftRightB32(dst=sgpr(swizzleStridePerWave), src=sgpr(swizzleStridePerWave), shiftHex=hex(log2(swizzleStrideVal)), comment="SWZ-%s: numKr = DimK / 32"%tc)) | ||
if tP["isA"]: | ||
WvG_MorN = kernel["MIWaveGroup"][0] | ||
commentMsg = "SWZ-%s: numKr *= MI_WvG[0] (%u), how many wave-M per WG" %(tc, WvG_MorN) | ||
elif tP["isB"]: | ||
WvG_MorN = kernel["MIWaveGroup"][1] | ||
commentMsg = "SWZ-%s: numKr *= MI_WvG[1] (%u), how many wave-N per WG" %(tc, WvG_MorN) | ||
module.add(SMulI32(dst=sgpr(swizzleStridePerWave), src0=hex(WvG_MorN), src1=sgpr(swizzleStridePerWave), comment=commentMsg)) | ||
|
||
for l in range(1, tP["nrt"]): | ||
strideValue = stride | ||
|
@@ -2335,7 +2344,7 @@ def graTileOffsets(self, kernel, tP, margin=-1): | |
# swizzle | ||
else: | ||
module.add(VAddCOU32(dst=vgpr(v+l), dst1=VCC(), src0=sgpr(swizzleStridePerWave), \ | ||
src1=vgpr(v+l-1), comment="SWZ: gro%s%s_%u += (numKr * WaveG[0])"%(tP["tensorChar"], tP["tileChar"], l) )) | ||
src1=vgpr(v+l-1), comment="SWZ-%s: gro%s%s_%u += (numKr * WaveG[0 or 1])"%(tc, tc, tP["tileChar"], l) )) | ||
if tP["isSwizzled"]: | ||
self.sgprPool.checkIn(swizzleStridePerWave) | ||
|
||
|
@@ -2377,13 +2386,13 @@ def graUnrollOffsets(self, kernel, tP): | |
tc = tP["tensorChar"] | ||
if kernel["_UseSgprForGRO"]: | ||
tP["gpr"]["unrollOffsets"] = tP["gpr"]["uReg"] | ||
# TODO- so far only A, check B | ||
# swizzle | ||
elif tP["isSwizzled"]: | ||
numUnrollOffsets = 1 | ||
tP["gpr"]["unrollOffsets"] = self.vgprPool.checkOut(numUnrollOffsets, "unrollOffsets", self.states.preventVgprOverflowDuringNewTile) | ||
v = tP["gpr"]["unrollOffsets"] | ||
module.addComment0("SWZ: Unroll increament is calculated in I-Offset since tile-memory is flattened") | ||
module.add(VMovB32(dst=vgpr(v), src=vgpr(tP["gpr"]["uReg"]), comment="SWZ: only one gro%s%s_%u Base"%(tP["tensorChar"], self.states.unrollChar, 0))) | ||
module.addComment0("SWZ-%s: Unroll increament is calculated in I/J-Offset since tile-memory is flattened"%tc) | ||
module.add(VMovB32(dst=vgpr(v), src=vgpr(tP["gpr"]["uReg"]), comment="SWZ-%s: only one gro%s%s_%u Base"%(tc, tc, self.states.unrollChar, 0))) | ||
else: | ||
numUnrollOffsets = tP["nru"] | ||
tP["gpr"]["unrollOffsets"] = self.vgprPool.checkOut(numUnrollOffsets, "unrollOffsets", self.states.preventVgprOverflowDuringNewTile) | ||
|
@@ -2581,10 +2590,15 @@ def graFinalOffsets(self, kernel, tP): | |
graIdx = 0 | ||
swapPerpPara = (((tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]) and (not tP["tlu"]) and tP["nrp"] > 1) | ||
|
||
# swizzle | ||
if tP["isSwizzled"]: | ||
swizzleK = tP["swizzleK"] | ||
tP["swizzledBlockSize"] = self.sgprPool.checkOut(1) | ||
module.add(SMovB32(dst=sgpr(tP["swizzledBlockSize"]), src=hex(16*swizzleK), comment="SWZ: swizzled block = MI_M(%u) * MI_K(%u) * pack-K(%u)" %(16, kernel["MatrixInstK"], tP["swizzlePackK"]))) | ||
if tc == "A": | ||
commentMsg = "SWZ-%s: swizzled block = MI_M(%u) * MI_K(%u) * pack-K(%u)" %(tc, 16, kernel["MatrixInstK"], tP["swizzlePackK"]) | ||
elif tc == "B": | ||
commentMsg = "SWZ-%s: swizzled block = MI_N(%u) * MI_K(%u) * pack-K(%u)" %(tc, 16, kernel["MatrixInstK"], tP["swizzlePackK"]) | ||
module.add(SMovB32(dst=sgpr(tP["swizzledBlockSize"]), src=hex(16*swizzleK), comment=commentMsg)) | ||
|
||
# both UseSgprForGRO and DTVA/B are enabled | ||
if ((tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]) and kernel["_UseSgprForGRO"]: | ||
|
@@ -2727,12 +2741,16 @@ def graFinalOffsetsSingleLoop(self, kernel, tP, tc, tmp, graIdx, perp, sPerp, pa | |
|
||
# Swizzled version: | ||
# 1. passing swizzled block size as stride to macro | ||
# 2. If moving unroll = advance I-Offset for one swizzled block | ||
# 2. If moving unroll = advance I/J-Offset (A/B) for one swizzled block | ||
if tP["isSwizzled"]: | ||
if tc == "A": | ||
commentMsg = "SWZ-%s: groAL = groA0I + 1 Block of flattened mem"%tc | ||
elif tc == "B": | ||
commentMsg = "SWZ-%s: groBL = groB0J + 1 Block of flattened mem"%tc | ||
|
||
bfArgs.append(tP["swizzledBlockSize"]) | ||
if para > 0: | ||
module.add(VAddCOU32(dst=vgpr(vgprTile), dst1=VCC(), src0=hex(1), src1=vgpr(vgprTile), \ | ||
comment="SWZ: groAL = groA0I + 1 Block of flattened mem")) | ||
module.add(VAddCOU32(dst=vgpr(vgprTile), dst1=VCC(), src0=hex(1), src1=vgpr(vgprTile), comment=commentMsg)) | ||
|
||
bfArgs.append( "%u" % tmp ) | ||
bfComment = "gRO%s_%u_%u_%u_%u" % (tP["tensorChar"], para, sPara, perp, sPerp) | ||
|
@@ -3036,13 +3054,16 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe): | |
module.add(SLShiftRightB32(dst=sgpr(stmp), src=size, shiftHex=0x1, comment="(size/2)")) | ||
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=0x1, comment="(size/2-1)")) | ||
else: | ||
if tP["isA"] and tP["isSwizzled"]: | ||
if tP["isSwizzled"]: | ||
if idx in kernel["ProblemType"]["IndicesSummation"]: | ||
module.addModuleAsFlatItems(self.alignTo(stmp, "SizeL", tP["swizzleK"])) | ||
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="(size-1)")) | ||
elif idx == kernel["ProblemType"]["Index0"]: | ||
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="SWZ-%s align: (size-1)"%tc)) | ||
elif tP["isA"] and idx == kernel["ProblemType"]["Index0"]: | ||
module.addModuleAsFlatItems(self.alignTo(stmp, "SizeI", 16)) | ||
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="(size-1)")) | ||
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="SWZ-%s align: (size-1)"%tc)) | ||
elif tP["isB"] and idx == kernel["ProblemType"]["Index1"]: | ||
module.addModuleAsFlatItems(self.alignTo(stmp, "SizeJ", 16)) | ||
module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="SWZ-%s align: (size-1)"%tc)) | ||
else: | ||
module.add(SSubU32(dst=sgpr(stmp), src0=size, src1=0x1, comment="(size-1)")) | ||
else: | ||
|
@@ -3145,8 +3166,10 @@ def graAddresses(self, kernel, tP): | |
tc = tP["tensorChar"] | ||
graIdx = 0 | ||
|
||
if tP["isA"] and tP["isSwizzled"]: | ||
module.addModuleAsFlatItems(self.alignTo("StrideA0I", "StrideA0I", tP["swizzleK"])) | ||
if tP["isSwizzled"]: | ||
# "StrideA0I" or "StrideB1J" | ||
strideName = "Stride%s%s"%(tc,self.states.indexChars[tP["idx"]]) | ||
module.addModuleAsFlatItems(self.alignTo(strideName, strideName, tP["swizzleK"])) | ||
|
||
Comment on lines
3168
to
3173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. strideName = "Stride%s%s"%(tc,self.states.indexChars[tP["idx"]]) |
||
if kernel["BufferLoad"]: | ||
# maxAddrSgpr = size[n] * stride[n-1] | ||
|
@@ -3379,6 +3402,7 @@ def lwaTileAssignment(self, kernel, tP): | |
# store DirectToVgpr K interval for later use | ||
dtvKInterval = 1 | ||
|
||
# swizzle | ||
if tP["isSwizzled"]: | ||
module.addComment0("TileAssignment for DirectToVgpr%s and SwizzleTensor%s" % (tc, tc)) | ||
module.add(vectorStaticDivideAndRemainder(qReg, rReg, dividendReg, kernel["WavefrontSize"], tmpVgprRes)) | ||
|
@@ -3388,10 +3412,16 @@ def lwaTileAssignment(self, kernel, tP): | |
swizzleStrideVal = tP["swizzleK"] | ||
module.addComment(f"Align to {swizzleStrideVal}") | ||
module.add(SAddU32(sgpr(tmpSgpr), sgpr("SizesSum"), swizzleStrideVal-1)) | ||
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(swizzleStrideVal)), src=sgpr(tmpSgpr), comment="SWZ: numKr = DimK / %s"%swizzleStrideVal)) | ||
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(swizzleStrideVal)), src=sgpr(tmpSgpr), comment="SWZ-%s: numKr = DimK / %s"%(tc, swizzleStrideVal))) | ||
WvG_M = kernel["MIWaveGroup"][0] | ||
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) %= MIWG[0]")) | ||
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) *= numKr")) | ||
if tP["isA"]: | ||
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_M) mod MIWG[0]"%tc)) | ||
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_M) *= numKr"%tc)) | ||
elif tP["isB"]: | ||
# NB: | ||
# Calc of w_id is: /= MIWG[0], not %= MIWG[1] | ||
module.add(VLShiftRightB32(dst=vgpr(qReg), shiftHex=log2(WvG_M), src=vgpr(qReg), comment="SWZ-%s: wave_id (along_N) /= MIWG[0]"%tc)) | ||
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_N) *= numKr"%tc)) | ||
elif isDTVAB: | ||
Comment on lines
3414
to
3425
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be the most important note: |
||
# offset calculation for DirectToVgpr | ||
# call function from LraTileAssignmentMFMA for DirectToVgpr | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this PR, all changes about renaming from MI_M to (MI_MN or MI_MorN) is for both A/B