Skip to content

Commit

Permalink
Use B64 instead of B32
Browse files Browse the repository at this point in the history
  • Loading branch information
KKyang committed Jan 15, 2025
1 parent ce603f2 commit 571844c
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 27 deletions.
27 changes: 18 additions & 9 deletions tensilelite/Tensile/Components/GlobalWriteBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,23 +2097,32 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV

def copyData(computeDataType, elementSumIdx, gwvw, vgprStart, direction=0):
module = Module("Copy Data")
for vi in range(0, gwvw):
vi = 0
while vi < gwvw:
sumIdxV = elementSumIdx + vi
if computeDataType.isHalf() or computeDataType.isBFloat16():
if (sumIdxV % 2 != 0):
vi += 1
continue
vgprIdx = elementSumIdx + vi // 2
module.add(VMovB32(dst=vgpr(vgprStart + (vi // 2)), src=vgpr(vgprIdx)))
elif computeDataType.isSingle():
if (vi + 1 < gwvw) and ((vgprStart + (vi // 2)) % 2 == 0) and (vgprIdx % 2 == 0):
module.add(VMovB64(dst=vgpr(vgprStart + (vi // 2), 2), src=vgpr(vgprIdx, 2)))
vi += 2
else:
module.add(VMovB32(dst=vgpr(vgprStart + (vi // 2)), src=vgpr(vgprIdx)))
vi += 1
elif computeDataType.isSingle() or computeDataType.isInt32():
vgprIdx = sumIdxV
module.add(VMovB32(dst=vgpr(vgprStart + vi), src=vgpr(vgprIdx)))
if (vi + 1 < gwvw) and ((vgprStart + vi) % 2 == 0) and (vgprIdx % 2 == 0):
module.add(VMovB64(dst=vgpr(vgprStart + vi, 2), src=vgpr(vgprIdx, 2)))
vi += 2
else:
module.add(VMovB32(dst=vgpr(vgprStart + vi), src=vgpr(vgprIdx)))
vi += 1
elif computeDataType.isDouble():
vgprIdx = elementSumIdx + vi * 2
module.add(VMovB32(dst=vgpr(vgprStart + vi * 2), src=vgpr(vgprIdx)))
module.add(VMovB32(dst=vgpr(vgprStart + vi * 2 + 1), src=vgpr(vgprIdx+1)))
elif computeDataType.isInt32():
vgprIdx = sumIdxV
module.add(VMovB32(dst=vgpr(vgprStart + vi), src=vgpr(vgprIdx)))
module.add(VMovB64(dst=vgpr(vgprStart + vi * 2, 2), src=vgpr(vgprIdx, 2)))
vi += 1
else:
assert 0

Expand Down
7 changes: 3 additions & 4 deletions tensilelite/Tensile/Components/StreamK.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
################################################################################

from ..TensileInstructions import Module, Label, SAddU32, RegisterPoolResource, sgpr, scalarStaticDivideAndRemainder, \
SCmpLtU32, SCSelectB32, sMagicDivAlg2, SMulI32, SSubU32, SMinU32, SMovB32, SCBranchSCC1, SCmpLeU32, VMovB32, vgpr, \
SAddCU32, SCmpGtU32, SCMovB32, SAddI32, SCmpEQU32, SCBranchSCC0, SLShiftLeftB32, SLoadB32, SWaitCnt, SMEMModifiers, \
SCmpLtU32, SCSelectB32, sMagicDivAlg2, SMulI32, SSubU32, SMinU32, SMovB32, SMovB64, SCBranchSCC1, SCmpLeU32, VMovB32, \
vgpr, SAddCU32, SCmpGtU32, SCMovB32, SAddI32, SCmpEQU32, SCBranchSCC0, SLShiftLeftB32, SLoadB32, SWaitCnt, SMEMModifiers, \
log2, SBarrier, SStoreB32, SLongBranchPositive, SBranch, ceilDivide, replaceHolder, SNop, staticMultiply, SSleep, \
VAddF32, VAddF64, SAndB32, SLShiftRightB32, VReadfirstlaneB32, SBranchIfNotZero
from ..Common import print2
Expand Down Expand Up @@ -379,8 +379,7 @@ def computeWorkspaceSrd(self, writer, kernel, sCtaIdx, tmpSgpr = None):
module = Module("StreamK Common computeWorkspaceSrd")

# Base Address
module.add(SMovB32(dst=sgpr("SrdWS+0"), src=sgpr("AddressWS+0"), comment="init SRD base address (lower)"))
module.add(SMovB32(dst=sgpr("SrdWS+1"), src=sgpr("AddressWS+1"), comment="init SRD base address (upper) + other fields"))
module.add(SMovB64(dst=sgpr("SrdWS+0", 2), src=sgpr("AddressWS+0", 2), comment="init SRD base address"))
module.add(SMovB32(dst=sgpr("SrdWS+2"), src="BufferOOB", comment=""))
module.add(SMovB32(dst=sgpr("SrdWS+3"), src="Srd127_96", comment="Set bits 127_96 in post-loop SRD"))

Expand Down
33 changes: 21 additions & 12 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,8 +1532,16 @@ def defineAndResources(self, kernel, tPA, tPB, tPM):
moduleArgs.add(SAddCU32(dst=sgpr("KernArgAddress+1"), src0=sgpr("KernArgAddress+1"), src1=hex(0)))
self.argLoader.resetOffset()
moduleArgs.addModuleAsFlatItems(self.getKernelArgLoadModule(kernel, sgprStart, load, self.states.numSgprPreload - self.states.userArgsInfo.commonArgsNum))
for i in range(self.states.userArgsInfo.commonArgsNum, self.states.numSgprPreload):
moduleArgs.add(SMovB32(dst=sgpr(sgprStart+i-self.states.userArgsInfo.commonArgsNum), src=sgpr(preloadSgprStartIdx+i), comment="move preload data to correct sgpr"))
i = self.states.userArgsInfo.commonArgsNum
while i < self.states.numSgprPreload:
dstIndex = sgprStart+i-self.states.userArgsInfo.commonArgsNum
srcIndex = preloadSgprStartIdx+i
if (i + 1 < self.states.numSgprPreload) and (dstIndex % 2 == 0) and (srcIndex % 2 == 0):
moduleArgs.add(SMovB64(dst=sgpr(sgprStart+i-self.states.userArgsInfo.commonArgsNum, 2), src=sgpr(preloadSgprStartIdx+i, 2), comment="move preload data to correct sgpr"))
i += 2
else:
moduleArgs.add(SMovB32(dst=sgpr(sgprStart+i-self.states.userArgsInfo.commonArgsNum), src=sgpr(preloadSgprStartIdx+i), comment="move preload data to correct sgpr"))
i += 1
moduleArgs.add(SBranch(labelName=perloadLabelLoadEnd.getLabelName()))
moduleArgs.add(preloadLabelHBM)
moduleArgs.add(SMovB64(dst=sgpr("KernArgAddress", 2), src=sgpr(preloadSgprStartIdx+4, 2), comment="Load address of kernel arguments"))
Expand Down Expand Up @@ -2934,8 +2942,7 @@ def computeMetaDataSrd(self, kernel, tP, tc, indices):
module.add(SAddU32(sgpr("SrdMetadata+0"), sgpr("AddressMetadata+0"), sgpr(tileStart+0), "SRD base = Address+ tileStart0"))
module.add(SAddCU32(sgpr("SrdMetadata+1"), sgpr("AddressMetadata+1"), sgpr(tileStart+1), "SRD base = Address+ tileStart1"))
else:
module.add(SMovB32(sgpr("SrdMetadata+0"), sgpr("AddressMetadata+0"), "init SRD base address (lower )" ))
module.add(SMovB32(sgpr("SrdMetadata+1"), sgpr("AddressMetadata+1"), "init SRD base address (upper) + other fields" ))
module.add(SMovB64(sgpr("SrdMetadata+0", 2), sgpr("AddressMetadata+0", 2), "init SRD base address" ))

module.add(SMovB32(sgpr("SrdMetadata+3"), "Srd127_96", "Set bits 127_96 in SRD"))
return module
Expand Down Expand Up @@ -2996,8 +3003,12 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe):
module.add(SMovB64(dst=sgpr(tileStart, 2), src=0, comment="set default tileStart"))

#Calculate tensor 2d size
module.add(SMovB32(dst=sgpr(tensor2dSize0), src=0x1, comment="Init tensor size"))
module.add(SMovB32(dst=sgpr(tensor2dSize1), src=0x0, comment="init tensor size"))
if tensor2dSize0 % 2 == 0:
module.add(SMovB64(dst=sgpr(tensor2dSize0, 2), src=0x1, comment="Init tensor size"))
else:
module.add(SMovB32(dst=sgpr(tensor2dSize0), src=0x1, comment="Init tensor size"))
module.add(SMovB32(dst=sgpr(tensor2dSize1), src=0x0, comment="init tensor size"))


numDim = len(indices)
for i in range(0, numDim):
Expand Down Expand Up @@ -8153,7 +8164,7 @@ def SrdTDinit(self, kernel):
module.addSpaceLine()

module.add(SMovB32(dst=sgpr("SrdTD+3"), src="Srd127_96", comment="Set bits 127_96 in post-loop SRD"))
module.add(SMovB32(dst=sgpr("SrdTD+2"), src=hex(0x80000000)))
module.add(SMovB32(dst=sgpr("SrdTD+2"), src="BufferOOB"))

module.add(SMulI32(dst=sgpr(tmpspgr0), src0="MT1", src1=sgpr("WorkGroup1"), comment=""))
module.add(SMulHIU32(dst=sgpr(tmpspgr+1), src0=sgpr(tmpspgr0), src1=sgpr("StrideC1J"), comment=""))
Expand Down Expand Up @@ -8354,9 +8365,8 @@ def localSplitUGlobalWriteIndices(self, kernel):
def allocPostLoopSrd(self, ch: str):
module = Module("allocPostLoopSrd")
# Buffer-load uses one base read pointer stored in the SRD - set it here:
module.add(SMovB32(dst=sgpr("Srd%s+0"%ch), src=sgpr("Address%s+0"%ch), comment="init SRD base address (lower)" ))
module.add(SMovB32(dst=sgpr("Srd%s+1"%ch), src=sgpr("Address%s+1"%ch), comment="init SRD base address (upper) + other fields" ))
module.add(SMovB32(dst=sgpr("Srd%s+2"%ch), src=hex(0x80000000)))
module.add(SMovB64(dst=sgpr("Srd%s+0"%ch, 2), src=sgpr("Address%s+0"%ch, 2), comment="init SRD base address" ))
module.add(SMovB32(dst=sgpr("Srd%s+2"%ch), src="BufferOOB"))
module.add(SMovB32(dst=sgpr("Srd%s+3"%ch), src="Srd127_96", comment="Set bits 127_96 in post-loop SRD"))
module.addSpaceLine()
return module
Expand Down Expand Up @@ -11008,8 +11018,7 @@ def writeBiasToGlobal(self, biasDataType, kernel, tP, gwvw, offsetVgpr, tmpSgprR
assert tmpSgprRes.size >= 4
tmpSgpr = tmpSgprRes.idx
#Calculate tensor 2d size
module.add(SMovB32(dst=sgpr(tmpSgpr+0), src=0x1, comment="Init tensor size"))
module.add(SMovB32(dst=sgpr(tmpSgpr+1), src=0x0, comment="Init tensor size"))
module.add(SMovB64(dst=sgpr(tmpSgpr+0, 2), src=0x1, comment="Init tensor size"))
indices = [i for i in range(kernel["ProblemType"]["NumIndicesC"])]
numDim = len(indices)
for i in range(0, numDim):
Expand Down
3 changes: 1 addition & 2 deletions tensilelite/Tensile/KernelWriterModules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def allocPostLoopSrdSuppressRaw(ch: str, chAddress: str, labelStr: str, sgprLeng
label = Label("%sAddrValid"%labelStr, "")
label2 = Label("%sAddrValid_End"%labelStr, "")
# Buffer-load uses one base read pointer stored in the SRD - set it here:
module.add(SMovB32(dst=sgpr("Srd%s+0"%ch), src=sgpr("Address%s+0"%chAddress), comment="init SRD base address (lower)" ))
module.add(SMovB32(dst=sgpr("Srd%s+1"%ch), src=sgpr("Address%s+1"%chAddress), comment="init SRD base address (upper) + other fields" ))
module.add(SMovB64(dst=sgpr("Srd%s+0"%ch, 2), src=sgpr("Address%s+0"%chAddress, 2), comment="init SRD base address" ))
module.add(SMovB32(dst=sgpr("Srd%s+3"%ch), src="Srd127_96", comment="Set bits 127_96 in post-loop SRD"))
module.add(SBranchIfNotZero("Address%s"%chAddress, DataType('int64'), label))
module.add(SMovB32(dst=sgpr("Srd%s+2"%ch), src=0))
Expand Down
2 changes: 2 additions & 0 deletions tensilelite/Tensile/TensileInstructions/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def _initAsmCaps(isaVersion, assemblerPath, isDebug) -> dict:

rv["v_fma_f64"] = _tryAssembler(isaVersion, assemblerPath, "v_fma_f64 v[20:21], v[22:23], v[24:25], v[20:21]", isDebug)

rv["v_mov_b64"] = _tryAssembler(isaVersion, assemblerPath, "v_mov_b64 v[0:1], v[2:3]", isDebug)

rv["HasAtomicAdd"] = _tryAssembler(isaVersion, assemblerPath, "buffer_atomic_add_f32 v0, v1, s[0:3], 0 offen offset:0", isDebug) \
or _tryAssembler(isaVersion, assemblerPath, "buffer_atomic_add_f32 v0, v1, s[0:3], null offen offset:0", isDebug)
rv["HasGLCModifier"] = _tryAssembler(isaVersion, assemblerPath, "buffer_load_dwordx4 v[10:13], v[0], s[0:3], 0, offen offset:0, glc", isDebug) \
Expand Down
29 changes: 29 additions & 0 deletions tensilelite/Tensile/TensileInstructions/Instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,6 +2769,35 @@ def __init__(self, dst, src, comment="") -> None:
super().__init__(InstType.INST_B32, dst, [src], None, None, comment)
self.setInst("v_mov_b32")

class _VMovB64(CommonInstruction):
def __init__(self, dst, src, comment="") -> None:
super().__init__(InstType.INST_B64, dst, [src], None, None, comment)
self.setInst("v_mov_b64")

class VMovB64(CompositeInstruction):
def __init__(self, dst, src, comment="") -> None:
super().__init__(InstType.INST_B64, dst, [src], comment)
self.setInst("v_mov_b64")

def toList(self) -> list:
assert 0 and "Not supported."
return []

def setupInstructions(self):
super().setupInstructions()
assert isinstance(self.srcs, List)
if self.asmCaps["v_mov_b64"]:
self.instructions = [_VMovB64(self.dst, self.srcs[0], self.comment)]
else:
dst1, dst2 = self.dst.splitRegContainer()
if isinstance(self.srcs[0], RegisterContainer) or isinstance(self.srcs[0], HolderContainer):
src1, src2 = self.srcs[0].splitRegContainer()
else:
srcs1 = (self.srcs[0] and 0xFFFFFFFF)
srcs2 = self.srcs[0] >> 32
self.instructions = [VMovB32(dst1, src1, self.comment),
VMovB32(dst2, src2, self.comment)]

# V Bfe
class VBfeI32(CommonInstruction):
def __init__(self, dst, src0, src1, src2, comment="") -> None:
Expand Down

0 comments on commit 571844c

Please sign in to comment.