diff --git a/python/perf-kernels/tools/plot-layout/README.md b/python/perf-kernels/tools/plot-layout/README.md index 40de35bdb3aa..6041ec88b786 100644 --- a/python/perf-kernels/tools/plot-layout/README.md +++ b/python/perf-kernels/tools/plot-layout/README.md @@ -5,28 +5,48 @@ Here is the help info from the script. ```bash >$ python3 plot_layout.py -h -usage: Draw triton layouts [-h] [-shape SHAPE SHAPE SHAPE] [-plot {blocked,dot,wmma,lds}] [-nonKDim {16,32}] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP] - [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-kWidth {4,8,16}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep] +usage: Draw triton layouts [-h] [-tensorShape TENSORSHAPE TENSORSHAPE] [-dotShape DOTSHAPE DOTSHAPE DOTSHAPE] [-plot {blocked,dot,wmma,lds}] [-dim0 DIM0] [-dim1 DIM1] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] + [-threadsPerWarp THREADSPERWARP THREADSPERWARP] [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}] + [-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-mfmaTrans] [-scale] [-banks {32,64}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] + [-mnContig] [-mfma_trans_load] [-swizzleVec {4,8,16,32}] [-padInterval PADINTERVAL] [-padAmount PADAMOUNT] [-wave_size {32,64}] [-o O] [-keep] options: -h, --help show this help message and exit - -shape SHAPE SHAPE SHAPE - Tensor shape in the form of M,N,K + -tensorShape TENSORSHAPE TENSORSHAPE + 2D tensor shape in the form of dim0,dim1 + -dotShape DOTSHAPE DOTSHAPE DOTSHAPE + Dot op shape in the form of M,N,K -plot {blocked,dot,wmma,lds} choose plot mode - -nonKDim {16,32} mfma instruction dim + -dim0 DIM0 tensor dim0 name + -dim1 DIM1 tensor dim1 name -sizePerThread SIZEPERTHREAD SIZEPERTHREAD -threadsPerWarp THREADSPERWARP THREADSPERWARP -warpsPerCTA WARPSPERCTA WARPSPERCTA -order ORDER ORDER - -kWidth {4,8,16} number of elements per thread + -nonKDim {16,32} mfma instruction dim + -kWidth {4,8,16,32} number of contiguous elements per thread + -kGroup {1,2} total number of elements / kWidth per mfma instruction + -dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8} + element type of operand A + -dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8} + element type of operand B + -mfmaTrans If set, then use mfma.trans layout + -scale If set, plot the scale tensor for mfma_f8f6f4 instructions + -banks {32,64} choose the number of banks in LDS -lds_layout {swizzle,padding,none} choose the LDS data layout -lds_access {read,write,none} choose LDS access mode + -mnContig If set, the tensor is K x N and n-contig + -mfma_trans_load If set, use MFMA transpose load instructions + -swizzleVec {4,8,16,32} + number of contiguous elements in a vector to swizzle + -padInterval PADINTERVAL + Add padding for every padInterval bytes + -padAmount PADAMOUNT Pad padAmount bytes for every padInterval bytes -wave_size {32,64} choose the wmma instruction mode -o O output pdf file name (without surfix) - -mfmaTrans If set, then use mfma.trans layout -keep If set, keep the generated .tex file ``` @@ -34,84 +54,111 @@ options: This script does not require torch or triton to be installed. The only package it depends on is latex. On Ubuntu, do ```bash -sudo apt install texlive-full +sudo apt-get install texlive-latex-base texlive-latex-extra texlive-fonts-recommended texlive-fonts-extra + ``` ## Draw blocked layout (`-plot blocked`) Examples: ```bash -python3 plot_layout.py -plot blocked -shape 128 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1 -python3 plot_layout.py -plot blocked -shape 16 128 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2 -python3 plot_layout.py -plot blocked -shape 32 128 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1 +python3 plot_layout.py -plot blocked -tensorShape 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1 +python3 plot_layout.py -plot blocked -tensorShape 16 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2 +python3 plot_layout.py -plot blocked -tensorShape 32 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1 ``` Blocked layouts are used during global load. It is used to describe the layout of the tensor for pointers and results. -We can provide tensor shape (`-shape M N K`) and blocked layout parameters ( +We can provide tensor shape (`-tensorShape dim0 dim1`) and blocked layout parameters ( `-sizePerThread x y`, `-threadsPerWarp x y`, and `-warpsPerCTA x y`). We can also provide the order of the tensor as `-order x y` to control which dim is the fastest changing dimension. Notes -- All of the gemm dims (M, N, and K) are needed when providing the shape. But only - M and K will be used to plot the layout of the tensor. - The script does not support the case when threads are loading elements that are out of the boundary of the tensor dimensions. This means - - For M: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= M - - For K: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= K + - For dim0: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= dim0 + - For dim1: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= dim1 ## Draw mfma operand and result layouts (`-plot dot`) Examples: ```bash -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4 -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8 -python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16 +## i8 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a i8 -dtype_b i8 +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a i8 -dtype_b i8 +## fp16/bf16 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 4 -dtype_a fp16 -dtype_b fp16 +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp16 -dtype_b fp16 +## fp8/bf8 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp8 -dtype_b bf8 +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a fp8 -dtype_b bf8 +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp8 -dtype_b bf8 +## f4 and fp6/bf6 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 32 -kGroup 1 -dtype_a f4 -dtype_b bf6 +## fp8/bf8 and fp6/bf6/f4 inputs +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8 +## mixed precision with scaling +python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8 -scale ``` +One can add `-nonKDim [16,32]` and `-mfmaTrans` to all of the above examples. + This mode draws two graphs: -1. The layout of the whole tile for tile A, B, and C +1. The layout of the dot operation, i.e. tile C = tile A x tile B 2. The layout of a single mfma block, operands and results of one or more mfma instructions that share the same accumulating VGPRs. - This view has thread distributions among tensor elements. Knobs -- `-kWidth`: the number of elements that will be loaded into one thread at once -- `-nonKDim`: 16 ot 32, which is used to control the mfma instruction size +- `-kWidth [4,8,16,32]`: the number of elements that will be loaded into one thread at once +- `-kGroup [1,2]`: total number of elements / kWidth for on mfma instruction. + This is 1 for all mfma instructions except for mfma_f32_16x16x128_f8f6f4 and mfma_f32_32x32x64_f8f6f4 + with fp8 input types (CBSZ=0 or 1 and/or BLGP=0 or 1) +- `-nonKDim [16,32]`: mfma instruction size. The default is set to 16. - `-mfmaTrans`: if set, the transposed mfma layout will be plotted. +- `-dtype_a` and `-dtype_b`: element types of operand A and B. The default value is fp16. +- `-scale`: plot scale tensors for A and B. This is only supported with f4/f6 and f8 with `kGroup=2`. + If `-scale` is set but not supported, it's ignored. Notes - The layout shows the mapping from the threads/wave to the elements in the - original tensor. It does not care if the elements are arranged in LDS, like - swizzling to avoid bank conflicts. -- The script does not allow settings for data type or k dim of the mfma instruction. - This can be controled by the `-kWidth` flag. - - For example, if we want `mfma_32x32x8xf16`, we can set `-nonKDim 32` and `-kWidth 4`. - - If we want `mfma_32x32x16xf8`, we can set `-nonKDim 32` and `-kWidth 8`. - + original tensor. It does not matter if LDS is used. +- The script does not allow settings for k dim of the mfma instruction. + This can be controled by the `-kWidth` and `-kGroup`. ## Draw LDS access (`-plot lds`) Examples: ```bash -python3 plot_layout.py -plot lds -lds_layout none -lds_access none -shape 128 128 64 -kWidth 8 +python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 8 +python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 32 -dtype_a f4 +python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64 +python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64 +python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access read -tensorShape 128 128 -kWidth 16 -dtype_a bf8 -banks 64 +python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access write -tensorShape 128 128 -kWidth 16 -dtype_a f4 -banks 32 +python3 plot_layout.py -plot lds -lds_layout none -lds_access read -tensorShape 128 32 -kWidth 4 -dtype_a fp16 -banks 64 -mnContig +python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access read -tensorShape 128 32 -kWidth 16 -dtype_a fp8 -banks 64 -mnContig -mfma_trans_load +python3 plot_layout.py -plot lds -lds_layout padding -lds_access none -tensorShape 128 32 -kWidth 8 -dtype_a fp16 -banks 32 -padInterval 128 -padAmount 16 ``` Knobs -- `kWidth` here means the vector size when accessing LDS +- `kWidth`: the vector size (in unit of elements) when accessing LDS +- `banks`: the number of banks in LDS. (64 for gfx950, 32 for pre-gfx950) +- `dtype_a`: element data type - Three options for `-lds_layout`: - `none`: no swizzling, no padding - - `padding`: padding at every 128B - - `swizzling`: apply the swizzling pattern, which is derived from tensor shape and kWidth. + - `swizzle`: apply the swizzling pattern, which is derived from tensor shape and kWidth. + - `padding`: pad `padAmount` bytes for every `padInterval` bytes of data + - `padAmount`: default is 0 + - `padInterval`: default is 1 - Three options for `-lds_access`: - `none`: do not plot access pattern - - `read`: plot accessed elements during ds_read - - `write`: plot accessed elements during ds_write. Note that this needs some infomation from - global load. Therefore, we need to provide `-sizePerThread` and `-threadsPerWarp`. - -Notes -- This mode is rarely used. If you have any questions, please contact Lixun Zhang directly. + - `read`: plot accessed elements at the first cycle of ds_read + - `write`: plot accessed elements during ds_write. For global load access, we assume + a fully coalesced dwordx4 access pattern along the K dim. +- `mnContig`: If set, the tile is stored in mn-contig layout. In this layout, elements along + the M/N dim are contiguous in both global memory and LDS. +- `mfma_trans_load`: This flag only works when `mnContig` is set. When set, `ds_read_b64_tr_bx` + instructions are used to read from LDS. Note that current triton LDS layout mechanism will + lead to bank conflicts. diff --git a/python/perf-kernels/tools/plot-layout/blockedLayout.tex b/python/perf-kernels/tools/plot-layout/blockedLayout.tex new file mode 100644 index 000000000000..37aba60f5bf0 --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/blockedLayout.tex @@ -0,0 +1,157 @@ +\newcommand{\drawBlockedWave}[5]{ + %% + %% Draw a wave coverage with blocked layout + %% + %% Wave TL: pre defined top-left coordinate of the wave + %% \elem: pre defined variable + %% + %% #1: sizePerThread[0] --> sizePerThreadM + %% #2: sizePerThread[1] --> sizePerThreadN + %% #3: threadsPerWarp[0] --> threadsPerWarpM + %% #4: threadsPerWarp[1] --> threadsPerWarpN + %% #5: fastest changing dim --> order + + \pgfmathsetmacro{\sizePerThreadM}{#1} + \pgfmathsetmacro{\sizePerThreadN}{#2} + \pgfmathsetmacro{\threadsPerWarpM}{#3} + \pgfmathsetmacro{\threadsPerWarpN}{#4} + \pgfmathsetmacro{\order}{#5} + + \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} + \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} + + \foreach \tid in {0,...,63}{ + \pgfmathsetmacro{\tidM}{int(\tid/\threadsPerWarpN)} + \pgfmathsetmacro{\tidN}{mod(\tid,\threadsPerWarpN)} + \coordinate (Thread TL) at ($(Wave TL)+(\tidN*\sizePerThreadN*\elem, -\tidM*\sizePerThreadM*\elem)$); + \pgfmathsetmacro{\ratio}{\tidM*10} + + \ifthenelse{\tid = 0}{ + \draw [line width = 0.01mm, fill=red] (Thread TL) + rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); + }{ + \draw [line width = 0.01mm, fill=blue!\ratio!white] (Thread TL) + rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); + } + } + \draw (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem); +} + +\newcommand{\drawBlockedCTA}[7]{ + %% + %% Draw a CTA coverage with blocked layout + %% + %% CTA TL: pre defined top-left coordinate of the CTA + %% \elem: pre defined variable + %% + %% #1: sizePerThread[0] --> sizePerThreadM + %% #2: sizePerThread[1] --> sizePerThreadN + %% #3: threadsPerWarp[0] --> threadsPerWarpM + %% #4: threadsPerWarp[1] --> threadsPerWarpN + %% #5: warpsPerCTA[0] --> warpsPerCTAM + %% #6: warpsPerCTA[1] --> warpsPerCTAN + %% #7: fastest changing dim --> order + + \pgfmathsetmacro{\sizePerThreadM}{#1} + \pgfmathsetmacro{\sizePerThreadN}{#2} + \pgfmathsetmacro{\threadsPerWarpM}{#3} + \pgfmathsetmacro{\threadsPerWarpN}{#4} + \pgfmathsetmacro{\warpsPerCTAM}{#5} + \pgfmathsetmacro{\warpsPerCTAN}{#6} + \pgfmathsetmacro{\order}{#7} + + \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} + \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} + \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} + \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} + + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM*\warpsPerCTAN-1} + + \coordinate (Wave TL) at (CTA TL); + \drawBlockedWave{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\order} + \foreach \waveId in {0,...,\maxWaveId}{ + \ifthenelse{\order=1} + { + \pgfmathsetmacro{\waveCoordM}{int(\waveId/\warpsPerCTAN)} + \pgfmathsetmacro{\waveCoordN}{mod(\waveId,\warpsPerCTAN)} + \pgfmathsetmacro{\rot}{0} + }{ + \pgfmathsetmacro{\waveCoordM}{mod(\waveId,\warpsPerCTAM)} + \pgfmathsetmacro{\waveCoordN}{int(\waveId/\warpsPerCTAM)} + \pgfmathsetmacro{\rot}{90} + } + + \coordinate (Wave TL) at ($(CTA TL)+(\waveCoordN*\waveSizeN*\elem, -\waveCoordM*\waveSizeM*\elem)$); + \draw [ultra thin] (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem) + node [pos=.5, scale=.6*\scale, inner sep=0, fill=white, rotate=\rot] {wave\waveId}; + } + + \draw [thick] (CTA TL) rectangle ++(\CTASizeN*\elem, -\CTASizeM*\elem); +} + +\newcommand{\drawBlockedTensor}[8]{ + %% + %% Draw a tensor with blocked layout of the following parameters + %% sizePerThread[2] + %% threadsPerWarp[2] + %% warpsPerCTA[2] + %% order[2] + %% + %% TL: pre defined top-left coordinate of the tensor + %% \elem: pre defined variable + %% \dimColName: dim0Name + %% \dimRowName: dim1Name + %% + %% #1: tensorShape[0] --> M + %% #2: tensorShape[1] --> N + %% #3: sizePerThread[0] --> sizePerThreadM + %% #4: sizePerThread[1] --> sizePerThreadN + %% #5: threadsPerWarp[0] --> threadsPerWarpM + %% Note that threadsPerWarp[1] is calculated by 64/threadsPerWarp[0] + %% #6: warpsPerCTA[0] --> warpsPerCTAM + %% #7: warpsPerCTA[1] --> warpsPerCTAN + %% #8: fastest changing dim --> order + + \pgfmathsetmacro{\M}{#1} + \pgfmathsetmacro{\N}{#2} + \pgfmathsetmacro{\sizePerThreadM}{#3} + \pgfmathsetmacro{\sizePerThreadN}{#4} + \pgfmathsetmacro{\threadsPerWarpM}{#5} + \pgfmathsetmacro{\warpsPerCTAM}{#6} + \pgfmathsetmacro{\warpsPerCTAN}{#7} + \pgfmathsetmacro{\order}{#8} + + \pgfmathsetmacro{\threadsPerWarpN}{64/\threadsPerWarpM} + \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} + \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} + \pgfmathsetmacro{\CTARepM}{\M/\CTASizeM} + \pgfmathsetmacro{\CTARepN}{\N/\CTASizeN} + \pgfmathsetmacro{\maxCTAId}{\CTARepM*\CTARepN-1} + + \foreach \ctaId in {0,...,\maxCTAId}{ + \pgfmathsetmacro{\ctaCoordM}{int(\ctaId/\CTARepN)} + \pgfmathsetmacro{\ctaCoordN}{mod(\ctaId,\CTARepN)} + \coordinate (CTA TL) at ($(TL)+(\ctaCoordN*\CTASizeN*\elem, -\ctaCoordM*\CTASizeM*\elem)$); + \drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order} + } + + \node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {\dimColName=\M}; + \node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {\dimRowName=\N}; + + \def\zoomR{1.5} + \coordinate (zoomin BL) at ($(TL)+(0, .3)$); + + \foreach \hl in {0,...,\sizePerThreadM}{ + \draw ($(zoomin BL)+(0, \hl*\elem*\zoomR)$) -- ++(\sizePerThreadN*\elem*\zoomR,0); + } + \foreach \vl in {0,...,\sizePerThreadN}{ + \draw ($(zoomin BL)+(\vl*\elem*\zoomR, 0)$) -- ++(0, \sizePerThreadM*\elem*\zoomR); + } + + \node [scale=.6*\scale, left] at ($(zoomin BL)+(0, .5*\sizePerThreadM*\elem*\zoomR)$) {$t_0$}; + \node [scale=.6*\scale, right] at ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, .5*\sizePerThreadM*\elem*\zoomR)$) {\sizePerThreadM$\times$\sizePerThreadN}; + + \draw [densely dotted] (TL) -- (zoomin BL); + \draw [densely dotted] ($(TL)+(\sizePerThreadN*\elem, 0)$) -- ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, 0)$); + \draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); +} diff --git a/python/perf-kernels/tools/plot-layout/dotLayout.tex b/python/perf-kernels/tools/plot-layout/dotLayout.tex new file mode 100644 index 000000000000..633d4af01023 --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/dotLayout.tex @@ -0,0 +1,444 @@ +\newcommand{\drawBlockMFMALayoutLarge}[3]{ + %% + %% Draw a single block of MFMA_32x32x8xf16 or MFMA_16x16x16xf16 + %% + %% block TL: pre-defined top-left coordinate of the block + %% \elem: pre defined variable + %% + %% #1: 1 for mfma.trans, 0 for normal mfma + %% #2: mfmaNonKDim + %% #3: verbose. 1 means draw tid in each vec; 0 means draw nothing + + \pgfmathsetmacro{\trans}{#1} + \pgfmathsetmacro{\nonTrans}{1-#1} + \pgfmathsetmacro{\nonKDim}{#2} + \pgfmathsetmacro{\maxTID}{\nonKDim-1} + \pgfmathsetmacro{\groups}{64/\nonKDim} + \pgfmathsetmacro{\maxGID}{\groups-1} + \pgfmathsetmacro{\maxIVec}{\nonKDim*\nonKDim/256-1} + \pgfmathsetmacro{\verbose}{#3} + \foreach \iVec in {0,...,\maxIVec} { + \coordinate (wave TL) at ($(block TL)+(\trans*\iVec*\groups*4*\elem, -\nonTrans*\iVec*\groups*4*\elem)$); + \foreach \tg in {0,...,\maxGID}{ + \pgfmathsetmacro{\colID}{\tg} + \pgfmathsetmacro{\col}{\Colors[\colID]} + \foreach \tid in {0,...,\maxTID} { + \pgfmathsetmacro{\ratio}{\tid*2.5*\groups+15} + \ifthenelse{\verbose=0}{ + \draw [line width=0.005mm, fill=\col!\ratio!white] + ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) + rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem); + }{ + \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} + \draw [line width=0.005mm, fill=\col!\ratio!white] + ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) + rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem) + node [pos=.5, scale=.35*\scale, rotate=90*\nonTrans] {t\drawTid}; + } + } + } + } + \draw [thick] (block TL) rectangle ++(\nonKDim*\elem, -\nonKDim*\elem); +} + + +\newcommand{\drawTensorMFMALayout}[6]{ + %% + %% Draw a tensor with mfma layout. + %% + %% C TL: pre defined top-left coordinates of the tensor + %% + %% #1: M + %% #2: N + %% #3: MFMA nonKDim + %% #4: warpsPerCTA[0] + %% #5: warpsPerCTA[1] + %% #6: 1 for mfma.trans, 0 for normal mfma + + \pgfmathsetmacro{\tensorShapeH}{#1} + \pgfmathsetmacro{\tensorShapeW}{#2} + \pgfmathsetmacro{\mfmaNonKDim}{#3} + \pgfmathsetmacro{\warpsPerCTAH}{#4} + \pgfmathsetmacro{\warpsPerCTAW}{#5} + \pgfmathsetmacro{\mfmaTrans}{#6} + + \coordinate (old TL) at (TL); + \coordinate (TL) at (C TL); + + + \pgfmathsetmacro{\CTARepH}{\tensorShapeH/\mfmaNonKDim/\warpsPerCTAH} + \pgfmathsetmacro{\CTARepW}{\tensorShapeW/\mfmaNonKDim/\warpsPerCTAW} + \pgfmathsetmacro{\maxCTAId}{\CTARepH*\CTARepW-1} + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAH*\warpsPerCTAW-1} + \pgfmathsetmacro{\CTASizeH}{\warpsPerCTAH*\mfmaNonKDim} + \pgfmathsetmacro{\CTASizeW}{\warpsPerCTAW*\mfmaNonKDim} + + + \foreach \ctaId in {0,...,\maxCTAId}{ + \pgfmathsetmacro{\ctaCoordH}{int(\ctaId/\CTARepW)} + \pgfmathsetmacro{\ctaCoordW}{mod(\ctaId,\CTARepW)} + \coordinate (CTA TL) at ($(TL)+(\ctaCoordW*\CTASizeW*\elem, -\ctaCoordH*\CTASizeH*\elem)$); + %% Draw a detailed view of wave0 in each CTA + \coordinate (block TL) at (CTA TL); + \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{0} + + \foreach \waveId in {0,...,\maxWaveId}{ + \pgfmathsetmacro{\waveCoordH}{int(\waveId/\warpsPerCTAW)} + \pgfmathsetmacro{\waveCoordW}{mod(\waveId,\warpsPerCTAW)} + \coordinate (block TL) at ($(CTA TL)+(\waveCoordW*\mfmaNonKDim*\elem, -\waveCoordH*\mfmaNonKDim*\elem)$); + %% Inside the loop, only draw a rectangle + \draw [ultra thin] (block TL) rectangle ++(\mfmaNonKDim*\elem, -\mfmaNonKDim*\elem) + node [scale=.7*\mfmaNonKDim/32*\scale, pos=.5, fill=white, inner sep=0] {wave\waveId}; + } + + %% Draw the outline of each CTA rep + \draw [ultra thick] (CTA TL) rectangle ++(\CTASizeW*\elem, -\CTASizeH*\elem); + } + + \coordinate (TL) at (old TL); +} + +\newcommand{\drawMFMAOperand}[5]{ + %% + %% Draw one mfma operand + %% + %% Pre-defined variables + %% mfma op TL: coordinates of the top-left + %% \elem: vertical element size of operands, element size of output + %% \elemW: honrizontal element size of operands + %% + %% #1: mfmNonKDim + %% #2: kWidth + %% #2: kGroup + %% #3: 0 for opA and 1 for opB + %% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing + + \pgfmathsetmacro{\nonKDim}{#1} + \pgfmathsetmacro{\maxGID}{64/\nonKDim-1} + \pgfmathsetmacro{\maxTID}{\nonKDim-1} + \pgfmathsetmacro{\kWidth}{#2} + \pgfmathsetmacro{\kGroup}{#3} + \pgfmathsetmacro{\maxGroupId}{\kGroup-1} + \pgfmathsetmacro{\opIdxA}{#4} + \pgfmathsetmacro{\opIdxB}{1-\opIdxA} + \pgfmathsetmacro{\verbose}{#5} + + \foreach \gp in {0,...,\maxGroupId}{ + \coordinate (group TL) at ($(mfma op TL)+(\gp*\kWidth*64*\elemW/\nonKDim*\opIdxB, -\gp*\kWidth*64*\elemW/\nonKDim*\opIdxA)$); + \foreach \col/\tg in {0,...,\maxGID}{ + \pgfmathsetmacro{\col}{\Colors[\tg]} + \foreach \tid in {0,...,\maxTID} { + \ifthenelse{\verbose=0}{ + \draw [line width=0.005mm, fill=\col] + ($(group TL)+(\tg*\kWidth*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kWidth*\elem*\opIdxA)$) + rectangle ++(\kWidth*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kWidth*\elem*\opIdxA); + }{ + \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} + \draw [line width=0.005mm, fill=\col] + ($(group TL)+(\tg*\kWidth*\elemW*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kWidth*\elemW*\opIdxA)$) + rectangle ++(\kWidth*\elemW*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kWidth*\elemW*\opIdxA) + node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid}; + } + } + } + } +} + +\newcommand{\drawWaveOperand}[5]{ + %% + %% Draw the part of the tensor that is one operand of the wave + %% + %% Op TL: pre defined coordinates of the top-left of the operand + %% \elem: pre defined variable + %% + %% #1: K + %% #2: mfmNonKDim + %% #3: kWidth + %% #4: kGroup + %% #5: 0 for opA and 1 for opB + + \pgfmathsetmacro{\K}{#1} + \pgfmathsetmacro{\nonKDim}{#2} + \pgfmathsetmacro{\groups}{64/\nonKDim} + \pgfmathsetmacro{\kWidth}{#3} + \pgfmathsetmacro{\kGroup}{#4} + \pgfmathsetmacro{\opIdx}{#5} + \pgfmathsetmacro{\opIdxOther}{1-\opIdx} + + \coordinate (TL) at (Op TL); + + \pgfmathsetmacro{\numKRep}{\K/\kWidth/\groups/\kGroup} + \pgfmathsetmacro{\maxKRepId}{\numKRep-1} + + \foreach \repId in {0,...,\maxKRepId}{ + \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kWidth*\elem*\kGroup*\opIdxOther, -\repId*\groups*\kWidth*\kGroup*\elem*\opIdx)$); + \drawMFMAOperand{\nonKDim}{\kWidth}{\kGroup}{\opIdx}{0} + \draw [thick] (mfma op TL) rectangle + ++(\groups*\kWidth*\kGroup*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kWidth*\kGroup*\elem*\opIdx); + } +} + +\newcommand{\drawDotOperands}[6]{ + %% + %% Draw operand tensors of dot + %% + %% A TL and B TL: pre defined top-left coordinates of A and B tensor + %% \elem: pre defined variable + %% + %% #1: M + %% #2: N + %% #3: K + %% #4: MFMA nonKDim + %% #5: warpsPerCTA[0] + %% #6: warpsPerCTA[1] + + \pgfmathsetmacro{\M}{#1} + \pgfmathsetmacro{\N}{#2} + \pgfmathsetmacro{\K}{#3} + \pgfmathsetmacro{\mfmaNonKDim}{#4} + \pgfmathsetmacro{\warpsPerCTAM}{#5} + \pgfmathsetmacro{\warpsPerCTAN}{#6} + + %% operand A + \pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/\mfmaNonKDim} + \pgfmathsetmacro{\maxCTAIdM}{\CTARepM-1} + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM-1} + \foreach \ctaId in {0,...,\maxCTAIdM}{ + \coordinate (CTA TL) at ($(A TL)+(0, -\ctaId*\warpsPerCTAM*\mfmaNonKDim*\elem)$); + \foreach \waveId in {0,...,\maxWaveId}{ + \coordinate (wave TL) at ($(CTA TL)+(0, -\waveId*\mfmaNonKDim*\elem)$); + \draw [ultra thin] (wave TL) rectangle ++(\K*\elem, -\mfmaNonKDim*\elem); + } + %% Only draw the detailed view of the first wave in CTA + \coordinate (Op TL) at (CTA TL); + \drawWaveOperand{\K}{\mfmaNonKDim}{\kWidthA}{\kGroupA}{0} + + %% Draw the outline of each CTA rep + \draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*\mfmaNonKDim*\elem); + } + \draw [ultra thin] (A TL) rectangle ++(\K*\elem, -\M*\elem); + + + %% operand B + \pgfmathsetmacro{\CTARepN}{\N/\warpsPerCTAN/\mfmaNonKDim} + \pgfmathsetmacro{\maxCTAIdN}{\CTARepN-1} + \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAN-1} + \foreach \ctaId in {0,...,\maxCTAIdN}{ + \coordinate (CTA TL) at ($(B TL)+(\ctaId*\warpsPerCTAN*\mfmaNonKDim*\elem, 0)$); + \foreach \waveId in {0,...,\maxWaveId}{ + \coordinate (wave TL) at ($(CTA TL)+(\waveId*\mfmaNonKDim*\elem ,0)$); + \draw [ultra thin] (wave TL) rectangle ++(\mfmaNonKDim*\elem, -\K*\elem); + } + %% Only draw the detailed view of the first wave in CTA + \coordinate (Op TL) at (CTA TL); + \drawWaveOperand{\K}{\mfmaNonKDim}{\kWidthB}{\kGroupB}{1} + + %% Draw the outline of each CTA rep + \draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*\mfmaNonKDim*\elem, -\K*\elem); + } + \draw [ultra thin] (B TL) rectangle ++(\N*\elem, -\K*\elem); +} + + +\newcommand{\drawDot}[7]{ + %% + %% Draw C = dot A, B + %% + %% C TL: pre defined top-left coordinates of the result tensor + %% \elem: pre defined variable + %% + %% #1: M + %% #2: N + %% #3: K + %% #4: MFMA nonKDim + %% #5: warpsPerCTA[0] + %% #6: warpsPerCTA[1] + %% #7: 1 for mfma.trans, 0 for normal mfma + + \pgfmathsetmacro{\M}{#1} + \pgfmathsetmacro{\N}{#2} + \pgfmathsetmacro{\K}{#3} + \pgfmathsetmacro{\mfmaNonKDim}{#4} + \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} + \pgfmathsetmacro{\warpsPerCTAM}{#5} + \pgfmathsetmacro{\warpsPerCTAN}{#6} + \pgfmathsetmacro{\mfmaTrans}{#7} + + \pgfmathsetmacro{\gap}{\elem*20} + \coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$); + \coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$); + + %% Draw both A and B operands + \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN} + + %% Draw result tensor + \drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans} + + %% Draw labels + \node [scale=\scale, above] at ($(A TL)+(.5*\K*\elem, 0)$) {K=\K}; + \node [scale=\scale, above, rotate=90] at ($(A TL)+(0, -.5*\M*\elem)$) {M=\M}; + + \node [scale=\scale, above, rotate=90] at ($(B TL)+(0, -.5*\K*\elem)$) {K=\K}; + \node [scale=\scale, above] at ($(B TL)+(.5*\N*\elem, 0)$) {N=\N}; + + \node [scale=\scale, above left] at (A TL) {A}; + \node [scale=\scale, above left] at (B TL) {B}; + \node [scale=\scale, above left] at (C TL) {C}; + + %% label nonKDim + \node [scale=.8*\scale, left] at ($(C TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim}; + \node [scale=.8*\scale, above] at ($(C TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; +} + +\newcommand{\drawZoomInVec}[3]{ + %% + %% Draw zoomed in view of vector of elements + %% + %% predefined variables + %% vec TL: top-left coordinates of the vector + %% orig TL: top-left coordinates of the original small vector + %% \elem: vertical element size of operands, element size of output + %% \elemW: honrizontal element size of operands + %% \scaleLabel: extra scale applied to labels according to kWidth + %% + %% #1: number of elements + %% #2: 0 for opLeft, 1 for opRight + %% #3: label + + \pgfmathsetmacro{\kWidth}{#1} + \pgfmathsetmacro{\opLeft}{#2} + \pgfmathsetmacro{\opRight}{1-#2} + + \pgfmathsetmacro{\maxVec}{\kWidth-1} + \foreach \vecId in {0,...,\maxVec}{ + \draw ($(vec TL)+(\vecId*\elem*\opRight, -\vecId*\elem*\opLeft)$) rectangle ++(\elem, -\elem); + } + \draw [densely dotted] (orig TL) -- ($(vec TL)+(\elem*\opLeft, -\elem*\opRight)$); + \draw [densely dotted] ($(orig TL)+(\kWidth*\elemW*\opRight, -\kWidth*\elemW*\opLeft)$) -- ($(vec TL)+(\kWidth*\elem*\opRight+\elem*\opLeft, -\elem*\opRight-\kWidth*\elem*\opLeft)$); + \node [scale=.8*\scaleLabel, above, rotate=90*\opLeft] at ($(vec TL)+(.5*\kWidth*\elem*\opRight, -.5*\kWidth*\elem*\opLeft)$) {#3}; +} + +\newcommand{\drawMFMAInstr}[6]{ + %% + %% Draw layout of mfma instructions with tid labeled + %% + %% Pre-defined variables + %% C TL: top-left coordinates of the output matrix + %% \elem: vertical element size of operands, element size of output + %% \elemW: honrizontal element size of operands + %% \scaleLabel: extra scale applied to labels according to kWidth + %% + %% #1: mfmaNonKDim + %% #2: mfmaTrans + %% #3: dtype_a + %% #4: dtype_b + %% #5: outType + %% #6: scaling: if set, draw scaling tensors + + \pgfmathsetmacro{\mfmaNonKDim}{#1} + \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} + \pgfmathsetmacro{\mfmaTrans}{#2} + \pgfmathsetmacro{\nonTrans}{1-#2} + \pgfmathsetmacro{\drawScale}{#6} + + \ifthenelse{\mfmaTrans=0}{ + \pgfmathsetmacro{\kWidthLeft}{\kWidthA} + \pgfmathsetmacro{\kWidthRight}{\kWidthB} + \pgfmathsetmacro{\kGroupLeft}{\kGroupA} + \pgfmathsetmacro{\kGroupRight}{\kGroupB} + }{ + \pgfmathsetmacro{\kWidthLeft}{\kWidthB} + \pgfmathsetmacro{\kWidthRight}{\kWidthA} + \pgfmathsetmacro{\kGroupLeft}{\kGroupB} + \pgfmathsetmacro{\kGroupRight}{\kGroupA} + } + \pgfmathsetmacro{\kDim}{int(\kWidthLeft*\groups*\kGroupLeft)} + + %% Draw operand left + \pgfmathsetmacro{\gap}{\elem*5} + \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kWidthLeft*\elemW*\kGroupLeft, 0)$); + \coordinate (mfma op TL) at (mfma opA TL); + \drawMFMAOperand{\mfmaNonKDim}{\kWidthLeft}{\kGroupLeft}{0}{1} + %% Draw operand right + \coordinate (mfma opB TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kWidthRight*\elemW*\kGroupRight)$); + \coordinate (mfma op TL) at (mfma opB TL); + \drawMFMAOperand{\mfmaNonKDim}{\kWidthRight}{\kGroupRight}{1}{1} + + %% Draw scaling tensors if needed + \ifthenelse{\drawScale=1}{ + \coordinate (left scaling TL) at ($(mfma opA TL)+(-0.3*\gap-\groups*4*\elemW, 0)$); + \coordinate (mfma op TL) at (left scaling TL); + \drawMFMAOperand{\mfmaNonKDim}{4}{1}{0}{1} + + \coordinate (right scaling TL) at ($(mfma opB TL)+(0, 0.3*\gap+\groups*4*\elemW)$); + \coordinate (mfma op TL) at (right scaling TL); + \drawMFMAOperand{\mfmaNonKDim}{4}{1}{1}{1} + }{} + + \coordinate (block TL) at (C TL); + \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} + + %% Draw labels + %% Set data types + \def\opAType{#3} + \def\opBType{#4} + \def\outType{#5} + + %% Draw kWidth vector and label of first operand + \coordinate (vec TL) at ($(mfma opA TL)+(0, 5*\elem)$); + \coordinate (orig TL) at (mfma opA TL); + \drawZoomInVec{\kWidthLeft}{0}{kWidth=\kWidthLeft} + + %% Draw kWidth vector and label of second operand + \coordinate (vec TL) at ($(mfma opB TL)+(-5*\elem, 0)$); + \coordinate (orig TL) at (mfma opB TL); + \drawZoomInVec{\kWidthRight}{1}{kWidth=\kWidthRight} + + \ifthenelse{\drawScale=1}{ + %% Draw vec and label of scalingLeft + \coordinate (vec TL) at ($(left scaling TL)+(0, 5*\elem)$); + \coordinate (orig TL) at (left scaling TL); + \drawZoomInVec{4}{0}{vec=4$\times$e8m0} + %% Draw vec and label of scalingRight + \coordinate (vec TL) at ($(right scaling TL)+(-5*\elem, 0)$); + \coordinate (orig TL) at (right scaling TL); + \drawZoomInVec{4}{1}{vec=4$\times$e8m0} + }{} + + %% Draw labels according to mfma.trans or not + \ifthenelse{\mfmaTrans=0}{ + \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kWidthLeft*\elemW*\groups*\kGroupLeft, 0)$) + {inA:$\mfmaNonKDim \times \kDim \times $\opAType}; + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma opB TL)+(0,-\groups*\kWidthRight*\elemW*\kGroupRight)$) + {inB:$\kDim \times \mfmaNonKDim \times $\opBType}; + \ifthenelse{\drawScale=1}{ + \node [scale=\scaleLabel, above] at ($(left scaling TL)+(0.5*4*\elemW*\groups, 0)$) {scaleA}; + \node [scale=\scaleLabel, above, rotate=90] at ($(right scaling TL)+(0,-0.5*\groups*4*\elemW)$) {scaleB}; + }{} + \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem,0)$); + \foreach \vecId in {0,1,2,3}{ + \draw ($(vec TL)+(0, -\vecId*\elem)$) rectangle ++(\elem, -\elem); + } + \draw [densely dotted] (block TL) -- ++(-3*\elem,0); + \draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem,0); + \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem)$) {vec=4$\times$\outType}; + \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; + }{ + \node [scale=\scaleLabel, above left] at ($(mfma opA TL)+(\kWidthLeft*\elemW*\groups*\kGroupLeft, 0)$) + {inB:$\kDim \times \mfmaNonKDim^T \times $\opBType}; + \node [scale=\scaleLabel, above right, rotate=90] at ($(mfma opB TL)+(0, -\groups*\kWidthRight*\elemW*\kGroupRight)$) + {inA:$\mfmaNonKDim \times \kDim^T \times $\opAType}; + \ifthenelse{\drawScale=1}{ + \node [scale=\scaleLabel, above] at ($(left scaling TL)+(.5*4*\elemW*\groups, 0)$) {scaleB}; + \node [scale=\scaleLabel, above, rotate=90] at ($(right scaling TL)+(0, -.5*\groups*4*\elemW)$) {scaleA}; + }{} + \coordinate (vec TL) at ($(block TL)+(0, 3*\elem+\elem)$); + \foreach \vecId in {0,1,2,3}{ + \draw ($(vec TL)+(\vecId*\elem, 0)$) rectangle ++(\elem, -\elem); + } + \draw [densely dotted] (block TL) -- ++(0, 3*\elem); + \draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(0, 3*\elem); + \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem, 0)$) {vec=4$\times$\outType}; + \node [scale=.6*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=True}; + } +} diff --git a/python/perf-kernels/tools/plot-layout/ldsLayout.tex b/python/perf-kernels/tools/plot-layout/ldsLayout.tex new file mode 100644 index 000000000000..f3a7cd2762b0 --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/ldsLayout.tex @@ -0,0 +1,494 @@ +\ExplSyntaxOn +\NewExpandableDocumentCommand{\bitwiseXor}{mm} + { + \recuenco_bitwise_xor:nn { #1 } { #2 } + } + +\cs_new:Nn \recuenco_bitwise_xor:nn + { + \int_from_bin:e + { + \__recuenco_bitwise_xor:ee { \int_to_bin:n { #1 } } { \int_to_bin:n { #2 } } + } + } +\cs_generate_variant:Nn \int_from_bin:n { e } + +\cs_new:Nn \__recuenco_bitwise_xor:nn + { + \__recuenco_bitwise_xor_binary:ee + { + \prg_replicate:nn + { + \int_max:nn { \tl_count:n { #1 } } { \tl_count:n { #2 } } - \tl_count:n { #1 } + } + { 0 } + #1 + } + { + \prg_replicate:nn + { + \int_max:nn { \tl_count:n { #1 } } { \tl_count:n { #2 } } - \tl_count:n { #2 } + } + { 0 } + #2 + } + } +\cs_generate_variant:Nn \__recuenco_bitwise_xor:nn { ee } + +\cs_new:Nn \__recuenco_bitwise_xor_binary:nn + { + \__recuenco_bitwise_xor_binary:w #1;#2; + } +\cs_generate_variant:Nn \__recuenco_bitwise_xor_binary:nn { ee } + +\cs_new:Npn \__recuenco_bitwise_xor_binary:w #1#2;#3#4; + { + \int_abs:n { #1-#3 } + \tl_if_empty:nF { #2 } { \__recuenco_bitwise_xor_binary:w #2;#4; } + } + +\ExplSyntaxOff + +\newcommand{\drawTensorLayoutGlobalMem}[4]{ + %% + %% Draw tensor layout in global memory without any swizzling + %% + %% TL: pre defined top-left coordinates of the tensor in global memory + %% \elemH: The height of each element + %% \bsize: The width of each byte + %% \Colors: a pre defined array of 16 colors + %% \trans: 1 for K x N and 0 for M x K + %% + %% #1: rowName + %% #2: colName + %% #3: rowSize, i.e. number of rows + %% #4: colSize, i.e. number of cols + + \pgfmathsetmacro{\rowSize}{#3} + \pgfmathsetmacro{\colSize}{#4} + %% decide how many rows to draw + \ifthenelse{\trans=0}{ + % non-trans case + \pgfmathsetmacro{\maxRowId}{\mfmaNonKDim-1} + }{ + % trans case + \ifthenelse{\useMfmaTransLD=0} + {\pgfmathsetmacro{\maxRowId}{32/\bytesPerElem-1}} + {\pgfmathsetmacro{\maxRowId}{512/\mfmaNonKDim/\bytesPerElem-1}} + } + + \pgfmathsetmacro{\elemsPerVec}{\vec} + + \pgfmathsetmacro{\vecInCol}{\colSize/\elemsPerVec} + \pgfmathsetmacro{\maxColVecId}{\vecInCol-1} + + \foreach \gp in {0,...,\maxColVecId}{ + \pgfmathsetmacro{\gpCol}{int(mod(\gp, 16))} + \pgfmathsetmacro{\vecColor}{\Colors[\gpCol]} + \pgfmathsetmacro{\colStart}{int(\gp*\elemsPerVec)} + \pgfmathsetmacro{\colEnd}{int(\colStart+\elemsPerVec-1)} + \foreach \row in {0,...,\maxRowId}{ + \coordinate (vec TL) at ($(TL)+(\gp*\vecInBytes*\bsize, -\row*\elemH)$); + \draw [ultra thin, fill=\vecColor] (vec TL) rectangle ++(\vecInBytes*\bsize, -\elemH) + node [pos=.5, scale=.6*\bankLabelScale*\scale, white] {#1\row,#2\colStart:\colEnd}; + } + } + %% draw dims + \def\gap{3} + \pgfmathsetmacro{\drawRow}{\maxRowId*\elemH+\gap*\elemH+\elemH} + \pgfmathsetmacro{\diffRow}{int(\maxRowId+1-\rowSize)} + \ifthenelse{\diffRow = 0}{\pgfmathsetmacro{\drawRow}{\maxRowId*\elemH+\elemH}}{} + \pgfmathsetmacro{\drawCol}{\vecInCol*\vecInBytes*\bsize} + \draw [ultra thick] (TL) rectangle ++(\drawCol, -\drawRow); + \node [scale=\scale, above, rotate=90] at ($(TL)+(0, -0.5*\drawRow)$) {block\_#1 = \rowSize}; + \node [scale=\scale, above] at ($(TL)+(0.5*\drawCol, 0)$) {block\_#2 = \colSize$\times$\dtype}; + \ifthenelse{\diffRow = 0}{}{ + \node [scale=\scale, rotate=90] at ($(TL)+(0.5*\colSize*\bytesPerElem*\bsize, -\drawRow+.5*\gap*\elemH)$) {$\ldots$};} +} + + +\newcommand{\drawLDSDiagram}[3]{ + %% + %% Draw the diagram of LDS without any data + %% + %% Pre-defined variables + %% TL: top-left coordinates of first elements in LDSaccess + %% bsize: size of a byte + %% mfmaNonKDim + %% K + %% bytesPerElem + %% + %% #1: number of banks + %% #2: rows of tensor plotted + %% #3: columns of tensor plotted + + \pgfmathsetmacro{\banks}{#1} + \pgfmathsetmacro{\rows}{#2} + \pgfmathsetmacro{\cols}{#3} + \pgfmathsetmacro{\maxBankId}{\banks-1} + \pgfmathsetmacro{\tensorHeight}{\rows*\cols*\bytesPerElem/4/\banks*\elemH} + \def\gapT{4} + \def\gapB{2} + \pgfmathsetmacro{\LDSHeight}{\tensorHeight+\gapT*\elemH+\gapB*\elemH} + \coordinate (LDS TL) at ($(TL)+(0, \gapT*\elemH)$); + \foreach \bank in {0,...,\maxBankId}{ + \coordinate (bank TL) at ($(LDS TL)+(\bank*4*\bsize, 0)$); + \draw [ultra thick] (bank TL) rectangle ++(4*\bsize, -\LDSHeight); + \node [scale=.6*\bankLabelScale*\scale, below, align=center] at ($(bank TL)+(2*\bsize,0)$) {bank\\\bank}; + \node [scale=0.8*\bankLabelScale*\scale, rotate=90] at ($(TL)+(2*\bsize+\bank*4*\bsize, -\tensorHeight-0.5*\gapB*\elemH)$) {$\ldots$}; + } + \node [scale=\scale, above] at ($(TL)+(0.5*\banks*4*\bsize, 4*\elemH)$) {LDS \banks\ banks}; +} + +\newcommand{\drawHighlightedAccess}[3]{ + %% Highlight the vectors if \tid < \threshold + %% + %% Predefined variables + %% vec TL: top-left of the current vector + %% elemH, bsize, vecInBytes, bankLabelScale + %% + %% #1: tid + %% #2: threshold + %% #3: label in vector + + \pgfmathsetmacro{\tid}{#1} + \pgfmathsetmacro{\threshold}{#2} + \def\bWidth{0.02} + + \ifthenelse{\tid < \threshold}{ + \ifthenelse{\vecInBits=128}{\def\ratio{0.5}}{\def\ratio{1}} + %% Highlight the vector in LDS + \draw [thick, draw=white, fill=\vecColor] ($(vec TL)+(\bWidth, -\bWidth)$) rectangle ++(\vecInBytes*\bsize-2*\bWidth, -\elemH+2*\bWidth); + \ifthenelse{\vecInBits=128}{ + \node [scale=.6*\bankLabelScale*\scale, white] at ($(vec TL)+(\ratio*\vecInBytes*\bsize, -0.5*\elemH)$) {#3}; + }{ + \node [scale=.6*\bankLabelScale*\scale, white, left] at ($(vec TL)+(\ratio*\vecInBytes*\bsize, -0.5*\elemH)$) {#3}; + } + \node [scale=.5*\bankLabelScale*\scale, right] at ($(vec TL)+(0, -0.5*\elemH)$) {\textbf{t\tid}}; + }{} +} + + +\newcommand{\drawHighlightedAccessInTile}[3]{ + %% Highlight the vectors if \tid < \threshold + %% + %% Predefined variables + %% vec TL: top-left of the current vector + %% elemH, bsize, vecInBytes, bankLabelScale + %% + %% #1: tid + %% #2: threshold + %% #3: label in vector + + \pgfmathsetmacro{\tidRaw}{#1} + \pgfmathsetmacro{\threshold}{#2} + \def\bWidth{0.02} + + \ifthenelse{\tidRaw < \threshold}{ + \ifthenelse{\vecInBits=128}{\def\ratio{0.5}}{\def\ratio{1}} + %% Highlight the vector in global memory + \coordinate (tile vec TL) at ($(tile TL)+(\gp*\vecInBytes*\bsize, -\row*\elemH)$); + \draw [thick, draw=white, fill=\vecColor] ($(tile vec TL)+(\bWidth, -\bWidth)$) rectangle ++(\vecInBytes*\bsize-2*\bWidth, -\elemH+2*\bWidth); + \ifthenelse{\vecInBits=128}{ + \node [scale=.6*\bankLabelScale*\scale, white] at ($(tile vec TL)+(\ratio*\vecInBytes*\bsize, -0.5*\elemH)$) {#3}; + }{ + \node [scale=.6*\bankLabelScale*\scale, white, left] at ($(tile vec TL)+(\ratio*\vecInBytes*\bsize, -0.5*\elemH)$) {#3}; + } + \pgfmathsetmacro{\colorRatio}{40*mod(int(\row/\kPerGroup), 2)} + \node [scale=.5*\bankLabelScale*\scale, right, white!\colorRatio!black] at ($(tile vec TL)+(0, -0.5*\elemH)$) {\textbf{t\tidRaw}}; + }{} +} + +\newcommand{\drawCoalescedGRAccess}[3]{ + %% Highlight the vectors in original tile if \tid < \threshold + %% + %% Predefined variables + %% tile TL: top-left of the original tile + %% gp: vector group id along K dim + %% row: row index + %% elemH, bsize, vecInBytes, bankLabelScale + %% + %% #1: tid + %% #2: threshold + %% #3: label in vector + + \pgfmathsetmacro{\tid}{#1} + \pgfmathsetmacro{\threshold}{#2} + \def\bWidth{0.02} + + \ifthenelse{\tid < \threshold}{ + \ifthenelse{\vecInBits=128}{\def\ratio{0.5}}{\def\ratio{1}} + \coordinate (tile vec TL) at ($(tile TL)+(\gp*\vecInBytes*\bsize, -\row*\elemH)$); + \draw [thick, draw=white, fill=\vecColor] ($(tile vec TL)+(\bWidth, -\bWidth)$) rectangle ++(\vecInBytes*\bsize-2*\bWidth, -\elemH+2*\bWidth); + \ifthenelse{\vecInBits=128}{ + \node [scale=.6*\bankLabelScale*\scale, white] at ($(tile vec TL)+(\ratio*\vecInBytes*\bsize, -0.5*\elemH)$) {#3}; + }{ + \node [scale=.6*\bankLabelScale*\scale, white, left] at ($(tile vec TL)+(\ratio*\vecInBytes*\bsize, -0.5*\elemH)$) {#3}; + } + \node [scale=.5*\bankLabelScale*\scale, right] at ($(tile vec TL)+(0, -0.5*\elemH)$) {\textbf{t\tid}}; + }{} +} + +\newcommand{\drawLDSLayoutAndAccess}[6]{ + %% + %% Draw tensor layout in LDS with swizzling + %% + %% TL: pre defined top-left coordinates of the tensor in global memory + %% \elem: per defined variable + %% \Colors: a pre defined array of 16 colors + %% + %% The following three arguments are expected to be pre defined + %% vec: number of elements in a group + %% trans + %% useMfmaTransLD + %% maxRowId: defined in drawTensorLayoutGlobalMem + %% + %% #1: hasSwizzle, 0 means no swizzling and no padding, + %% 1 means optimal swizzling + %% 2 means padding + %% #2: access mode, 0 means draw nothing, 1 means ds_read, 2 means ds_write + %% #3: number of banks + %% #4: rowLabel + %% #5: colLabel + %% #6: colSize + + \pgfmathsetmacro{\hasSwizzle}{#1} + \pgfmathsetmacro{\accessMode}{#2} + \pgfmathsetmacro{\numVecCol}{\colSize/\vec} + \pgfmathsetmacro{\banks}{#3} + \pgfmathsetmacro{\colSize}{#6} + \pgfmathsetmacro{\rows}{\maxRowId+1} + + \ifthenelse{\trans=0}{ + \drawLDSDiagram{#3}{\rows}{\colSize} + }{ + \drawLDSDiagram{#3}{\rows}{\colSize} + } + + % number of elements per LDS row + \pgfmathsetmacro{\elemsPerLDSRow}{int(\banks*4/\bytesPerElem)} + % number of vectors per LDS row + \pgfmathsetmacro{\vecsPerLDSRow}{\elemsPerLDSRow/\vec} + % max vecId per tile row + \pgfmathsetmacro{\maxColVecId}{\colSize/\vec-1} + + %% Parameters for ds_read + % \vecInBytes: access width in bytes + % \vecInBits: access width in bits (ds_read_b64 or ds_read_b128) + % \numThreadsSameCycle: number of threads that will access LDS at the same cycle (8, 16, or 32) + \pgfmathsetmacro{\vecInBits}{int(\vecInBytes*8)} + \pgfmathsetmacro{\elemInBits}{int(\bytesPerElem*8)} + \pgfmathsetmacro{\numThreadsSameCycle}{int(\banks*4/\vecInBytes)} + \pgfmathsetmacro{\maxTid}{int(\numThreadsSameCycle-1)} + + %% Parameters for swizzling + %% perPhase = ceil(elemsPerLDSRow / K) + %% The number of the rows of the tensor that can share the same swizzling pattern + \pgfmathsetmacro{\perPhase}{ceil(\elemsPerLDSRow/\colSize)} + %% maxPhase: the total number of different swizzling patterns + \ifthenelse{\hasSwizzle=1}{ + %% When vec is small enough, we want 16/perPhase different swizzling patterns + %% When vec is large, we can only have 64 / \vec different swizzling pattern at most + \pgfmathsetmacro{\maxPhase}{min(min(\mfmaNonKDim,\numThreadsSameCycle)/\perPhase,\banks*4/\bytesPerElem/\swizzleVec)} + }{ + %% When swizzling is disabled + \pgfmathsetmacro{\maxPhase}{1} + } + + %% Draw the vectors according to the swizzling pattern + \foreach \gp in {0,...,\maxColVecId}{ + \pgfmathsetmacro{\gpCol}{int(mod(\gp, 16))} + \pgfmathsetmacro{\vecColor}{\Colors[\gpCol]} + \pgfmathsetmacro{\colStart}{int(\gp*\elemsPerVec)} + \pgfmathsetmacro{\colEnd}{int(\colStart+\elemsPerVec-1)} + \foreach \row in {0,...,\maxRowId}{ + %% Compute some info of the current vec + % global offset in unit of vec + \pgfmathsetmacro{\offVec}{\row*\colSize/\vec+\gp} + % which row of LDS + \pgfmathsetmacro{\LDSRow}{int(\offVec/\vecsPerLDSRow)} + % phase + \pgfmathsetmacro{\phaseRaw}{int(\row/\perPhase)} + \pgfmathsetmacro{\phase}{int(mod(\phaseRaw, \maxPhase))} + + % vector ID in the current LDS row + \pgfmathsetmacro{\vecIdInLDSRow}{int(mod(\offVec,\vecsPerLDSRow))} + % number of vec in swizzleVec + \pgfmathsetmacro{\vecsPerSwizzleVec}{int(\swizzleVec/\vec)} + \pgfmathsetmacro{\swizzleVecIdInLDSRow}{int(\vecIdInLDSRow/\vecsPerSwizzleVec)} + \pgfmathsetmacro{\vecIdInSwizzleVec}{int(mod(\vecIdInLDSRow, \vecsPerSwizzleVec))} + \pgfmathsetmacro{\newSwizzleVecId}{\bitwiseXor{\swizzleVecIdInLDSRow}{\phase}} + \pgfmathsetmacro{\LDSVec}{int(\newSwizzleVecId*\vecsPerSwizzleVec+\vecIdInSwizzleVec)} + + %% Padding case needs to recompute \LDSVec and \LDSRow + %% Add padAmount bytes of padding after every padInterval bytes of data + \ifthenelse{\hasSwizzle=2}{ + % global offset in unit of bytes + \pgfmathsetmacro{\offVecStartByte}{\row*\colSize/\vec*\vecInBytes+\gp*\vecInBytes} + \pgfmathsetmacro{\paddedVecStartByte}{int(\offVecStartByte/\padInterval)*\padAmount+\offVecStartByte} + \pgfmathsetmacro{\LDSRow}{int(\paddedVecStartByte/\banks/4)} + \pgfmathsetmacro{\LDSVec}{int(mod(int(\paddedVecStartByte), int(\banks*4))/\vecInBytes)} + }{} + + \coordinate (vec TL) at ($(TL)+(\LDSVec*\vecInBytes*\bsize, -\LDSRow*\elemH)$); + \draw [ultra thin, fill=\vecColor] (vec TL) rectangle ++(\vecInBytes*\bsize, -\elemH) + node [pos=.5, scale=.6*\bankLabelScale*\scale, white] {#4\row,#5\colStart:\colEnd}; + + %% draw phase of each LDS row + \pgfmathsetmacro{\lastVecId}{\vecsPerLDSRow-1} + \ifthenelse{\LDSVec=\lastVecId}{ + \draw [ultra thin] ($(vec TL)+(\vec*\bytesPerElem*\bsize, -.5*\bsize)$) -- ++(\elemH, 0) + node [scale=0.6*\bankLabelScale*\scale, right] {\phase}; + }{} + + %% For ds_read/write access patterns, we first decide the thread id that owns + %% the current vector. And then we decide if the current vector is accessed + %% at the first cycle according to thread id and access width + + %%%%%%%%%%%%%%%% + % Draw ds_read % + %%%%%%%%%%%%%%%% + \ifthenelse{\accessMode=1}{ + \ifthenelse{\trans=0}{ + %%%%%%%%%%%%%%%%%%% + %% K-contig case %% + %%%%%%%%%%%%%%%%%%% + %%% compute thread id for current vec + \pgfmathsetmacro{\tid}{int(\gp*\mfmaNonKDim+\row)} + \pgfmathsetmacro{\kPerGroup}{\mfmaNonKDim} + \drawHighlightedAccessInTile{\tid}{\numThreadsSameCycle}{#4\row,#5\colStart:\colEnd} + %%% draw ds_read instruction name + \ifthenelse{\tid=0}{ + \ifthenelse{\vecInBits=128}{ + %%% Special thread access pattern for ds_read_b128 + \ifthenelse{\banks=32}{ + \node [scale=\scale, above right] at ($(TL)+(0, \gapT*\elemH)$) + {ds\_read\_b\vecInBits\ (t0$\sim$t\maxTid\ $\Leftrightarrow$\ t0$\sim$t3, t20$\sim$t23)}; + }{ + \node [scale=\scale, above right] at ($(TL)+(0, \gapT*\elemH)$) + {ds\_read\_b\vecInBits\ (t0$\sim$t\maxTid\ $\Leftrightarrow$\ t0$\sim$t3, t12$\sim$t15, t23$\sim$t27)}; + } + }{ + %%% Normal thread access pattern for ds_read_b64 + \node [scale=\scale, above right] at ($(TL)+(0, \gapT*\elemH)$) + {ds\_read\_b\vecInBits\ (t0$\sim$t\maxTid)}; + } + }{} + %%% highlight vector of the threads that will access LDS at the same cycle + \drawHighlightedAccess{\tid}{\numThreadsSameCycle}{#4\row,#5\colStart:\colEnd} + }{ + %%%%%%%%%%%%%%%%%%%% + %% MN-contig case %% + %%%%%%%%%%%%%%%%%%%% + %%% This is further diverging according to whether mfma_transpose_load is used + \ifthenelse{\useMfmaTransLD=0}{ + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + %%% Not use mfma_transpose_load instructions %%% + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + %%%% This is the current triton implementation of MN-contig case. We have + %%%% - Threads can only load one element, i.e. ds_read_b16/b8 are used + %%%% - 32 threads are accessing LDS at the same cycle + %%%% - if mfmaNonKDim == 32, they will only access row 0 + %%%% - if mfmaNonKDim == 16, they will access row 0 and mfmaKWidth + %%%% - no swizzling used for lds layout + %%%% - vec is always 16 bytes + \pgfmathsetmacro{\numGp}{\mfmaNonKDim/\vec} + \pgfmathsetmacro{\maxElId}{\vec-1} + \pgfmathsetmacro{\hasSecondRow}{int((32-\mfmaNonKDim)/16)} + \pgfmathsetmacro{\secondRow}{\hasSecondRow*\mfmaKWidth} + \pgfmathsetmacro{\tStart}{\colStart+int(\row/\mfmaKWidth)*16} + \pgfmathsetmacro{\tEnd}{\tStart+\vec} + \ifthenelse{\gp < \numGp}{ + \ifthenelse{\row = 0 \OR \row = \secondRow}{ + \foreach \el in {0,...,\maxElId}{ + \pgfmathsetmacro{\tid}{int(\tStart+\el)} + \pgfmathsetmacro{\kNewEnd}{int(\row+\mfmaKWidth-1)} + %%%% Draw access in LDS + \draw ($(vec TL)+(\el*\bytesPerElem*\bsize, 0)$) rectangle ++(\bytesPerElem*\bsize, -\elemH) + node[scale=0.4*\bankLabelScale, pos=.5] {t\tid}; + %%%% Draw access in original tile + \coordinate (vertical vec TL) at ($(tile TL)+(\gp*\vecInBytes*\bsize+\el*\bytesPerElem*\bsize, -\row*\elemH)$); + \def\bWidth{0.002} + \draw [thick, draw=white, fill=\vecColor, opacity=0.6] ($(vertical vec TL)+(\bWidth, -\bWidth)$) + rectangle ++(\bytesPerElem*\bsize-2*\bWidth, -\elemH*\mfmaKWidth+2*\bWidth); + \node [scale=0.5*\bankLabelScale, left, rotate=90] at ($(vertical vec TL)+(0.5*\bytesPerElem*\bsize, -\bWidth)$) {t\tid}; + \node [scale=0.5*\bankLabelScale, right, rotate=90, white] at + ($(vertical vec TL)+(0.5*\bytesPerElem*\bsize, -\elemH*\mfmaKWidth+2*\bWidth)$) {\bfseries{n\el,k\row:\kNewEnd}}; + } % End \foreach + \draw [ultra thin, fill=\vecColor, opacity=0.6] (vec TL) rectangle ++(\vecInBytes*\bsize, -\elemH) + node [pos=.5, scale=.6*\bankLabelScale*\scale, white] {#4\row,#5\colStart:\colEnd}; + }{} % End \ifthenelse{\row = 0 \OR \row = \secondRow} + }{} % End \ifthenelse{\gp < \numGp} + \ifthenelse{\gp = 0 \AND \row = 0}{ + \node [scale=\scale, above right] at ($(TL)+(0, \gapT*\elemH)$) + {ds\_read\_b\elemInBits\ (t0$\sim$t31)};}{} + }{ + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + %%% Use mfma_transpose_load instructions %%% + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + %%%% Compute the tid for current vector: tid = f(\gp, \row) + %%%% threadsPerGroupM = 16 * bytesPerElem / 8 + %%%% kPerGroup = 16 / threadsPerGroupM + %%%% numGp = mfmaNonKDim / 16 + %%%% groupId = 32 / mfmaNonKDim + %%%% groupRow = groupId // 2 + %%%% groupCol = groupId % 2 + \pgfmathsetmacro{\threadsPerGroupM}{int(16*\bytesPerElem/8)} + \pgfmathsetmacro{\kPerGroup}{int(16/\threadsPerGroupM)} + \ifthenelse{\mfmaNonKDim=16}{ + % nonKDim = 16 + \pgfmathsetmacro{\colOff}{int(\gp/\threadsPerGroupM)*32+mod(\gp,\threadsPerGroupM)} + \pgfmathsetmacro{\kPerTwoGroups}{int(\kPerGroup*2)} + \pgfmathsetmacro{\rowOff}{int(\row/\kPerGroup/2)*16+mod(\row, \kPerTwoGroups)*\threadsPerGroupM} + \pgfmathsetmacro{\kGroupId}{int(\row/\kPerGroup)} + \ifthenelse{\kGroupId = 1}{\pgfmathsetmacro{\tid}{32}} + {\pgfmathsetmacro{\tid}{int(\rowOff+\colOff)}} + % draw highlighed vector in tile + \pgfmathsetmacro{\rowOffRaw}{int(\row/\kPerGroup/2)*16+mod(\row, \kPerGroup)*\threadsPerGroupM} + \pgfmathsetmacro{\tidRaw}{int(\rowOffRaw+\colOff)} + \drawHighlightedAccessInTile{\tidRaw}{\numThreadsSameCycle}{#4\row,#5\colStart:\colEnd} + }{ + % nonKDim = 32 + \pgfmathsetmacro{\colOff}{int(\gp/\threadsPerGroupM)*16+mod(\gp,\threadsPerGroupM)} + \pgfmathsetmacro{\rowOff}{int(\row/\kPerGroup)*32+mod(\row, \kPerGroup)*\threadsPerGroupM} + \pgfmathsetmacro{\tid}{int(\rowOff+\colOff)} + % draw highlighed vector in tile + \pgfmathsetmacro{\rowOffRaw}{int(\row/\kPerGroup/2)*32+mod(\row, \kPerGroup)*\threadsPerGroupM} + \pgfmathsetmacro{\tidRaw}{int(\rowOffRaw+\colOff)} + \drawHighlightedAccessInTile{\tidRaw}{\numThreadsSameCycle}{#4\row,#5\colStart:\colEnd} + } + %%% draw ds_read instruction name + \ifthenelse{\tid=0}{ + \node [scale=\scale, above right] at ($(TL)+(0, \gapT*\elemH)$) + {ds\_read\_b64\_tr\_b\elemInBits\ (t0$\sim$t\maxTid)};}{} + %%% highlight vector of the threads that will access LDS at the same cycle + \drawHighlightedAccess{\tid}{\numThreadsSameCycle}{#4\row,#5\colStart:\colEnd} + } % End of MN-contig case + } % End of trans/non-trans case + }{} %% End draw ds_read + + %% Draw ds_write + \ifthenelse{\accessMode=2}{ + % compute thread id for current vec + % Here we assume the following global load pattern: + % - global/buffer_load_dwordx4, i.e. sizePerThread[1] = 128-bit + % - CTA coverage will always cover all elements along the K dim first, i.e. + % sizePerThread[1]*threadsPerWarp[1] == K or + % (sizePerThread[1]*threadsPerWarp[1] < K and threadsPerWarp[0] == 1) + \pgfmathsetmacro{\offBytes}{int(\row*\colSize*\bytesPerElem+\gp*\vecInBytes)} + \pgfmathsetmacro{\tidRaw}{int(\offBytes/16)} + \pgfmathsetmacro{\remTid}{int(mod(\offBytes,16))} + \drawCoalescedGRAccess{\tidRaw}{\numThreadsSameCycle}{#4\row,#5\colStart:\colEnd} + \ifthenelse{\remTid>0}{\pgfmathsetmacro{\tid}{int(\tidRaw+32)}}{\pgfmathsetmacro{\tid}{\tidRaw}} + % draw ds_write instruction name + \ifthenelse{\tid=0}{ + \node [scale=\scale, above right] at ($(TL)+(0, \gapT*\elemH)$) + {ds\_write\_b\vecInBits\ (t0$\sim$t\maxTid)};}{} + % highlight vector of the threads that will access LDS at the same cycle + \drawHighlightedAccess{\tid}{\numThreadsSameCycle}{#4\row,#5\colStart:\colEnd} + }{} %% End draw ds_write + + } + } + \node [scale=0.6*\bankLabelScale*\scale, above right] at($(TL)+(\banks*4*\bsize, 0)$) {phase}; +} diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index 599f92c790e4..9973f05e8176 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -2,141 +2,247 @@ import sys import os import subprocess +from dataclasses import dataclass + + +def draw_dot_layout_cmd(M, N, K, dtype_a, dtype_b, mfma_inst_str, isMixed864, plot_scale, dotConfig): + mfmaNonKDim = dotConfig.mfmaNonKDim + warpsPerCTA = dotConfig.warpsPerCTA + trans = dotConfig.trans + kWidth = dotConfig.kWidth + kGroup = dotConfig.kGroup + scaleLabel = 0.7 if (kWidth == 4 or (kWidth == 8 and mfmaNonKDim == 32)) else 1 + + outType = 'i32' if dtype_a == 'i8' else 'f32' + kWidth_a = kWidth_b = kWidth + kGroup_a = kGroup_b = kGroup + if isMixed864: + if isType8BitFloat(dtype_a): + kWidth_a = 16 + kGroup_a = 2 + kWidth_b = 32 + kGroup_b = 1 + else: + kWidth_a = 32 + kGroup_a = 1 + kWidth_b = 16 + kGroup_b = 2 + kWidth_left = kWidth_b if trans else kWidth_a + kGroup_left = kGroup_b if trans else kGroup_a + + elemSmall = 0.04 + elemLarge = 0.16 + elemPerThread = kWidth_a * kGroup_a + if elemPerThread == 16: + ratio = 0.8 + elif elemPerThread == 32: + ratio = 0.6 + else: + ratio = 1 + elemWidth = elemLarge * ratio + scaling = 1 if plot_scale else 0 -def draw_preamble_cmd(): - return '''\\documentclass[tikz, border=1mm, dvipsnames]{standalone} -\\usepackage{ifthen} -\\usepackage{tikz} -\\usetikzlibrary{arrows.meta,arrows} -\\usetikzlibrary{intersections} -\\usetikzlibrary{calc, quotes} -\\usetikzlibrary{patterns} -\\usepackage{xparse} - -\\ExplSyntaxOn -\\NewExpandableDocumentCommand{\\bitwiseXor}{mm} - { - \\recuenco_bitwise_xor:nn { #1 } { #2 } - } - -\\cs_new:Nn \\recuenco_bitwise_xor:nn - { - \\int_from_bin:e - { - \\__recuenco_bitwise_xor:ee { \\int_to_bin:n { #1 } } { \\int_to_bin:n { #2 } } - } - } -\\cs_generate_variant:Nn \\int_from_bin:n { e } - -\\cs_new:Nn \\__recuenco_bitwise_xor:nn - { - \\__recuenco_bitwise_xor_binary:ee - { - \\prg_replicate:nn - { - \\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #1 } - } - { 0 } - #1 - } - { - \\prg_replicate:nn - { - \\int_max:nn { \\tl_count:n { #1 } } { \\tl_count:n { #2 } } - \\tl_count:n { #2 } - } - { 0 } - #2 - } - } -\\cs_generate_variant:Nn \\__recuenco_bitwise_xor:nn { ee } - -\\cs_new:Nn \\__recuenco_bitwise_xor_binary:nn - { - \\__recuenco_bitwise_xor_binary:w #1;#2; - } -\\cs_generate_variant:Nn \\__recuenco_bitwise_xor_binary:nn { ee } - -\\cs_new:Npn \\__recuenco_bitwise_xor_binary:w #1#2;#3#4; - { - \\int_abs:n { #1-#3 } - \\tl_if_empty:nF { #2 } { \\__recuenco_bitwise_xor_binary:w #2;#4; } - } - -\\ExplSyntaxOff''' - - -def draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack): return f'''\\begin{{document}} \\begin{{tikzpicture}} \\def\\scale{{1}} - \\def\\elem{{0.04}} + \\def\\elem{{{elemSmall}}} + \\def\\elemW{{\\elem}} + \\def\\kWidthA{{{kWidth_a}}} + \\def\\kWidthB{{{kWidth_b}}} + \\def\\kGroupA{{{kGroup_a}}} + \\def\\kGroupB{{{kGroup_b}}} \\coordinate (C TL) at (0,0); - \\def\\opColorAL{{magenta}} - \\def\\opColorAR{{cyan}} - \\def\\opColorBL{{Maroon}} - \\def\\opColorBR{{BlueGreen}} - \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}{{{kpack}}} + \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}} \\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$); \\def\\mfmaTrans{{{trans}}} %% Draw zoomed in view of mfma - \\def\\elem{{.16}} + \\def\\scaleLabel{{{scaleLabel}}} + \\pgfmathsetmacro{{\\oldElem}}{{\\elem}} + \\def\\elem{{{elemLarge}}} + \\def\\elemW{{{elemWidth}}} \\pgfmathsetmacro{{\\gap}}{{\\elem*5}} \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} - \\coordinate (C TL) at ($(C TL)+(.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kpack}*\\elem, 0)$); - \\drawMFMAInstr{{{mfmaNonKDim}}}{{{kpack}}}{{\\mfmaTrans}} + \\coordinate (C TL) at ($(C TL)+({scaling}*0.3*\\gap+{scaling}*\\groups*4*\elemW+.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth_left}*{kGroup_left}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); + \\coordinate (mfma instr) at ($(C TL)+(-.5*\\gap-0.6*\\nonTrans*\\gap-0.4*\\mfmaTrans*\\gap, 1.5*\\gap+.5*\\mfmaTrans*\\gap)$); + \\node [scale=\scaleLabel, above left, align=left, draw=black, fill=white] at (mfma instr) {{{mfma_inst_str}}}; + \\drawMFMAInstr{{{mfmaNonKDim}}}{{\\mfmaTrans}}{{{dtype_a}}}{{{dtype_b}}}{{{outType}}}{{{scaling}}} \\end{{tikzpicture}} \\end{{document}}''' -def draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order): +def draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, blockedConfig): return f'''\\begin{{document}} \\begin{{tikzpicture}} \\def\\scale{{1}} \\def\\elem{{0.06}} \\coordinate (TL) at (0,0); - \\drawBlockedTensor{{{M}}}{{{K}}}{{{sizePerThread[0]}}}{{{sizePerThread[1]}}}{{{threadsPerWarp[0]}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{order[0]}}} + \\def\\dimColName{{{dim0Name}}} + \\def\\dimRowName{{{dim1Name}}} + \\drawBlockedTensor{{{dim0}}}{{{dim1}}}{{{blockedConfig.sizePerThread[0]}}}{{{blockedConfig.sizePerThread[1]}}}{{{blockedConfig.threadsPerWarp[0]}}}{{{blockedConfig.warpsPerCTA[0]}}}{{{blockedConfig.warpsPerCTA[1]}}}{{{blockedConfig.order[0]}}} \\end{{tikzpicture}} \\end{{document}}''' -def draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp): - if ldsLayout == 'swizzle': +def typeToBytes(dtype): + if dtype == 'bf16' or dtype == 'fp16': + return 2 + if dtype == 'bf8' or dtype == 'fp8' or dtype == 'i8': + return 1 + if dtype == 'f4': + return 0.5 + if dtype == 'fp6' or dtype == 'bf6': + return 0.75 + + +def maxKDimInBytes(dtype, mfmaNonKDim, kWidth): + groups = 64 / mfmaNonKDim + if (dtype == 'bf8' or dtype == 'fp8') and kWidth == 16: + groups *= 2 + return groups * kWidth * typeToBytes(dtype) + + +def calcPerPhase(banks, dtype, K): + bytesPerBank = 4 + return max(banks * bytesPerBank / (K * typeToBytes(dtype)), 1) + + +def draw_lds_access_cmd(dim0, dim1, dtype, mfmaNonKDim, ldsConfig): + if ldsConfig.ldsLayout == 'swizzle': hasSwizzle = 1 - elif ldsLayout == 'padding': + elif ldsConfig.ldsLayout == 'padding': hasSwizzle = 2 else: hasSwizzle = 0 - if ldsAccess == 'read': + if ldsConfig.ldsAccess == 'read': accessMode = 1 - elif ldsAccess == 'write': + elif ldsConfig.ldsAccess == 'write': accessMode = 2 else: accessMode = 0 + trans = 1 if ldsConfig.mnContig else 0 + useMfmaTransLD = 1 if ldsConfig.mfmaTransLD else 0 + banks = ldsConfig.banks + padInterval = ldsConfig.padInterval + padAmount = ldsConfig.padAmount + + if trans: + dim0Name = 'k' + dim1Name = 'n' + else: + dim0Name = 'm' + dim1Name = 'k' + dim0Size = dim0 + dim1Size = dim1 + ''' + Definitions of different vector size + + swizzleVec: Number of elements that are grouped together when swizzling is enabled. + Note that this is all about LDS layout without considering LDS read + or write patterns. And this is un-related to K- or MN-contig settings. + accessVec: When reading from or writing to LDS, accessVec is the number of contiguous + elements each thread read or write as a vector. + This is un-related to K- or MN-contig settings. + Note that accessVec <= swizzleVec. accessVec for read and write are not + required to be the same. + kWidth: Number of contiguous elements along the k dim that each thread holds + right before invoking mfma instruction(s). kWidth can be larger than + the required number of contiguous elements along the k dim for a single + mfma instruction. + Note that kWidth is un-related to swizzleVec or accessVec. kWidth should + be set according to datatype and mfmaNonKDim. + + We need to handle the following cases of LDS layout and access patterns: + + case 1: K-contig in both HBM and LDS (default) + In most cases, we can set swizzleVec = accessVec = kWidth according to the dtype. + However, for mfmaNonKDim = 16, banks = 64, and kWidth = 8B, 32 threads will + access LDS at the same cycle. In this case, we need to double swizzleVec = 16B. + + Swizzling: works as suggested above. + Padding: will have bank conflicts for ds_read_b128 due to non-linear thread ids + are accessing LDS at the same cycle + + case 2: MN-contig in both HBM and LDS without using mfma_transpose_ld instructions (-mnContig) + In this case, ds_read can only read one element at a time (i.e. accessVec is always 1). + Therefore, we can always choose swizzleVec = 16B. kWidth does not matter. accessVec is always 1. + Note that in this case, only swizzling is supported and can help resolve bank conflicts. + But the performance bottleneck is scalar ds_read rather than bank conflicts. + + case 3: MN-contig in both HBM and LDS using mfma_transpose_ld instructions (-mnContig -mfma_trans_load) + In this case, ds_read is done in a special pattern so that the ds_read_b64_tr_bx instructions + can be used. Each thread will read 8B data, which corresponds to kWidth = 8B/elemInBytes. + The swizzleVec needs to be set to mfmaNonKDim. + + Swizzling: currently, it leads to bank conflicts for nonKDim = 16 and + if the read pattern follows the spec. + For nonKDim = 32, swizzling does not have bank conflicts. + Padding: It can help resolve bank conflicts for both nonKDim = 16 and 32. + However, it leads to a lot of waste of LDS space. + + case 4: MN-contig in HBM and k-Contig in LDS (-inThreadTrans) + Not supported yet + ''' + + elemTypeInBytes = typeToBytes(dtype) + + bankLabelScale = 0.8 + bsize = 0.15 + + if trans == 0: + # case 1 + swizzleVec = ldsConfig.swizzleVec + accessVec = ldsConfig.accessVec + vec = ldsConfig.kWidth + elif useMfmaTransLD == 0: + # case 2 + swizzleVec = 16 / elemTypeInBytes + accessVec = 1 + vec = swizzleVec + else: + # case 3 + vec = 8 / elemTypeInBytes + swizzleVec = mfmaNonKDim + accessVec = ldsConfig.accessVec + + kWidth = ldsConfig.kWidth + vecInBytes = vec * elemTypeInBytes + return f'''\\begin{{document}} \\begin{{tikzpicture}} \\def\\scale{{1}} - \\def\\M{{{M}}} - \\def\\K{{{K}}} - \\def\\vec{{{kpack}}} + \\def\\M{{{dim0}}} + \\def\\K{{{dim1}}} + \\def\\mfmaKWidth{{{kWidth}}} + \\def\\vec{{{vec}}} + \\def\\swizzleVec{{{swizzleVec}}} + \\def\\accessVec{{{accessVec}}} + \\def\\vecInBytes{{{vecInBytes}}} + \\def\\bytesPerElem{{{elemTypeInBytes}}} \\def\\hasSwizzle{{{hasSwizzle}}} \\def\\accessMode{{{accessMode}}} - - \\def\\sizePerThreadK{{{sizePerThread[1]}}} - \\def\\sizePerThreadM{{{sizePerThread[0]}}} - \\def\\threadsPerWarpK{{{threadsPerWarp[1]}}} - + \\def\\mfmaNonKDim{{{mfmaNonKDim}}} + \\def\\dtype{{{dtype}}} + \\def\\trans{{{trans}}} + \\def\\useMfmaTransLD{{{useMfmaTransLD}}} + \\def\\padInterval{{{padInterval}}} + \\def\\padAmount{{{padAmount}}} + + \\def\\elemH{{0.18}} \\def\\elem{{0.18}} - \\coordinate (TL) at (0,0); - \\drawTensorLayoutGlobalMem - \\coordinate (TL) at ($(TL)+(0, -24*\\elem-10*\\elem)$); - \\drawLDSLayoutTritonSwizzling{{\\hasSwizzle}}{{\\accessMode}} + \\def\\bsize{{{bsize}}} + \\def\\bankLabelScale{{{bankLabelScale}}} + \\coordinate (tile TL) at (0,0); + \\coordinate (TL) at (tile TL); + \\drawTensorLayoutGlobalMem{{{dim0Name}}}{{{dim1Name}}}{{{dim0Size}}}{{{dim1Size}}} + \\coordinate (TL) at ($(TL)+(0, -\drawRow-8*\\elemH)$); + \\drawLDSLayoutAndAccess{{\\hasSwizzle}}{{\\accessMode}}{{{banks}}}{{{dim0Name}}}{{{dim1Name}}}{{{dim1Size}}} \\end{{tikzpicture}} \\end{{document}}''' @@ -153,6 +259,130 @@ def draw_wmma_instr_cmd(waveSize): \\end{{document}}''' +matrixFormatTable = {'fp8': 0, 'bf8': 1, 'fp6': 2, 'bf6': 3, 'f4': 4} + + +def matrixFormat(dtype_a, dtype_b): + ''' + return CBSZ and BLGP according to data types + b000: E4M3(FP8) + b001: E5M2(BF8) + b010: E2M3(FP6) + b011: E3M2(BF6) + b100: E2M1(FP4) + ''' + return matrixFormatTable[dtype_a], matrixFormatTable[dtype_b] + + +def isType4Or6Bit(dtype): + return dtype == 'fp6' or dtype == 'bf6' or dtype == 'f4' + + +def isType8BitFloat(dtype): + return dtype == 'fp8' or dtype == 'bf8' + + +def isType16Bit(dtype): + return dtype == 'bf16' or dtype == 'fp16' + + +def isMixedPrecType(dtype): + return isType8BitFloat(dtype) or isType4Or6Bit(dtype) + + +def isMixedPrecBtwF8AndF4OrF6(dtype_a, dtype_b): + return (isType8BitFloat(dtype_a) and isType4Or6Bit(dtype_b)) or (isType8BitFloat(dtype_b) + and isType4Or6Bit(dtype_a)) + + +def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans, scale): + ## Check input types + ## Mixed precision is only allowed within f8, f6 and f4 + assert (isMixedPrecType(dtype_a) and isMixedPrecType(dtype_b)) or ( + dtype_a == dtype_b), f"Cannot do mixed precision mfma with {dtype_a} and {dtype_b}" + ''' + Check mfma size according to data types + * refers to newly added instructions on MI350 + Both dtyes are f4 or fp6 or bf6 + *mfma_f32_16x16x128_f8f6f4: kWidth = 32, kGroup = 1 + *mfma_f32_32x32x64_f8f6f4: kWidth = 32, kGroup = 1 + One dtype is fp8 or bf8 + When the other operand is f4, fp6, or bf6 + *mfma_f32_16x16x128_f8f6f4: kWidth = 16, kGroup = 2 + *mfma_f32_32x32x64_f8f6f4: kWidth = 16, kGroup = 2 + When the other operand is fp8 or bf8 + *mfma_f32_16x16x128_f8f6f4: kWidth = 16, kGroup = 2 + mfma_f32_16x16x32_fp8/bf8_fp8/bf8: kWidth = 16, kGroup = 1, kpack=2 + mfma_f32_16x16x32_fp8/bf8_fp8/bf8: kWidth = 8, kGroup = 1 + *mfma_f32_32x32x64_f8f6f4: kWidth = 16, kGroup = 2 + mfma_f32_32x32x16_fp8/bf8_fp8/bf8: kWidth = 16, kGroup = 1, kpack=2 + mfma_f32_32x32x16_fp8/bf8_fp8/bf8: kWidth = 8, kGroup = 1 + Both dtypes are bf16 or bf16 + *mfma_f32_16x16x32_f16/bf16: kWidth = 8, kGroup = 1 + mfma_f32_16x16x16_f16/bf16: kWidth = 4, kGroup = 1 + *mfma_f32_32x32x16_f16/bf16: kWidth = 8, kGroup = 1 + mfma_f32_32x32x8_f16/bf16: kWidth = 4, kGroup = 1 + Both types are i8 + *mfma_i32_16x16x64_i8: kWidth = 16, kGroup = 1 + mfma_i32_16x16x32_i8: kWidth = 8, kGroup = 1 + *mfma_i32_32x32x32_i8: kWidth = 16, kGroup = 1 + mfma_i32_32x32x16_i8: kWidth = 8, kGroup = 1 + + Return mfma instruction name and kpack + ''' + kDim = 64 / mfmaNonKDim * kWidth * kGroup + ## Both dtyes are f4 or fp6 or bf6 + if isType4Or6Bit(dtype_a) and isType4Or6Bit(dtype_b): + assert kWidth == 32 and kGroup == 1, f"Only kWidth=32 and kGroup=1 is supported for {dtype_a} x {dtype_b}" + kpack = 1 + CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] + BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] + scale_str = 'scale_' if scale else '' + return f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4", kpack, CBSZ, BLGP, scale + + ## Both dtypes are fp8 or bf8 + if isType8BitFloat(dtype_a) and isType8BitFloat(dtype_b): + assert (kWidth == 8 and kGroup == 1) or ( + kWidth == 16), f"Not a valid mfma instruction for {dtype_a} x {dtype_b} with {kWidth=} and {kGroup=}" + kpack = 2 if (kWidth == 16 and kGroup == 1) else 1 + if kGroup == 2: + suffix = "f8f6f4" + CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] + BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] + plot_scale = scale + scale_str = 'scale_' if scale else '' + else: + suffix = f"{dtype_b}_{dtype_a}" if trans else f"{dtype_a}_{dtype_b}" + CBSZ = -1 + BLGP = -1 + plot_scale = False + scale_str = '' + kDim = kDim / 2 if kpack == 2 else kDim + return f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{suffix}", kpack, CBSZ, BLGP, plot_scale + + ## Both types are fp16 or bf16 + if isType16Bit(dtype_a) and isType16Bit(dtype_b): + assert ( + kWidth == 8 or kWidth == 4 + ) and kGroup == 1, f"Not a valid mfma instruction for {dtype_a} x {dtype_b} with {kWidth=} and {kGroup=}" + kpack = 1 + CBSZ = -1 + BLGP = -1 + return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}", kpack, CBSZ, BLGP, False + + ## Both types are i8 + if dtype_a == 'i8' and dtype_b == 'i8': + assert ( + kWidth == 16 or kWidth == 8 + ) and kGroup == 1, f"Not a valid mfma instruction for {dtype_a} x {dtype_b} with {kWidth=} and {kGroup=}" + kpack = 1 + CBSZ = -1 + BLGP = -1 + return f"mfma_i32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}", kpack, CBSZ, BLGP, False + + assert False, "Mixed precision between fp8/bf8 and fp6/bf6/f4 not supported in this mode" + + def run_bash_command(commandstring): proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE) return proc.stdout.splitlines() @@ -164,26 +394,50 @@ def parse_args(): allow_abbrev=False, ) ## tensor shapes - parser.add_argument("-shape", type=int, nargs=3, default=(32, 128, 64), help='Tensor shape in the form of M,N,K') + parser.add_argument("-tensorShape", type=int, nargs=2, default=(128, 64), + help='2D tensor shape in the form of dim0,dim1') + parser.add_argument("-dotShape", type=int, nargs=3, default=(32, 128, 64), help='Dot op shape in the form of M,N,K') parser.add_argument("-plot", type=str, default="blocked", choices=['blocked', 'dot', 'wmma', 'lds'], help='choose plot mode') - parser.add_argument("-nonKDim", type=int, default=32, choices=[16, 32], help='mfma instruction dim') + parser.add_argument("-dim0", type=str, default="M", help='tensor dim0 name') + parser.add_argument("-dim1", type=str, default="K", help='tensor dim1 name') ## blocked layout parameters parser.add_argument("-sizePerThread", type=int, nargs=2, default=(1, 4)) parser.add_argument("-threadsPerWarp", type=int, nargs=2, default=(16, 4)) parser.add_argument("-warpsPerCTA", type=int, nargs=2, default=(1, 4)) parser.add_argument("-order", type=int, nargs=2, default=(1, 0)) + ## dot layout parameters + parser.add_argument("-nonKDim", type=int, default=16, choices=[16, 32], help='mfma instruction dim') + parser.add_argument("-kWidth", type=int, default=4, choices=[4, 8, 16, 32], + help='number of contiguous elements per thread') + parser.add_argument("-kGroup", type=int, default=1, choices=[1, 2], + help='total number of elements / kWidth per mfma instruction') + parser.add_argument("-dtype_a", type=str, default='fp16', + choices=['fp16', 'bf16', 'fp8', 'bf8', 'fp6', 'bf6', 'f4', + 'i8'], help='element type of operand A') + parser.add_argument("-dtype_b", type=str, default='fp16', + choices=['fp16', 'bf16', 'fp8', 'bf8', 'fp6', 'bf6', 'f4', + 'i8'], help='element type of operand B') + parser.add_argument("-mfmaTrans", action='store_true', default=False, help='If set, then use mfma.trans layout') + parser.add_argument("-scale", action='store_true', default=False, + help='If set, plot the scale tensor for mfma_f8f6f4 instructions') ## LDS access parameters - parser.add_argument("-kWidth", type=int, default=4, choices=[4, 8, 16], help='number of elements per thread') + parser.add_argument("-banks", type=int, default=32, choices=[32, 64], help='choose the number of banks in LDS') parser.add_argument("-lds_layout", type=str, default="none", choices=['swizzle', 'padding', 'none'], help='choose the LDS data layout') parser.add_argument("-lds_access", type=str, default="none", choices=['read', 'write', 'none'], help='choose LDS access mode') + parser.add_argument("-mnContig", action='store_true', default=False, + help='If set, the tensor is K x N and n-contig') + parser.add_argument("-mfma_trans_load", action='store_true', default=False, + help='If set, use MFMA transpose load instructions') + parser.add_argument("-swizzleVec", type=int, default=4, choices=[4, 8, 16, 32], + help='number of contiguous elements in a vector to swizzle') + parser.add_argument("-padInterval", type=int, default=1, help='Add padding for every padInterval bytes') + parser.add_argument("-padAmount", type=int, default=0, help='Pad padAmount bytes for every padInterval bytes') ## wmma instruction layout parameter parser.add_argument("-wave_size", type=int, default=32, choices=[32, 64], help='choose the wmma instruction mode') - parser.add_argument("-o", type=str, default="myplot", help='output pdf file name (without surfix)') - parser.add_argument("-mfmaTrans", action='store_true', default=False, help='If set, then use mfma.trans layout') parser.add_argument("-keep", action='store_true', default=False, help='If set, keep the generated .tex file') args = parser.parse_args() @@ -191,22 +445,88 @@ def parse_args(): return args +@dataclass +class BlockedConfig: + sizePerThread: tuple + threadsPerWarp: tuple + warpsPerCTA: tuple + order: tuple + + +@dataclass +class DotConfig: + mfmaNonKDim: int + kWidth: int + kGroup: int + trans: int + warpsPerCTA: tuple + + +@dataclass +class LDSConfig: + banks: int + ldsLayout: str + ldsAccess: str + mnContig: bool + mfmaTransLD: bool + swizzleVec: int + accessVec: int + kWidth: int + padInterval: int + padAmount: int + + def __init__(self, banks, ldsLayout, ldsAccess, mnContig, mfmaTransLD, swizzleVec, accessVec, kWidth, padInterval, + padAmount): + self.banks = banks + self.ldsLayout = ldsLayout + self.ldsAccess = ldsAccess + self.mnContig = mnContig + self.mfmaTransLD = mfmaTransLD + self.swizzleVec = swizzleVec + self.accessVec = accessVec + self.kWidth = kWidth + self.padInterval = padInterval + self.padAmount = padAmount + if self.swizzleVec < self.kWidth: + self.swizzleVec = self.kWidth + + def print(self): + print( + f"{self.banks=} {self.ldsLayout=} {self.ldsAccess=} {self.mnContig=} {self.mfmaTransLD=} {self.swizzleVec=} {self.accessVec=} {self.kWidth=} {self.padInterval} {self.padAmount}" + ) + + def main(): args = parse_args() - shape = args.shape - M = shape[0] - N = shape[1] - K = shape[2] + dotShape = args.dotShape + M = dotShape[0] + N = dotShape[1] + K = dotShape[2] + tShape = args.tensorShape + dim0 = tShape[0] + dim1 = tShape[1] + dim0Name = args.dim0 + dim1Name = args.dim1 plot_mode = args.plot mfmaNonKDim = args.nonKDim - kpack = args.kWidth + kWidth = args.kWidth + kGroup = args.kGroup + dtype_a = args.dtype_a + dtype_b = args.dtype_b trans = 1 if args.mfmaTrans else 0 + scale = 1 if args.scale else 0 ofilename = args.o keepSrc = args.keep ldsLayout = args.lds_layout ldsAccess = args.lds_access + banks = args.banks + mnContig = args.mnContig + mfmaTransLD = args.mfma_trans_load + swizzleVec = args.swizzleVec + padInterval = args.padInterval + padAmount = args.padAmount waveSize = args.wave_size @@ -215,64 +535,89 @@ def main(): warpsPerCTA = args.warpsPerCTA order = args.order + blockedConfig = BlockedConfig(sizePerThread, threadsPerWarp, warpsPerCTA, order) + dotConfig = DotConfig(mfmaNonKDim, kWidth, kGroup, trans, warpsPerCTA) + ldsConfig = LDSConfig(banks, ldsLayout, ldsAccess, mnContig, mfmaTransLD, swizzleVec, kWidth, kWidth, padInterval, + padAmount) + CTAShape = [] if plot_mode == 'blocked': - print(f"Plotting tensor M={M},K={K} with blocked layout:") - print(f"sizePerThread={sizePerThread}", end=" ") - print(f"threadsPerWarp={threadsPerWarp}", end=" ") - print(f"warpsPerCTA={warpsPerCTA}", end=" ") - print(f"order={order}", end=" ") + print(f"Plotting tensor {dim0Name}={dim0},{dim1Name}={dim1} with blocked layout:") + print(f"{sizePerThread=}", end=" ") + print(f"{threadsPerWarp=}", end=" ") + print(f"{warpsPerCTA=}", end=" ") + print(f"{order=}", end=" ") CTAShape.append(sizePerThread[0] * threadsPerWarp[0] * warpsPerCTA[0]) CTAShape.append(sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1]) + print(f"CTAShape={CTAShape}") + assert dim0 != 0 and CTAShape[0] <= dim0 and dim0 % CTAShape[0] == 0, "bad tensor dimension " + dim0Name + assert dim1 != 0 and CTAShape[1] <= dim1 and dim1 % CTAShape[1] == 0, "bad tensor dimension " + dim1Name if plot_mode == 'dot': - mfma_inst_str = "mfma_32x32" if mfmaNonKDim == 32 else "mfma_16x16" - mfma_trans_str = ".trans" if trans else "" - print(f"Plotting dot operation with shapes M={M},N={N},K={K}") - print("MFMA: " + mfma_inst_str + mfma_trans_str + f" kWidth = {kpack}", end=" ") - print(f"warpsPerCTA={warpsPerCTA}", end=" ") CTAShape.append(mfmaNonKDim * warpsPerCTA[0]) CTAShape.append(mfmaNonKDim * warpsPerCTA[1]) - - if plot_mode == 'blocked' or plot_mode == 'dot': - print(f"CTAShape={CTAShape}") + print(f"Plotting dot operation with shapes=M{M}-N{N}-K{K},{kWidth=},{kGroup=},{warpsPerCTA=},{CTAShape=}") assert M != 0 and CTAShape[0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M" - - if plot_mode == 'blocked': - assert K != 0 and CTAShape[1] <= K and K % CTAShape[1] == 0, "bad tensor dimension K" - - if plot_mode == 'dot': assert N != 0 and CTAShape[1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N" - assert K != 0 and K % (2 * kpack) == 0, "bad tensor dimension K" + if isMixedPrecBtwF8AndF4OrF6(dtype_a, dtype_b): + ## In the case of mixed precision between 8-bit and 4 or 6-bit, + ## ignore kWidth and kGroup since inA and inB have different kWidth and kGroup values + kDim = 128 + assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" + kpack = 1 + CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] + BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] + scale_str = 'scale_' if scale else '' + mfma_inst_str = f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4" + isMixed864 = True + plot_scale = scale + else: + kDim = kWidth * kGroup * 64 / mfmaNonKDim + assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" + mfma_inst_str, kpack, CBSZ, BLGP, plot_scale = checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, + dtype_b, trans, scale) + isMixed864 = False + flag = '' if CBSZ == -1 else f" with {CBSZ=},{BLGP=}" + scale_info = " (scale is not supported hence ignored)" if (scale and not plot_scale) else '' + print(f"MFMA: {mfma_inst_str} x {kpack}{flag}{scale_info}", end="") + mfma_inst_str = mfma_inst_str.replace("_", "\\_") + mfma_inst_str = mfma_inst_str + flag + if kpack == 2: + mfma_inst_str = mfma_inst_str + " $\\times$ 2" + if ((dtype_a == 'fp16' or dtype_a == 'bf16') and kWidth == 8) or (dtype_a == 'i8' and kWidth == 16): + kDim = 64 / mfmaNonKDim * kWidth / 2 + outType = "i32" if dtype_a == 'i8' else "f32" + old_instr = f"mfma_{outType}_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtype_a}" + print(f" or {old_instr} x 2") + old_instr = old_instr.replace("_", "\\_") + mfma_inst_str = mfma_inst_str + " or\\\\" + old_instr + "$\\times$2" + else: + print("") if plot_mode == 'lds': - print(f"Plotting LDS access for tensor M={M},K={K} with vec={kpack}") - if ldsAccess == 'write': - print(f"sizePerThread={sizePerThread}, threadsPerWarp={threadsPerWarp}") + print(f"Plotting LDS access for tensor {dim0}x{dim1} with vec={kWidth}") with open("myplot.tex", 'w') as f_plot: - with open("tikzplot.tex") as file: - tikz_code = file.read() - - preamble_str = draw_preamble_cmd() - - draw_blockedLayout_str = draw_blocked_layout_cmd(M, K, sizePerThread, threadsPerWarp, warpsPerCTA, order) - - draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, mfmaNonKDim, warpsPerCTA, trans, kpack) - - draw_lds_str = draw_lds_access_cmd(M, K, kpack, ldsLayout, ldsAccess, sizePerThread, threadsPerWarp) - - draw_wmma_str = draw_wmma_instr_cmd(waveSize) + with open("preamble.tex") as file: + preamble = file.read() - f_plot.write(preamble_str + "\n") - f_plot.write(tikz_code) + f_plot.write(preamble) if plot_mode == 'blocked': + draw_blockedLayout_str = draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, blockedConfig) + f_plot.write("\input{blockedLayout}\n") f_plot.write(draw_blockedLayout_str) elif plot_mode == 'dot': + draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, dtype_a, dtype_b, mfma_inst_str, isMixed864, plot_scale, + dotConfig) + f_plot.write("\input{dotLayout}\n") f_plot.write(draw_dotLayout_str) elif plot_mode == 'lds': + draw_lds_str = draw_lds_access_cmd(dim0, dim1, dtype_a, mfmaNonKDim, ldsConfig) + f_plot.write("\input{ldsLayout}\n") f_plot.write(draw_lds_str) elif plot_mode == 'wmma': + draw_wmma_str = draw_wmma_instr_cmd(waveSize) + f_plot.write("\input{wmmaLayout}\n") f_plot.write(draw_wmma_str) run_bash_command(f"pdflatex -jobname {ofilename} myplot.tex") diff --git a/python/perf-kernels/tools/plot-layout/preamble.tex b/python/perf-kernels/tools/plot-layout/preamble.tex new file mode 100644 index 000000000000..b016b1123391 --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/preamble.tex @@ -0,0 +1,31 @@ +\documentclass[tikz, border=1mm, dvipsnames, x11names]{standalone} +\usepackage{ifthen} +\usepackage{tikz} +\usetikzlibrary{arrows.meta,arrows} +\usetikzlibrary{intersections} +\usetikzlibrary{calc, quotes} +\usetikzlibrary{patterns} +\usepackage{xparse} +\usepackage{libertinus} +\definecolor{RoyalPurple}{HTML}{CC79A7} +\definecolor{CrimsonRed}{HTML}{D41159} +\definecolor{Gold}{HTML}{F1C40F} +\definecolor{DeepViolet}{HTML}{7E3F8F} +\newcommand{\Colors}{{ + "SkyBlue", + "orange", + "ForestGreen", + "RoyalPurple", + "CrimsonRed", + "teal", + "Gold", + "DeepViolet", + "cyan", + "purple", + "gray", + "Green", + "BlueGreen", + "violet", + "olive", + "darkgray", + }} diff --git a/python/perf-kernels/tools/plot-layout/tikzplot.tex b/python/perf-kernels/tools/plot-layout/tikzplot.tex deleted file mode 100644 index d8441b042f02..000000000000 --- a/python/perf-kernels/tools/plot-layout/tikzplot.tex +++ /dev/null @@ -1,880 +0,0 @@ -\newcommand{\drawBlockedWave}[5]{ - %% - %% Draw a wave coverage with blocked layout - %% - %% Wave TL: pre defined top-left coordinate of the wave - %% \elem: pre defined variable - %% - %% #1: sizePerThread[0] --> sizePerThreadM - %% #2: sizePerThread[1] --> sizePerThreadN - %% #3: threadsPerWarp[0] --> threadsPerWarpM - %% #4: threadsPerWarp[1] --> threadsPerWarpN - %% #5: fastest changing dim --> order - - \pgfmathsetmacro{\sizePerThreadM}{#1} - \pgfmathsetmacro{\sizePerThreadN}{#2} - \pgfmathsetmacro{\threadsPerWarpM}{#3} - \pgfmathsetmacro{\threadsPerWarpN}{#4} - \pgfmathsetmacro{\order}{#5} - - \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} - \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} - - \foreach \tid in {0,...,63}{ - \pgfmathsetmacro{\tidM}{int(\tid/\threadsPerWarpN)} - \pgfmathsetmacro{\tidN}{mod(\tid,\threadsPerWarpN)} - \coordinate (Thread TL) at ($(Wave TL)+(\tidN*\sizePerThreadN*\elem, -\tidM*\sizePerThreadM*\elem)$); - \pgfmathsetmacro{\ratio}{\tidM*10} - - \ifthenelse{\tid = 0}{ - \draw [line width = 0.01mm, fill=red] (Thread TL) - rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); - }{ - \draw [line width = 0.01mm, fill=blue!\ratio!white] (Thread TL) - rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); - } - } - \draw (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem); -} - -\newcommand{\drawBlockedCTA}[7]{ - %% - %% Draw a CTA coverage with blocked layout - %% - %% CTA TL: pre defined top-left coordinate of the CTA - %% \elem: pre defined variable - %% - %% #1: sizePerThread[0] --> sizePerThreadM - %% #2: sizePerThread[1] --> sizePerThreadN - %% #3: threadsPerWarp[0] --> threadsPerWarpM - %% #4: threadsPerWarp[1] --> threadsPerWarpN - %% #5: warpsPerCTA[0] --> warpsPerCTAM - %% #6: warpsPerCTA[1] --> warpsPerCTAN - %% #7: fastest changing dim --> order - - \pgfmathsetmacro{\sizePerThreadM}{#1} - \pgfmathsetmacro{\sizePerThreadN}{#2} - \pgfmathsetmacro{\threadsPerWarpM}{#3} - \pgfmathsetmacro{\threadsPerWarpN}{#4} - \pgfmathsetmacro{\warpsPerCTAM}{#5} - \pgfmathsetmacro{\warpsPerCTAN}{#6} - \pgfmathsetmacro{\order}{#7} - - \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} - \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} - \pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM} - \pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN} - - \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM*\warpsPerCTAN-1} - - \coordinate (Wave TL) at (CTA TL); - \drawBlockedWave{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\order} - \foreach \waveId in {0,...,\maxWaveId}{ - \ifthenelse{\order=1} - { - \pgfmathsetmacro{\waveCoordM}{int(\waveId/\warpsPerCTAN)} - \pgfmathsetmacro{\waveCoordN}{mod(\waveId,\warpsPerCTAN)} - \pgfmathsetmacro{\rot}{0} - }{ - \pgfmathsetmacro{\waveCoordM}{mod(\waveId,\warpsPerCTAM)} - \pgfmathsetmacro{\waveCoordN}{int(\waveId/\warpsPerCTAM)} - \pgfmathsetmacro{\rot}{90} - } - - \coordinate (Wave TL) at ($(CTA TL)+(\waveCoordN*\waveSizeN*\elem, -\waveCoordM*\waveSizeM*\elem)$); - \draw [ultra thin] (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem) - node [pos=.5, scale=.6*\scale, inner sep=0, fill=white, rotate=\rot] {wave\waveId}; - } - - \draw [thick] (CTA TL) rectangle ++(\CTASizeN*\elem, -\CTASizeM*\elem); -} - -\newcommand{\drawBlockedTensor}[8]{ - %% - %% Draw a tensor with blocked layout of the following parameters - %% sizePerThread[2] - %% threadsPerWarp[2] - %% warpsPerCTA[2] - %% order[2] - %% - %% TL: pre defined top-left coordinate of the tensor - %% \elem: pre defined variable - %% - %% #1: tensorShape[0] --> M - %% #2: tensorShape[1] --> N - %% #3: sizePerThread[0] --> sizePerThreadM - %% #4: sizePerThread[1] --> sizePerThreadN - %% #5: threadsPerWarp[0] --> threadsPerWarpM - %% Note that threadsPerWarp[1] is calculated by 64/threadsPerWarp[0] - %% #6: warpsPerCTA[0] --> warpsPerCTAM - %% #7: warpsPerCTA[1] --> warpsPerCTAN - %% #8: fastest changing dim --> order - - \pgfmathsetmacro{\M}{#1} - \pgfmathsetmacro{\N}{#2} - \pgfmathsetmacro{\sizePerThreadM}{#3} - \pgfmathsetmacro{\sizePerThreadN}{#4} - \pgfmathsetmacro{\threadsPerWarpM}{#5} - \pgfmathsetmacro{\warpsPerCTAM}{#6} - \pgfmathsetmacro{\warpsPerCTAN}{#7} - \pgfmathsetmacro{\order}{#8} - - \pgfmathsetmacro{\threadsPerWarpN}{64/\threadsPerWarpM} - \pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM} - \pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN} - \pgfmathsetmacro{\CTARepM}{\M/\CTASizeM} - \pgfmathsetmacro{\CTARepN}{\N/\CTASizeN} - \pgfmathsetmacro{\maxCTAId}{\CTARepM*\CTARepN-1} - - \foreach \ctaId in {0,...,\maxCTAId}{ - \pgfmathsetmacro{\ctaCoordM}{int(\ctaId/\CTARepN)} - \pgfmathsetmacro{\ctaCoordN}{mod(\ctaId,\CTARepN)} - \coordinate (CTA TL) at ($(TL)+(\ctaCoordN*\CTASizeN*\elem, -\ctaCoordM*\CTASizeM*\elem)$); - \drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order} - } - - \node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {M=\M}; - \node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {K=\N}; - - \def\zoomR{1.5} - \coordinate (zoomin BL) at ($(TL)+(0, .3)$); - - \foreach \hl in {0,...,\sizePerThreadM}{ - \draw ($(zoomin BL)+(0, \hl*\elem*\zoomR)$) -- ++(\sizePerThreadN*\elem*\zoomR,0); - } - \foreach \vl in {0,...,\sizePerThreadN}{ - \draw ($(zoomin BL)+(\vl*\elem*\zoomR, 0)$) -- ++(0, \sizePerThreadM*\elem*\zoomR); - } - - \node [scale=.6*\scale, left] at ($(zoomin BL)+(0, .5*\sizePerThreadM*\elem*\zoomR)$) {$t_0$}; - \node [scale=.6*\scale, right] at ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, .5*\sizePerThreadM*\elem*\zoomR)$) {\sizePerThreadM$\times$\sizePerThreadN}; - - \draw [densely dotted] (TL) -- (zoomin BL); - \draw [densely dotted] ($(TL)+(\sizePerThreadN*\elem, 0)$) -- ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, 0)$); - \draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem); -} - -\newcommand{\drawBlockMFMALayoutLarge}[3]{ - %% - %% Draw a single block of MFMA_32x32x8xf16 or MFMA_16x16x16xf16 - %% - %% block TL: pre-defined top-left coordinate of the block - %% \elem: pre defined variable - %% - %% #1: 1 for mfma.trans, 0 for normal mfma - %% #2: mfmaNonKDim - %% #3: verbose. 1 means draw tid in each vec; 0 means draw nothing - - \pgfmathsetmacro{\trans}{#1} - \pgfmathsetmacro{\nonTrans}{1-#1} - \pgfmathsetmacro{\nonKDim}{#2} - \pgfmathsetmacro{\maxTID}{\nonKDim-1} - \pgfmathsetmacro{\groups}{64/\nonKDim} - \pgfmathsetmacro{\maxGID}{\groups-1} - \pgfmathsetmacro{\maxIVec}{\nonKDim*\nonKDim/256-1} - \pgfmathsetmacro{\verbose}{#3} - \foreach \iVec in {0,...,\maxIVec} { - \coordinate (wave TL) at ($(block TL)+(\trans*\iVec*\groups*4*\elem, -\nonTrans*\iVec*\groups*4*\elem)$); - \foreach \tg in {0,...,\maxGID}{ - \pgfmathsetmacro{\colID}{\tg+4} - \pgfmathsetmacro{\col}{\Colors[\colID]} - \foreach \tid in {0,...,\maxTID} { - \pgfmathsetmacro{\ratio}{\tid*2.5*\groups+15} - \ifthenelse{\verbose=0}{ - \draw [line width=0.005mm, fill=\col!\ratio!white] - ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) - rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem); - }{ - \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} - \draw [line width=0.005mm, fill=\col!\ratio!white] - ($(wave TL)+(\nonTrans*\tid*\elem+\tg*\trans*4*\elem, -\trans*\tid*\elem-\tg*\nonTrans*4*\elem)$) - rectangle ++(\nonTrans*\elem+\trans*4*\elem, -\nonTrans*4*\elem-\trans*\elem) - node [pos=.5, scale=.35*\scale, rotate=90*\nonTrans] {t\drawTid}; - } - } - } - } - \draw [thick] (block TL) rectangle ++(\nonKDim*\elem, -\nonKDim*\elem); -} - - -\newcommand{\drawTensorMFMALayout}[6]{ - %% - %% Draw a tensor with mfma layout. - %% - %% C TL: pre defined top-left coordinates of the tensor - %% - %% #1: M - %% #2: N - %% #3: MFMA nonKDim - %% #4: warpsPerCTA[0] - %% #5: warpsPerCTA[1] - %% #6: 1 for mfma.trans, 0 for normal mfma - - \pgfmathsetmacro{\tensorShapeH}{#1} - \pgfmathsetmacro{\tensorShapeW}{#2} - \pgfmathsetmacro{\mfmaNonKDim}{#3} - \pgfmathsetmacro{\warpsPerCTAH}{#4} - \pgfmathsetmacro{\warpsPerCTAW}{#5} - \pgfmathsetmacro{\mfmaTrans}{#6} - - \coordinate (old TL) at (TL); - \coordinate (TL) at (C TL); - - - \pgfmathsetmacro{\CTARepH}{\tensorShapeH/\mfmaNonKDim/\warpsPerCTAH} - \pgfmathsetmacro{\CTARepW}{\tensorShapeW/\mfmaNonKDim/\warpsPerCTAW} - \pgfmathsetmacro{\maxCTAId}{\CTARepH*\CTARepW-1} - \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAH*\warpsPerCTAW-1} - \pgfmathsetmacro{\CTASizeH}{\warpsPerCTAH*\mfmaNonKDim} - \pgfmathsetmacro{\CTASizeW}{\warpsPerCTAW*\mfmaNonKDim} - - - \foreach \ctaId in {0,...,\maxCTAId}{ - \pgfmathsetmacro{\ctaCoordH}{int(\ctaId/\CTARepW)} - \pgfmathsetmacro{\ctaCoordW}{mod(\ctaId,\CTARepW)} - \coordinate (CTA TL) at ($(TL)+(\ctaCoordW*\CTASizeW*\elem, -\ctaCoordH*\CTASizeH*\elem)$); - %% Draw a detailed view of wave0 in each CTA - \coordinate (block TL) at (CTA TL); - \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{0} - - \foreach \waveId in {0,...,\maxWaveId}{ - \pgfmathsetmacro{\waveCoordH}{int(\waveId/\warpsPerCTAW)} - \pgfmathsetmacro{\waveCoordW}{mod(\waveId,\warpsPerCTAW)} - \coordinate (block TL) at ($(CTA TL)+(\waveCoordW*\mfmaNonKDim*\elem, -\waveCoordH*\mfmaNonKDim*\elem)$); - %% Inside the loop, only draw a rectangle - \draw [ultra thin] (block TL) rectangle ++(\mfmaNonKDim*\elem, -\mfmaNonKDim*\elem) - node [scale=.7*\mfmaNonKDim/32*\scale, pos=.5, fill=white, inner sep=0] {wave\waveId}; - } - - %% Draw the outline of each CTA rep - \draw [ultra thick] (CTA TL) rectangle ++(\CTASizeW*\elem, -\CTASizeH*\elem); - } - - \coordinate (TL) at (old TL); -} - -\newcommand{\drawMFMAOperand}[4]{ - %% - %% Draw one mfma operand - %% - %% mfma op TL: pre defined coordinates of the top-left - %% \elem: pre defined variable - %% - %% #1: mfmNonKDim - %% #2: kpack - %% #3: 0 for opA and 1 for opB - %% #4: verbose. 1 means draw tid in each vec; 0 means draw nothing - - \pgfmathsetmacro{\nonKDim}{#1} - \pgfmathsetmacro{\maxGID}{64/\nonKDim-1} - \pgfmathsetmacro{\maxTID}{\nonKDim-1} - \pgfmathsetmacro{\kpack}{#2} - \pgfmathsetmacro{\opIdxA}{#3} - \pgfmathsetmacro{\opIdxB}{1-\opIdxA} - \pgfmathsetmacro{\verbose}{#4} - - \foreach \col/\tg in {0,...,\maxGID}{ - \pgfmathsetmacro{\col}{\Colors[\tg]} - \foreach \tid in {0,...,\maxTID} { - % \pgfmathsetmacro{\ratio}{\tid*2.5+15} - \ifthenelse{\verbose=0}{ - \draw [line width=0.005mm, fill=\col] - ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) - rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA); - }{ - \pgfmathsetmacro{\drawTid}{int(\tid+\tg*\nonKDim)} - \draw [line width=0.005mm, fill=\col] - ($(mfma op TL)+(\tg*\kpack*\elem*\opIdxB+\tid*\elem*\opIdxA, -\tid*\elem*\opIdxB-\tg*\kpack*\elem*\opIdxA)$) - rectangle ++(\kpack*\elem*\opIdxB + \elem*\opIdxA, -\elem*\opIdxB-\kpack*\elem*\opIdxA) - node [pos=.5, scale=.35*\scale, rotate=90*\opIdxA] {t\drawTid}; - } - } - } -} - -\newcommand{\drawWaveOperand}[4]{ - %% - %% Draw the part of the tensor that is one operand of the wave - %% - %% Op TL: pre defined coordinates of the top-left of the operand - %% \elem: pre defined variable - %% - %% #1: K - %% #2: mfmNonKDim - %% #3: kpack - %% #4: 0 for opA and 1 for opB - - \pgfmathsetmacro{\K}{#1} - \pgfmathsetmacro{\nonKDim}{#2} - \pgfmathsetmacro{\groups}{64/\nonKDim} - \pgfmathsetmacro{\kpack}{#3} - \pgfmathsetmacro{\opIdx}{#4} - \pgfmathsetmacro{\opIdxOther}{1-\opIdx} - - \coordinate (TL) at (Op TL); - - \pgfmathsetmacro{\numKRep}{\K/\kpack/\groups} - \pgfmathsetmacro{\maxKRepId}{\numKRep-1} - - \foreach \repId in {0,...,\maxKRepId}{ - \coordinate (mfma op TL) at ($(TL)+(\repId*\groups*\kpack*\elem*\opIdxOther, -\repId*\groups*\kpack*\elem*\opIdx)$); - \drawMFMAOperand{\nonKDim}{\kpack}{\opIdx}{0} - \draw [thick] (mfma op TL) rectangle - ++(\groups*\kpack*\elem*\opIdxOther+\nonKDim*\opIdx*\elem, -\nonKDim*\opIdxOther*\elem-\groups*\kpack*\elem*\opIdx); - } -} - -\newcommand{\drawDotOperands}[7]{ - %% - %% Draw operand tensors of dot - %% - %% A TL and B TL: pre defined top-left coordinates of A and B tensor - %% \elem: pre defined variable - %% - %% #1: M - %% #2: N - %% #3: K - %% #4: MFMA nonKDim - %% #5: warpsPerCTA[0] - %% #6: warpsPerCTA[1] - %% #7: kpack - - \pgfmathsetmacro{\M}{#1} - \pgfmathsetmacro{\N}{#2} - \pgfmathsetmacro{\K}{#3} - \pgfmathsetmacro{\mfmaNonKDim}{#4} - \pgfmathsetmacro{\warpsPerCTAM}{#5} - \pgfmathsetmacro{\warpsPerCTAN}{#6} - \pgfmathsetmacro{\kpack}{#7} - - %% operand A - \pgfmathsetmacro{\CTARepM}{\M/\warpsPerCTAM/\mfmaNonKDim} - \pgfmathsetmacro{\maxCTAIdM}{\CTARepM-1} - \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM-1} - \foreach \ctaId in {0,...,\maxCTAIdM}{ - \coordinate (CTA TL) at ($(A TL)+(0, -\ctaId*\warpsPerCTAM*\mfmaNonKDim*\elem)$); - \foreach \waveId in {0,...,\maxWaveId}{ - \coordinate (wave TL) at ($(CTA TL)+(0, -\waveId*\mfmaNonKDim*\elem)$); - \draw [ultra thin] (wave TL) rectangle ++(\K*\elem, -\mfmaNonKDim*\elem); - } - %% Only draw the detailed view of the first wave in CTA - \coordinate (Op TL) at (CTA TL); - \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{0} - - %% Draw the outline of each CTA rep - \draw [ultra thick] (CTA TL) rectangle ++(\K*\elem, -\warpsPerCTAM*\mfmaNonKDim*\elem); - } - \draw [ultra thin] (A TL) rectangle ++(\K*\elem, -\M*\elem); - - - %% operand B - \pgfmathsetmacro{\CTARepN}{\N/\warpsPerCTAN/\mfmaNonKDim} - \pgfmathsetmacro{\maxCTAIdN}{\CTARepN-1} - \pgfmathsetmacro{\maxWaveId}{\warpsPerCTAN-1} - \foreach \ctaId in {0,...,\maxCTAIdN}{ - \coordinate (CTA TL) at ($(B TL)+(\ctaId*\warpsPerCTAN*\mfmaNonKDim*\elem, 0)$); - \foreach \waveId in {0,...,\maxWaveId}{ - \coordinate (wave TL) at ($(CTA TL)+(\waveId*\mfmaNonKDim*\elem ,0)$); - \draw [ultra thin] (wave TL) rectangle ++(\mfmaNonKDim*\elem, -\K*\elem); - } - %% Only draw the detailed view of the first wave in CTA - \coordinate (Op TL) at (CTA TL); - \drawWaveOperand{\K}{\mfmaNonKDim}{\kpack}{1} - - %% Draw the outline of each CTA rep - \draw [ultra thick] (CTA TL) rectangle ++(\warpsPerCTAN*\mfmaNonKDim*\elem, -\K*\elem); - } - \draw [ultra thin] (B TL) rectangle ++(\N*\elem, -\K*\elem); -} - - -\newcommand{\drawDot}[8]{ - %% - %% Draw C = dot A, B - %% - %% C TL: pre defined top-left coordinates of the result tensor - %% \elem: pre defined variable - %% - %% #1: M - %% #2: N - %% #3: K - %% #4: MFMA nonKDim - %% #5: warpsPerCTA[0] - %% #6: warpsPerCTA[1] - %% #7: 1 for mfma.trans, 0 for normal mfma - %% #8: kpack - - \pgfmathsetmacro{\M}{#1} - \pgfmathsetmacro{\N}{#2} - \pgfmathsetmacro{\K}{#3} - \pgfmathsetmacro{\mfmaNonKDim}{#4} - \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} - \pgfmathsetmacro{\warpsPerCTAM}{#5} - \pgfmathsetmacro{\warpsPerCTAN}{#6} - \pgfmathsetmacro{\mfmaTrans}{#7} - \pgfmathsetmacro{\kpack}{#8} - \pgfmathsetmacro{\kdim}{int(\groups*\kpack)} - - \pgfmathsetmacro{\gap}{\elem*20} - \coordinate (A TL) at ($(C TL)+(-\gap-\K*\elem, 0)$); - \coordinate (B TL) at ($(C TL)+(0, \gap+\K*\elem)$); - - \drawDotOperands{\M}{\N}{\K}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\kpack} - - \drawTensorMFMALayout{\M}{\N}{\mfmaNonKDim}{\warpsPerCTAM}{\warpsPerCTAN}{\mfmaTrans} - - %% Draw labels - \node [scale=\scale, above] at ($(A TL)+(.5*\K*\elem, 0)$) {K=\K}; - \node [scale=\scale, above, rotate=90] at ($(A TL)+(0, -.5*\M*\elem)$) {M=\M}; - - \node [scale=\scale, above, rotate=90] at ($(B TL)+(0, -.5*\K*\elem)$) {K=\K}; - \node [scale=\scale, above] at ($(B TL)+(.5*\N*\elem, 0)$) {N=\N}; - - \node [scale=\scale, above left] at (A TL) {A}; - \node [scale=\scale, above left] at (B TL) {B}; - \node [scale=\scale, above left] at (C TL) {C}; - - %% label nonKDim - \node [scale=.8*\scale, left] at ($(A TL)+(0, -.5*\mfmaNonKDim*\elem)$) {\mfmaNonKDim}; - \node [scale=.8*\scale, above] at ($(B TL)+(.5*\mfmaNonKDim*\elem, 0)$) {\mfmaNonKDim}; - %% label kpack - \node [scale=.8*\scale, above] at ($(A TL)+(0.5*\groups*\kpack*\elem, 0)$) {\kdim}; - \node [scale=.8*\scale, left] at ($(B TL)+(0, -0.5*\groups\kpack*\elem)$) {\kdim}; -} - -\newcommand{\Colors}{{ - "red", - "YellowGreen", - "blue", - "Maroon", - "orange", - "cyan", - "magenta", - "brown", - "teal", - "purple", - "gray", - "Green", - "BlueGreen", - "violet", - "olive", - "darkgray", - }} - -\newcommand{\drawTensorLayoutGlobalMem}{ - %% - %% Draw tensor layout in global memory without any swizzling - %% - %% TL: pre defined top-left coordinates of the tensor in global memory - %% \elem: per defined variable - %% \Colors: a pre defined array of 16 colors - %% - %% The following arguments are also expected to be pre defined - %% #1: M - %% #2: K - %% #3: vec: number of elements in a group - - \pgfmathsetmacro{\numVecK}{\K/\vec} - \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} - \pgfmathsetmacro{\drawM}{20} - - %% Draw the tensor, but only draw 32 rows - \draw (TL) rectangle ++(\K*\elem, -\drawM*\elem); - %% Draw detailed vec view of the tensor - \foreach \vecId in {0,...,\maxVecId}{ - - \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} - \pgfmathsetmacro{\vecCoordK}{mod(\vecId,\numVecK)} - \coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$); - - \pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))} - \pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)} - \pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]} - \pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40} - - \draw [ultra thin, fill=\vecColor!\ratio!white] (vec TL) rectangle ++(\vec*\elem, -\elem) - node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - - } - %% M and K dim - \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem-8*\elem)$) {M=\M}; - \node [scale=.8*\scale, left] at ($(TL)+(0, -.5*16*\elem)$) {16}; - \node [scale=\scale, above] at ($(TL)+(.5*\K*\elem, 0)$) {K=\K}; - %% label for vecSize - \def\vecR{1.5} - \coordinate (vec TL) at ($(TL)+(-.25*\vec*\elem, 3*\elem*\vecR)$); - \pgfmathsetmacro{\maxVec}{\vec-1} - \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (TL) -- ($(vec TL)+(0, -\elem*\vecR)$); - \draw [densely dotted] ($(TL)+(\vec*\elem, 0)$) -- ($(vec TL)+(\vec*\elem*\vecR, -\elem*\vecR)$); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\vec*\elem*\vecR, 0)$) {vec=\vec}; -} - - - -\newcommand{\drawLDSLayoutTritonSwizzling}[2]{ - %% - %% Draw tensor layout in LDS with swizzling - %% - %% TL: pre defined top-left coordinates of the tensor in global memory - %% \elem: per defined variable - %% \Colors: a pre defined array of 16 colors - %% - %% The following three arguments are expected to be pre defined - %% #1: M - %% #2: K - %% #3: vec: number of elements in a group - %% - %% #1: hasSwizzle, 0 means no swizzling and no padding, - %% 1 means optimal swizzling - %% 2 means padding - %% #2: access mode, 0 means draw nothing, 1 means ds_read, 2 means ds_write - %% For ds_write access, the following variables are assumed to be pre defined - %% \sizePerThreadK - %% \sizePerThreadM - %% \threadsPerWarpK - - \pgfmathsetmacro{\hasSwizzle}{#1} - \pgfmathsetmacro{\accessMode}{#2} - \pgfmathsetmacro{\numVecK}{\K/\vec} - - %% Assuming fp16 data type - \pgfmathsetmacro{\LDSK}{64} - \pgfmathsetmacro{\numLDSVec}{\LDSK/\vec} - \pgfmathsetmacro{\swizzleK}{max(\LDSK, \K)} - \pgfmathsetmacro{\LDSM}{int(\M/\LDSK*\K)} - - \ifthenelse{\accessMode = 2}{ - %% \accessMode == 2, draw 8 rows - \pgfmathsetmacro{\maxVecId}{8*\numVecK-1} - \pgfmathsetmacro{\drawM}{8*\K/\LDSK+4} - }{ - %% \accessMode == 0 or 1, draw 16 rows - \pgfmathsetmacro{\maxVecId}{16*\numVecK-1} - \pgfmathsetmacro{\drawM}{16*\K/\LDSK+4} - } - - %% Parameters used for swizzling - \pgfmathsetmacro{\numVecSwizzleK}{\swizzleK/\vec} - %% perPhase = ceil(LDSK / K) - %% The number of the rows of the tensor that can share the same swizzling pattern - \pgfmathsetmacro{\perPhase}{ceil(\LDSK/\K)} - %% maxPhase: the total number of different swizzling patterns - \ifthenelse{\hasSwizzle=0}{ - %% When swizzling is disabled - \pgfmathsetmacro{\maxPhase}{1} - }{ - %% When vec is small enough, we want 16/perPhase different swizzling patterns - %% When vec is large, we can only have 64 / \vec different swizzling pattern at most - \pgfmathsetmacro{\maxPhase}{min(16/\perPhase,64/\vec)} - } - - %% Draw the LDS - \draw (TL) rectangle ++(\LDSK*\elem, -\drawM*\elem); - - %% Draw detailed vec view of LDS - \foreach \vecId in {0,...,\maxVecId}{ - \pgfmathsetmacro{\vecCoordM}{int(\vecId/\numVecK)} - \pgfmathsetmacro{\vecCoordK}{int(mod(\vecId,\numVecK))} - \pgfmathsetmacro{\rawPhase}{floor(\vecId/\numVecSwizzleK)} - %% vec color - \pgfmathsetmacro{\colorIdxK}{int(mod(\vecCoordK,16))} - \pgfmathsetmacro{\colorIdxM}{mod(\vecCoordM,16)} - \pgfmathsetmacro{\ratio}{100-floor(\vecCoordK/16)*40} - \pgfmathsetmacro{\vecColor}{\Colors[\colorIdxK]} - - %% old vec coordinates - \coordinate (vec TL) at ($(TL)+(\vecCoordK*\vec*\elem, -\vecCoordM*\elem)$); - - %% new vec coordinates in LDS by swizzling - %% The following two conditions correspond to the relation between \LDSK and \K - \ifthenelse{\LDSK < \K}{ - \pgfmathsetmacro{\vecLDSM}{\vecCoordM*\K/\LDSK+floor(\vecCoordK*\vec/\LDSK)} - \pgfmathsetmacro{\vecLDSK}{int(mod(\vecCoordK, \LDSK/\vec))} - }{ - \pgfmathsetmacro{\vecLDSM}{floor(\vecCoordM/\perPhase)} - \pgfmathsetmacro{\vecLDSK}{int(\vecCoordK+mod(\vecCoordM,\perPhase)*\numVecK)} - } - %% - \pgfmathsetmacro{\phase}{int(mod(\rawPhase, \maxPhase))} - %% Compute the swizzled col id - \pgfmathsetmacro{\vecLDSKSwizzled}{\bitwiseXor{\vecLDSK}{\phase}} - - %% new vec coordinates in LDS by padding - \pgfmathsetmacro{\numPads}{floor(\vecId/\numLDSVec)} - \pgfmathsetmacro{\bankId}{\vec/2*\vecId+\numPads} - \pgfmathsetmacro{\vecPadM}{int(\bankId/32)} - \pgfmathsetmacro{\vecPadK}{int(mod(\bankId,32))} - - \ifthenelse{\hasSwizzle = 2}{ - %% vec coordinates by padding - \coordinate (new vec TL) at ($(TL)+(\vecPadK*2*\elem, -\vecPadM*\elem)$); - \pgfmathsetmacro{\tailBankId}{int(\vecPadK+\vec/2-1)} - }{ - %% vec coordinates by swizzling - \coordinate (new vec TL) at ($(TL)+(\vecLDSKSwizzled*\vec*\elem, -\vecLDSM*\elem)$); - \pgfmathsetmacro{\tailBankId}{0} - } - - \ifthenelse{\hasSwizzle = 2 \AND \tailBankId > 31}{ - \pgfmathsetmacro{\nextBanks}{\tailBankId-31} - \pgfmathsetmacro{\leftBanks}{\vec/2 - \nextBanks} - \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\leftBanks*2*\elem, -\elem) - node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - \draw [ultra thin, fill=\vecColor!\ratio!white] ($(TL)+(0, -\vecPadM*\elem-\elem)$) - rectangle ++(\nextBanks*2*\elem, -\elem) node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - }{ - \draw [ultra thin, fill=\vecColor!\ratio!white] (new vec TL) rectangle ++(\vec*\elem, -\elem) - node [pos=.5, scale=.6*\scale, white] {m\vecCoordM}; - } - - %% ds_read - %% Highlight the elements the first 16 threads access in the first cycle - %% This is used to visualize bank conflicts - \ifthenelse{\accessMode = 1}{ - \ifthenelse{\vecCoordK = 0}{ - \draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem); - \draw (new vec TL) -- ++(\elem, -\elem); - \draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem); - }{} - }{} - - %% Draw ds_write pattern - \ifthenelse{\accessMode = 2}{ - %% First compute the coverage of the first 16 threads - \pgfmathsetmacro{\covK}{min(16, \threadsPerWarpK)*\sizePerThreadK/\vec} - \pgfmathsetmacro{\covM}{ceil(16/\threadsPerWarpK)*\sizePerThreadM} - %% Check conditions for the first 16 threads - \pgfmathsetmacro{\vecInThread}{int(mod(\vecCoordK, \sizePerThreadK/\vec))} - \ifthenelse{\vecInThread=0}{ - \ifthenelse{\vecCoordK<\covK \AND \vecCoordM<\covM}{ - \draw [fill=white] (new vec TL) rectangle ++(\elem, -\elem); - \draw (new vec TL) -- ++(\elem, -\elem); - \draw ($(new vec TL)+(0, -\elem)$) -- ++(\elem, \elem); - }{} - }{} - }{} - - %% Label the phase of each line if swizzling is used - \ifthenelse{\hasSwizzle = 2}{}{ - \pgfmathsetmacro{\lastVecId}{int(64/\vec)-1} - \ifthenelse{\vecLDSKSwizzled = \lastVecId}{ - \draw [ultra thin] ($(new vec TL)+(\vec*\elem, -.5*\elem)$) -- ++(\elem, 0) - node [scale=.6*\scale, right] {\phase}; - }{} - } - } - - %% Draw boundary of 32 banks - %% Assume fp16 data type - \foreach \bank in {0,...,31}{ - \draw [ultra thin, gray] ($(TL)+(\bank*2*\elem, 0)$) -- ++(0, 2*\elem) - node [scale=.6*\scale, right, black] {\bank}; - } - \draw [ultra thin, gray] ($(TL)+(32*2*\elem, 0)$) -- ++(0, 2*\elem); - \node [scale=.6*\scale, left, black] at ($(TL)+(0, 2*\elem)$) {bank id}; - - \node [scale=\scale, above] at ($(TL)+(.5*\LDSK*\elem, 3*\elem)$) {LDS 32 banks}; - \node [scale=\scale, rotate=90, above] at ($(TL)+(0, -.5*\drawM*\elem)$) {LDSM=\LDSM}; - - %% label phase if swizzling is used - \ifthenelse{\hasSwizzle = 2}{}{ - \node [scale=.6*\scale, above right] at($(TL)+(32*2*\elem, 0)$) {phase}; - } -} - -\newcommand{\drawMFMAInstr}[3]{ - %% - %% Draw layout of mfma instructions with tid labeled - %% - %% C TL: pre defined top-left coordinates of the output matrix - %% \elem: pre defined variable - %% - %% #1: mfmaNonKDim - %% #2: kpack - %% #3: mfmaTrans - \pgfmathsetmacro{\mfmaNonKDim}{#1} - \pgfmathsetmacro{\groups}{64/\mfmaNonKDim} - \pgfmathsetmacro{\kpack}{#2} - \pgfmathsetmacro{\mfmaTrans}{#3} - \pgfmathsetmacro{\nonTrans}{1-#3} - - \pgfmathsetmacro{\gap}{\elem*5} - \coordinate (mfma opA TL) at ($(C TL)+(-.5*\gap-1.2*\nonTrans*\gap-\groups*\kpack*\elem, 0)$); - \coordinate (mfma op TL) at (mfma opA TL); - \drawMFMAOperand{\mfmaNonKDim}{\kpack}{0}{1} - \coordinate (mfma op TL) at ($(C TL)+(0, 1.5*\gap+.5*\mfmaTrans*\gap+\groups*\kpack*\elem)$); - \drawMFMAOperand{\mfmaNonKDim}{\kpack}{1}{1} - - \coordinate (block TL) at (C TL); - \drawBlockMFMALayoutLarge{\mfmaTrans}{\mfmaNonKDim}{1} - - %% Draw labels - \def\vecR{1.5} - \coordinate (vec TL) at ($(mfma opA TL)+(-.25*\kpack*\elem, 3*\elem*\vecR)$); - \pgfmathsetmacro{\maxVec}{\kpack-1} - \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (mfma opA TL) -- ($(vec TL)+(0, -\elem*\vecR)$); - \draw [densely dotted] ($(mfma opA TL)+(\kpack*\elem, 0)$) -- ($(vec TL)+(\kpack*\elem*\vecR, -\elem*\vecR)$); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*\kpack*\elem*\vecR, 0)$) {vec=\kpack}; - - \coordinate (vec TL) at ($(mfma op TL)+(-3*\elem*\vecR, .25*\kpack*\elem)$); - \foreach \vecId in {0,...,\maxVec}{ - \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (mfma op TL) -- ($(vec TL)+(\elem*\vecR,0)$); - \draw [densely dotted] ($(mfma op TL)+(0, -\kpack*\elem)$) -- ($(vec TL)+(\elem*\vecR, -\kpack*\elem*\vecR)$); - \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*\kpack*\elem*\vecR)$) {vec=\kpack}; - - \node [scale=\scale, below] at ($(block TL)+(.5*\mfmaNonKDim*\elem,-\mfmaNonKDim*\elem)$) {outC}; - \ifthenelse{\mfmaTrans=0}{ - \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opA}; - \node [scale=\scale, above] at (mfma op TL) {opB}; - \coordinate (vec TL) at ($(block TL)+(-3*\elem-\elem*\vecR, .25*4*\elem)$); - \foreach \vecId in {0,1,2,3}{ - \draw ($(vec TL)+(0, -\vecId*\elem*\vecR)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (block TL) -- ++(-3*\elem, .25*4*\elem); - \draw [densely dotted] ($(block TL)+(0, -4*\elem)$) -- ++(-3*\elem, -.25*4*\elem); - \node [scale=.8*\scale, above, rotate=90] at ($(vec TL)+(0, -.5*4*\elem*\vecR)$) {vec=4}; - \node [scale=.8*\scale, above, align=center] at ($(block TL)+(.5*\mfmaNonKDim*\elem, 0)$) {mfmaLayout\\trans=False}; - }{ - \node [scale=\scale, below] at ($(mfma opA TL)+(\kpack*\elem, -\mfmaNonKDim*\elem)$) {opB}; - \node [scale=\scale, above] at (mfma op TL) {opA}; - \coordinate (vec TL) at ($(block TL)+(-.25*4*\elem, 3*\elem+\elem*\vecR)$); - \foreach \vecId in {0,1,2,3}{ - \draw ($(vec TL)+(\vecId*\elem*\vecR, 0)$) rectangle ++(\elem*\vecR, -\elem*\vecR); - } - \draw [densely dotted] (block TL) -- ++(-.25*4*\elem, 3*\elem); - \draw [densely dotted] ($(block TL)+(4*\elem, 0)$) -- ++(.25*4*\elem, 3*\elem); - \node [scale=.8*\scale, above] at ($(vec TL)+(.5*4*\elem*\vecR, 0)$) {vec=4}; - \node [scale=.8*\scale, above, align=center] at ($(block TL)+(16*\elem, 0)$) {mfmaLayout\\trans=True}; - } -} - -\newcommand{\drawWMMAOperand}[3]{ - %% - %% Draw the layout of one operand of WMMA instruction - %% - %% #1: opIdx. 0 for opA, 1 for opB - %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing - %% #3: mode. 0 for w32, 1 for w64 - %% - %% wmma op TL: pre defined top-left coordinates of the operand matrix - - \pgfmathsetmacro{\isOpB}{#1} - \pgfmathsetmacro{\isOpA}{1-\isOpB} - \pgfmathsetmacro{\verbose}{#2} - \pgfmathsetmacro{\isWLarge}{#3} - - \foreach \row in {0,...,15}{ - \pgfmathsetmacro{\ratio}{\row*5+15} - \coordinate (vec TL) at ($(wmma op TL)+(\row*\isOpB*\elem, -\row*\elem*\isOpA)$); - \ifthenelse{\isWLarge=1}{ - \pgfmathsetmacro{\tidone}{int(\row+16)} - \pgfmathsetmacro{\tidtwo}{int(\row+32)} - \pgfmathsetmacro{\tidthree}{int(\row+48)} - \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) - rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) - node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone, t\tidtwo, t\tidthree}; - }{ - \pgfmathsetmacro{\tidone}{int(\row+16)} - \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) - rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) - node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone}; - } - } -} - -\newcommand{\drawWMMAResult}[2]{ - %% - %% Draw layout of WMMA result tensor - %% - %% #1: verbose. 1 means draw tid in each vec; 0 means draw nothing - %% #2: mode. 0 for w32, 1 for w64 - - \pgfmathsetmacro{\verbose}{#1} - \pgfmathsetmacro{\isWLarge}{#2} - - \pgfmathsetmacro{\numElem}{256} - \pgfmathsetmacro{\maxElemId}{\numElem-1} - - \foreach \elemId in {0,...,\maxElemId}{ - %% figure out the rowID - \pgfmathsetmacro{\rowId}{floor(\elemId/16)} - %% figure out the colID - \pgfmathsetmacro{\colId}{mod(\elemId,16)} - %% figure out the tid and color - \ifthenelse{\isWLarge=1}{ - \pgfmathsetmacro{\tid}{int(mod(\elemId,64))} - \pgfmathsetmacro{\laneId}{mod(\elemId,64)} - }{ - \pgfmathsetmacro{\tid}{int(mod(\elemId,32))} - \pgfmathsetmacro{\laneId}{mod(\elemId,32)} - } - %% figure out the color - \pgfmathsetmacro{\colorId}{floor(\laneId/16)} - \pgfmathsetmacro{\vecColor}{\Colors[\colorId]} - %% Coordinate - \coordinate (vec TL) at ($(C TL)+(\colId*\elem, -\rowId*\elem)$); - \draw [line width=0.005mm, fill=\vecColor!60!white] (vec TL) rectangle ++(\elem, -\elem) - node [scale=.4*\scale, pos=.5] {t\tid}; - } - - -} - -\newcommand{\drawWMMAInstr}[2]{ - %% - %% Draw wmma instruction layouts 16x16x16 - %% - %% #1: mode. 0 for w32, 1 for w64 - %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing - %% - %% C TL: pre defined top-left coordinates of output matrix - %% \elem: pre defined element size - - - \pgfmathsetmacro{\isWLarge}{#1} - \pgfmathsetmacro{\verbose}{#2} - - \pgfmathsetmacro{\gap}{\elem*2} - \coordinate (wmma op TL) at ($(C TL)+(-\gap-16*\elem, 0)$); - \coordinate (wmma opA TL) at (wmma op TL); - \drawWMMAOperand{0}{\verbose}{\isWLarge} - \coordinate (wmma op TL) at ($(C TL)+(0, \gap+16*\elem)$); - \drawWMMAOperand{1}{\verbose}{\isWLarge} - - \drawWMMAResult{1}{\isWLarge} - - %% labels - \pgfmathsetmacro{\gap}{\elem} - \node [above left, scale=\scale] at (wmma opA TL) {A}; - \node [above left, scale=\scale] at (wmma op TL) {B}; - \node [above right, scale=\scale] at ($(C TL)+(16*\elem, 0)$) {C}; - - %% A k dim - \node [scale=.8*\scale] (k dim A) at ($(wmma opA TL)+(8*\elem,\gap)$) {16}; - \draw [->, >=stealth] (k dim A.west) -- ($(wmma opA TL)+(0, \gap)$); - \draw [->, >=stealth] (k dim A.east) -- ($(wmma opA TL)+(16*\elem, \gap)$); - - %% B K dim - \node [scale=.8*\scale, rotate=90] (k dim B) at ($(wmma op TL)+(-\gap, -8*\elem)$) {16}; - \draw [->, >=stealth] (k dim B.east) -- ($(wmma op TL)+(-\gap, 0)$); - \draw [->, >=stealth] (k dim B.west) -- ($(wmma op TL)+(-\gap, -16*\elem)$); - - %% C M dim - \node [scale=.8*\scale] (m dim) at ($(C TL)+(8*\elem,-16*\elem-\gap)$) {16}; - \draw [->, >=stealth] (m dim.west) -- ($(C TL)+(0, -16*\elem-\gap)$); - \draw [->, >=stealth] (m dim.east) -- ($(C TL)+(16*\elem, -16*\elem-\gap)$); - - %% C N dim - \node [scale=.8*\scale, rotate=-90] (n dim) at ($(C TL)+(16*\elem+\gap, -8*\elem)$) {16}; - \draw [->, >=stealth] (n dim.west) -- ($(C TL)+(16*\elem+\gap, 0)$); - \draw [->, >=stealth] (n dim.east) -- ($(C TL)+(16*\elem+\gap, -16*\elem)$); -} diff --git a/python/perf-kernels/tools/plot-layout/wmmaLayout.tex b/python/perf-kernels/tools/plot-layout/wmmaLayout.tex new file mode 100644 index 000000000000..25d459a1d0dd --- /dev/null +++ b/python/perf-kernels/tools/plot-layout/wmmaLayout.tex @@ -0,0 +1,121 @@ +\newcommand{\drawWMMAOperand}[3]{ + %% + %% Draw the layout of one operand of WMMA instruction + %% + %% #1: opIdx. 0 for opA, 1 for opB + %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% #3: mode. 0 for w32, 1 for w64 + %% + %% wmma op TL: pre defined top-left coordinates of the operand matrix + + \pgfmathsetmacro{\isOpB}{#1} + \pgfmathsetmacro{\isOpA}{1-\isOpB} + \pgfmathsetmacro{\verbose}{#2} + \pgfmathsetmacro{\isWLarge}{#3} + + \foreach \row in {0,...,15}{ + \pgfmathsetmacro{\ratio}{\row*5+15} + \coordinate (vec TL) at ($(wmma op TL)+(\row*\isOpB*\elem, -\row*\elem*\isOpA)$); + \ifthenelse{\isWLarge=1}{ + \pgfmathsetmacro{\tidone}{int(\row+16)} + \pgfmathsetmacro{\tidtwo}{int(\row+32)} + \pgfmathsetmacro{\tidthree}{int(\row+48)} + \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) + rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) + node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone, t\tidtwo, t\tidthree}; + }{ + \pgfmathsetmacro{\tidone}{int(\row+16)} + \draw [line width=0.005mm, fill=brown!\ratio!white] (vec TL) + rectangle ++(16*\elem*\isOpA+\elem*\isOpB, -\elem*\isOpA-16*\elem*\isOpB) + node [scale=0.4*\scale, pos=.5, rotate=90*\isOpB] {t\row, t\tidone}; + } + } +} + +\newcommand{\drawWMMAResult}[2]{ + %% + %% Draw layout of WMMA result tensor + %% + %% #1: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% #2: mode. 0 for w32, 1 for w64 + + \pgfmathsetmacro{\verbose}{#1} + \pgfmathsetmacro{\isWLarge}{#2} + + \pgfmathsetmacro{\numElem}{256} + \pgfmathsetmacro{\maxElemId}{\numElem-1} + + \foreach \elemId in {0,...,\maxElemId}{ + %% figure out the rowID + \pgfmathsetmacro{\rowId}{floor(\elemId/16)} + %% figure out the colID + \pgfmathsetmacro{\colId}{mod(\elemId,16)} + %% figure out the tid and color + \ifthenelse{\isWLarge=1}{ + \pgfmathsetmacro{\tid}{int(mod(\elemId,64))} + \pgfmathsetmacro{\laneId}{mod(\elemId,64)} + }{ + \pgfmathsetmacro{\tid}{int(mod(\elemId,32))} + \pgfmathsetmacro{\laneId}{mod(\elemId,32)} + } + %% figure out the color + \pgfmathsetmacro{\colorId}{floor(\laneId/16)} + \pgfmathsetmacro{\vecColor}{\Colors[\colorId]} + %% Coordinate + \coordinate (vec TL) at ($(C TL)+(\colId*\elem, -\rowId*\elem)$); + \draw [line width=0.005mm, fill=\vecColor!60!white] (vec TL) rectangle ++(\elem, -\elem) + node [scale=.4*\scale, pos=.5] {t\tid}; + } + + +} + +\newcommand{\drawWMMAInstr}[2]{ + %% + %% Draw wmma instruction layouts 16x16x16 + %% + %% #1: mode. 0 for w32, 1 for w64 + %% #2: verbose. 1 means draw tid in each vec; 0 means draw nothing + %% + %% C TL: pre defined top-left coordinates of output matrix + %% \elem: pre defined element size + + + \pgfmathsetmacro{\isWLarge}{#1} + \pgfmathsetmacro{\verbose}{#2} + + \pgfmathsetmacro{\gap}{\elem*2} + \coordinate (wmma op TL) at ($(C TL)+(-\gap-16*\elem, 0)$); + \coordinate (wmma opA TL) at (wmma op TL); + \drawWMMAOperand{0}{\verbose}{\isWLarge} + \coordinate (wmma op TL) at ($(C TL)+(0, \gap+16*\elem)$); + \drawWMMAOperand{1}{\verbose}{\isWLarge} + + \drawWMMAResult{1}{\isWLarge} + + %% labels + \pgfmathsetmacro{\gap}{\elem} + \node [above left, scale=\scale] at (wmma opA TL) {A}; + \node [above left, scale=\scale] at (wmma op TL) {B}; + \node [above right, scale=\scale] at ($(C TL)+(16*\elem, 0)$) {C}; + + %% A k dim + \node [scale=.8*\scale] (k dim A) at ($(wmma opA TL)+(8*\elem,\gap)$) {16}; + \draw [->, >=stealth] (k dim A.west) -- ($(wmma opA TL)+(0, \gap)$); + \draw [->, >=stealth] (k dim A.east) -- ($(wmma opA TL)+(16*\elem, \gap)$); + + %% B K dim + \node [scale=.8*\scale, rotate=90] (k dim B) at ($(wmma op TL)+(-\gap, -8*\elem)$) {16}; + \draw [->, >=stealth] (k dim B.east) -- ($(wmma op TL)+(-\gap, 0)$); + \draw [->, >=stealth] (k dim B.west) -- ($(wmma op TL)+(-\gap, -16*\elem)$); + + %% C M dim + \node [scale=.8*\scale] (m dim) at ($(C TL)+(8*\elem,-16*\elem-\gap)$) {16}; + \draw [->, >=stealth] (m dim.west) -- ($(C TL)+(0, -16*\elem-\gap)$); + \draw [->, >=stealth] (m dim.east) -- ($(C TL)+(16*\elem, -16*\elem-\gap)$); + + %% C N dim + \node [scale=.8*\scale, rotate=-90] (n dim) at ($(C TL)+(16*\elem+\gap, -8*\elem)$) {16}; + \draw [->, >=stealth] (n dim.west) -- ($(C TL)+(16*\elem+\gap, 0)$); + \draw [->, >=stealth] (n dim.east) -- ($(C TL)+(16*\elem+\gap, -16*\elem)$); +}