Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[OPT] Tail Loop Optimization
Browse files Browse the repository at this point in the history
details:
1. Separate tailLoopOpt for A / B: tailLoopOptA / tailLoopOptB.
2. Not supported: DTV, SparseGemm.
3. Reorder load instructions with more vgprs.
briannwu committed Jan 20, 2025
1 parent 167eb6b commit ef4242e
Showing 3 changed files with 720 additions and 182 deletions.
44 changes: 35 additions & 9 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
@@ -2627,33 +2627,59 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
# if swapGlobalRoad is true, swap the order of global read (B->A)
tensorParameters1st = tensorParametersA
tensorParameters2nd = tensorParametersB
tailLoopOpt1st = kernel["tailLoopOptA"]
tailLoopOpt2nd = kernel["tailLoopOptB"]

tc1 = 'A'
tc2 = 'B'
if self.isSwapGlobalReadOrderForDtvOrDtl(kernel):
tensorParameters1st, tensorParameters2nd = tensorParameters2nd, tensorParameters1st
tailLoopOpt1st, tailLoopOpt2nd = tailLoopOpt2nd, tailLoopOpt1st
tc1, tc2 = tc2, tc1

globalReadMode1st = 2 if (((tensorParameters1st["glvw"] * tensorParameters1st["bpeGR"]) < 4) or \
kernel["tailLoopOpt"] == False) else 0
tailLoopOpt1st == False) else 3
globalReadMode2nd = 2 if (((tensorParameters2nd["glvw"] * tensorParameters2nd["bpeGR"]) < 4) or \
kernel["tailLoopOpt"] == False) else 0
globalReadMode1st = 0 if tensorParameters1st["isSwizzled"] else globalReadMode1st
globalReadMode2nd = 0 if tensorParameters2nd["isSwizzled"] else globalReadMode2nd

globalReadMode1st = 3 if tensorParameters1st["isSwizzled"] else globalReadMode1st
globalReadMode2nd = 3 if tensorParameters2nd["isSwizzled"] else globalReadMode2nd

module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters1st)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc1)
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st))
if tailLoopOpt1st and (globalReadMode1st == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters1st))
else:
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st))
module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("Tail global read %s"%tc2)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd))
if kernel["tailLoopOpt"] and \
(((tensorParameters1st["glvw"] * tensorParameters1st["bpeGR"]) >= 4) or \
((tensorParameters2nd["glvw"] * tensorParameters2nd["bpeGR"]) >= 4)):
module.add(self.tailLoopGlobalRead(kernel, tensorParameters1st, tensorParameters2nd))
if tailLoopOpt2nd and (globalReadMode2nd == 2):
module.add(self.doTailLoopOpt(kernel, tensorParameters2nd))
else:
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd))

doA = False
doB = False
if globalReadMode1st == 3:
if tc1 == 'A':
doA = True if (tensorParameters1st["bpeGR"] % 4 != 0) and (not kernel["ProblemType"]["TLU%s"%(tc1)]) else False
else:
doB = True if (tensorParameters1st["bpeGR"] % 4 != 0) and (not kernel["ProblemType"]["TLU%s"%(tc1)]) else False
if globalReadMode2nd == 3:
if tc2 == 'A':
doA = True if (tensorParameters2nd["bpeGR"] % 4 != 0) and (not kernel["ProblemType"]["TLU%s"%(tc2)]) else False
else:
doB = True if (tensorParameters2nd["bpeGR"] % 4 != 0) and (not kernel["ProblemType"]["TLU%s"%(tc2)]) else False

if doA or doB:
if tc1 == 'A':
module.add(self.tailLoopGlobalRead(kernel, tensorParameters1st, tensorParameters2nd, doA, doB))
else:
module.add(self.tailLoopGlobalRead(kernel, tensorParameters2nd, tensorParameters1st, doA, doB))
module.add(self._wait(kernel, tensorParameters1st, tensorParameters2nd, 0, -1, -1, "2wait for global read"))
module.add(self._syncThreads(kernel))

840 changes: 675 additions & 165 deletions tensilelite/Tensile/KernelWriterAssembly.py

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions tensilelite/Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
@@ -1386,15 +1386,17 @@ def assignProblemIndependentDerivedParameters(state):
reject(state, "MacroTile mismatch")

# tail loop optimization
state["tailLoopOptA"] = True
state["tailLoopOptB"] = True

if (tuple(state["ISA"]) != (9, 4, 2)) or \
(state["ProblemType"]["Sparse"]) or \
(state["LocalSplitU"] > 1) or \
(state["WaveSeparateGlobalReadA"] != 0) or \
(state["WaveSeparateGlobalReadB"] != 0) or \
(state["DirectToVgprA"] or state["DirectToVgprB"]):
state["tailLoopOpt"] = False
else:
state["tailLoopOpt"] = True
(state["ProblemType"]["Sparse"]):
state["tailLoopOptA"] = False
state["tailLoopOptB"] = False
if (state["DirectToVgprA"]):
state["tailLoopOptA"] = False
if (state["DirectToVgprB"]):
state["tailLoopOptB"] = False

# done
state["AssignedProblemIndependentDerivedParameters"] = True

0 comments on commit ef4242e

Please sign in to comment.