Skip to content

Commit

Permalink
fixed wave_id distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
solaslin committed Jan 21, 2025
1 parent f3f7232 commit 0d524ab
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -3382,13 +3382,14 @@ def lwaTileAssignment(self, kernel, tP):
tmpSgpr = tmpSgprInfo.idx
# Calc numKr, TODO- 32 should be MI_K * 2
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(32)), src=sgpr("SizesSum"), comment="SWZ-%s: numKr = DimK / 32"%tc))
WvG_M = kernel["MIWaveGroup"][0]
if tP["isA"]:
WvG_M = kernel["MIWaveGroup"][0]
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_M) mod MIWG[0]"%tc))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_M) *= numKr"%tc))
else:
WvG_N = kernel["MIWaveGroup"][1]
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_N-1), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_N) mod MIWG[1]"%tc))
elif tP["isB"]:
# NB:
# Calc of w_id is: /= MIWG[0], not %= MIWG[1]
module.add(VLShiftRightB32(dst=vgpr(qReg), shiftHex=log2(WvG_M), src=vgpr(qReg), comment="SWZ-%s: wave_id (along_N) /= MIWG[0]"%tc))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_N) *= numKr"%tc))
elif isDTVAB:
# offset calculation for DirectToVgpr
Expand Down

0 comments on commit 0d524ab

Please sign in to comment.