Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gfx950 layouts #692

Merged
merged 18 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 89 additions & 42 deletions python/perf-kernels/tools/plot-layout/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,113 +5,160 @@ Here is the help info from the script.

```bash
>$ python3 plot_layout.py -h
usage: Draw triton layouts [-h] [-shape SHAPE SHAPE SHAPE] [-plot {blocked,dot,wmma,lds}] [-nonKDim {16,32}] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP]
[-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-kWidth {4,8,16}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep]
usage: Draw triton layouts [-h] [-tensorShape TENSORSHAPE TENSORSHAPE] [-dotShape DOTSHAPE DOTSHAPE DOTSHAPE] [-plot {blocked,dot,wmma,lds}] [-dim0 DIM0] [-dim1 DIM1] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD]
[-threadsPerWarp THREADSPERWARP THREADSPERWARP] [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}]
[-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-mfmaTrans] [-scale] [-banks {32,64}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}]
[-mnContig] [-mfma_trans_load] [-swizzleVec {4,8,16,32}] [-padInterval PADINTERVAL] [-padAmount PADAMOUNT] [-wave_size {32,64}] [-o O] [-keep]

options:
-h, --help show this help message and exit
-shape SHAPE SHAPE SHAPE
Tensor shape in the form of M,N,K
-tensorShape TENSORSHAPE TENSORSHAPE
2D tensor shape in the form of dim0,dim1
-dotShape DOTSHAPE DOTSHAPE DOTSHAPE
Dot op shape in the form of M,N,K
-plot {blocked,dot,wmma,lds}
choose plot mode
-nonKDim {16,32} mfma instruction dim
-dim0 DIM0 tensor dim0 name
-dim1 DIM1 tensor dim1 name
-sizePerThread SIZEPERTHREAD SIZEPERTHREAD
-threadsPerWarp THREADSPERWARP THREADSPERWARP
-warpsPerCTA WARPSPERCTA WARPSPERCTA
-order ORDER ORDER
-kWidth {4,8,16} number of elements per thread
-nonKDim {16,32} mfma instruction dim
-kWidth {4,8,16,32} number of contiguous elements per thread
-kGroup {1,2} total number of elements / kWidth per mfma instruction
-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}
element type of operand A
-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}
element type of operand B
-mfmaTrans If set, then use mfma.trans layout
-scale If set, plot the scale tensor for mfma_f8f6f4 instructions
-banks {32,64} choose the number of banks in LDS
-lds_layout {swizzle,padding,none}
choose the LDS data layout
-lds_access {read,write,none}
choose LDS access mode
-mnContig If set, the tensor is K x N and n-contig
-mfma_trans_load If set, use MFMA transpose load instructions
-swizzleVec {4,8,16,32}
number of contiguous elements in a vector to swizzle
-padInterval PADINTERVAL
Add padding for every padInterval bytes
-padAmount PADAMOUNT Pad padAmount bytes for every padInterval bytes
-wave_size {32,64} choose the wmma instruction mode
-o O output pdf file name (without surfix)
-mfmaTrans If set, then use mfma.trans layout
-keep If set, keep the generated .tex file
```

## Installation
This script does not require torch or triton to be installed. The only package
it depends on is latex. On Ubuntu, do
```bash
sudo apt install texlive-full
sudo apt-get install texlive-latex-base texlive-latex-extra texlive-fonts-recommended texlive-fonts-extra

```

## Draw blocked layout (`-plot blocked`)

Examples:
```bash
python3 plot_layout.py -plot blocked -shape 128 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1
python3 plot_layout.py -plot blocked -shape 16 128 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2
python3 plot_layout.py -plot blocked -shape 32 128 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1
python3 plot_layout.py -plot blocked -tensorShape 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1
python3 plot_layout.py -plot blocked -tensorShape 16 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2
python3 plot_layout.py -plot blocked -tensorShape 32 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1
```

Blocked layouts are used during global load. It is used to describe the layout of the tensor
for pointers and results.
We can provide tensor shape (`-shape M N K`) and blocked layout parameters (
We can provide tensor shape (`-tensorShape dim0 dim1`) and blocked layout parameters (
`-sizePerThread x y`, `-threadsPerWarp x y`, and `-warpsPerCTA x y`).
We can also provide the order of the tensor as `-order x y` to control which dim
is the fastest changing dimension.

Notes
- All of the gemm dims (M, N, and K) are needed when providing the shape. But only
M and K will be used to plot the layout of the tensor.
- The script does not support the case when threads are loading elements that are
out of the boundary of the tensor dimensions. This means
- For M: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= M
- For K: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= K
- For dim0: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= dim0
- For dim1: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= dim1


## Draw mfma operand and result layouts (`-plot dot`)

Examples:
```bash
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16
## i8 inputs
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a i8 -dtype_b i8
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a i8 -dtype_b i8
## fp16/bf16 inputs
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 4 -dtype_a fp16 -dtype_b fp16
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp16 -dtype_b fp16
## fp8/bf8 inputs
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp8 -dtype_b bf8
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a fp8 -dtype_b bf8
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp8 -dtype_b bf8
## f4 and fp6/bf6 inputs
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 32 -kGroup 1 -dtype_a f4 -dtype_b bf6
## fp8/bf8 and fp6/bf6/f4 inputs
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8
## mixed precision with scaling
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8 -scale
```

One can add `-nonKDim [16,32]` and `-mfmaTrans` to all of the above examples.

This mode draws two graphs:
1. The layout of the whole tile for tile A, B, and C
1. The layout of the dot operation, i.e. tile C = tile A x tile B
2. The layout of a single mfma block, operands and results of one or more mfma
instructions that share the same accumulating VGPRs.
This view has thread distributions among tensor elements.

Knobs
- `-kWidth`: the number of elements that will be loaded into one thread at once
- `-nonKDim`: 16 ot 32, which is used to control the mfma instruction size
- `-kWidth [4,8,16,32]`: the number of elements that will be loaded into one thread at once
- `-kGroup [1,2]`: total number of elements / kWidth for on mfma instruction.
This is 1 for all mfma instructions except for mfma_f32_16x16x128_f8f6f4 and mfma_f32_32x32x64_f8f6f4
with fp8 input types (CBSZ=0 or 1 and/or BLGP=0 or 1)
- `-nonKDim [16,32]`: mfma instruction size. The default is set to 16.
- `-mfmaTrans`: if set, the transposed mfma layout will be plotted.
- `-dtype_a` and `-dtype_b`: element types of operand A and B. The default value is fp16.
- `-scale`: plot scale tensors for A and B. This is only supported with f4/f6 and f8 with `kGroup=2`.
If `-scale` is set but not supported, it's ignored.

Notes
- The layout shows the mapping from the threads/wave to the elements in the
original tensor. It does not care if the elements are arranged in LDS, like
swizzling to avoid bank conflicts.
- The script does not allow settings for data type or k dim of the mfma instruction.
This can be controled by the `-kWidth` flag.
- For example, if we want `mfma_32x32x8xf16`, we can set `-nonKDim 32` and `-kWidth 4`.
- If we want `mfma_32x32x16xf8`, we can set `-nonKDim 32` and `-kWidth 8`.

original tensor. It does not matter if LDS is used.
- The script does not allow settings for k dim of the mfma instruction.
This can be controled by the `-kWidth` and `-kGroup`.

## Draw LDS access (`-plot lds`)

Examples:
```bash
python3 plot_layout.py -plot lds -lds_layout none -lds_access none -shape 128 128 64 -kWidth 8
python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 8
python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 32 -dtype_a f4
python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64
python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64
python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access read -tensorShape 128 128 -kWidth 16 -dtype_a bf8 -banks 64
python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access write -tensorShape 128 128 -kWidth 16 -dtype_a f4 -banks 32
python3 plot_layout.py -plot lds -lds_layout none -lds_access read -tensorShape 128 32 -kWidth 4 -dtype_a fp16 -banks 64 -mnContig
python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access read -tensorShape 128 32 -kWidth 16 -dtype_a fp8 -banks 64 -mnContig -mfma_trans_load
python3 plot_layout.py -plot lds -lds_layout padding -lds_access none -tensorShape 128 32 -kWidth 8 -dtype_a fp16 -banks 32 -padInterval 128 -padAmount 16
```

Knobs
- `kWidth` here means the vector size when accessing LDS
- `kWidth`: the vector size (in unit of elements) when accessing LDS
- `banks`: the number of banks in LDS. (64 for gfx950, 32 for pre-gfx950)
- `dtype_a`: element data type
- Three options for `-lds_layout`:
- `none`: no swizzling, no padding
- `padding`: padding at every 128B
- `swizzling`: apply the swizzling pattern, which is derived from tensor shape and kWidth.
- `swizzle`: apply the swizzling pattern, which is derived from tensor shape and kWidth.
- `padding`: pad `padAmount` bytes for every `padInterval` bytes of data
- `padAmount`: default is 0
- `padInterval`: default is 1
- Three options for `-lds_access`:
- `none`: do not plot access pattern
- `read`: plot accessed elements during ds_read
- `write`: plot accessed elements during ds_write. Note that this needs some infomation from
global load. Therefore, we need to provide `-sizePerThread` and `-threadsPerWarp`.

Notes
- This mode is rarely used. If you have any questions, please contact Lixun Zhang directly.
- `read`: plot accessed elements at the first cycle of ds_read
- `write`: plot accessed elements during ds_write. For global load access, we assume
a fully coalesced dwordx4 access pattern along the K dim.
- `mnContig`: If set, the tile is stored in mn-contig layout. In this layout, elements along
the M/N dim are contiguous in both global memory and LDS.
- `mfma_trans_load`: This flag only works when `mnContig` is set. When set, `ds_read_b64_tr_bx`
instructions are used to read from LDS. Note that current triton LDS layout mechanism will
lead to bank conflicts.
157 changes: 157 additions & 0 deletions python/perf-kernels/tools/plot-layout/blockedLayout.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
\newcommand{\drawBlockedWave}[5]{
%%
%% Draw a wave coverage with blocked layout
%%
%% Wave TL: pre defined top-left coordinate of the wave
%% \elem: pre defined variable
%%
%% #1: sizePerThread[0] --> sizePerThreadM
%% #2: sizePerThread[1] --> sizePerThreadN
%% #3: threadsPerWarp[0] --> threadsPerWarpM
%% #4: threadsPerWarp[1] --> threadsPerWarpN
%% #5: fastest changing dim --> order

\pgfmathsetmacro{\sizePerThreadM}{#1}
\pgfmathsetmacro{\sizePerThreadN}{#2}
\pgfmathsetmacro{\threadsPerWarpM}{#3}
\pgfmathsetmacro{\threadsPerWarpN}{#4}
\pgfmathsetmacro{\order}{#5}

\pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM}
\pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN}

\foreach \tid in {0,...,63}{
\pgfmathsetmacro{\tidM}{int(\tid/\threadsPerWarpN)}
\pgfmathsetmacro{\tidN}{mod(\tid,\threadsPerWarpN)}
\coordinate (Thread TL) at ($(Wave TL)+(\tidN*\sizePerThreadN*\elem, -\tidM*\sizePerThreadM*\elem)$);
\pgfmathsetmacro{\ratio}{\tidM*10}

\ifthenelse{\tid = 0}{
\draw [line width = 0.01mm, fill=red] (Thread TL)
rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
}{
\draw [line width = 0.01mm, fill=blue!\ratio!white] (Thread TL)
rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
}
}
\draw (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem);
}

\newcommand{\drawBlockedCTA}[7]{
%%
%% Draw a CTA coverage with blocked layout
%%
%% CTA TL: pre defined top-left coordinate of the CTA
%% \elem: pre defined variable
%%
%% #1: sizePerThread[0] --> sizePerThreadM
%% #2: sizePerThread[1] --> sizePerThreadN
%% #3: threadsPerWarp[0] --> threadsPerWarpM
%% #4: threadsPerWarp[1] --> threadsPerWarpN
%% #5: warpsPerCTA[0] --> warpsPerCTAM
%% #6: warpsPerCTA[1] --> warpsPerCTAN
%% #7: fastest changing dim --> order

\pgfmathsetmacro{\sizePerThreadM}{#1}
\pgfmathsetmacro{\sizePerThreadN}{#2}
\pgfmathsetmacro{\threadsPerWarpM}{#3}
\pgfmathsetmacro{\threadsPerWarpN}{#4}
\pgfmathsetmacro{\warpsPerCTAM}{#5}
\pgfmathsetmacro{\warpsPerCTAN}{#6}
\pgfmathsetmacro{\order}{#7}

\pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM}
\pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN}
\pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM}
\pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN}

\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM*\warpsPerCTAN-1}

\coordinate (Wave TL) at (CTA TL);
\drawBlockedWave{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\order}
\foreach \waveId in {0,...,\maxWaveId}{
\ifthenelse{\order=1}
{
\pgfmathsetmacro{\waveCoordM}{int(\waveId/\warpsPerCTAN)}
\pgfmathsetmacro{\waveCoordN}{mod(\waveId,\warpsPerCTAN)}
\pgfmathsetmacro{\rot}{0}
}{
\pgfmathsetmacro{\waveCoordM}{mod(\waveId,\warpsPerCTAM)}
\pgfmathsetmacro{\waveCoordN}{int(\waveId/\warpsPerCTAM)}
\pgfmathsetmacro{\rot}{90}
}

\coordinate (Wave TL) at ($(CTA TL)+(\waveCoordN*\waveSizeN*\elem, -\waveCoordM*\waveSizeM*\elem)$);
\draw [ultra thin] (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem)
node [pos=.5, scale=.6*\scale, inner sep=0, fill=white, rotate=\rot] {wave\waveId};
}

\draw [thick] (CTA TL) rectangle ++(\CTASizeN*\elem, -\CTASizeM*\elem);
}

\newcommand{\drawBlockedTensor}[8]{
%%
%% Draw a tensor with blocked layout of the following parameters
%% sizePerThread[2]
%% threadsPerWarp[2]
%% warpsPerCTA[2]
%% order[2]
%%
%% TL: pre defined top-left coordinate of the tensor
%% \elem: pre defined variable
%% \dimColName: dim0Name
%% \dimRowName: dim1Name
%%
%% #1: tensorShape[0] --> M
%% #2: tensorShape[1] --> N
%% #3: sizePerThread[0] --> sizePerThreadM
%% #4: sizePerThread[1] --> sizePerThreadN
%% #5: threadsPerWarp[0] --> threadsPerWarpM
%% Note that threadsPerWarp[1] is calculated by 64/threadsPerWarp[0]
%% #6: warpsPerCTA[0] --> warpsPerCTAM
%% #7: warpsPerCTA[1] --> warpsPerCTAN
%% #8: fastest changing dim --> order

\pgfmathsetmacro{\M}{#1}
\pgfmathsetmacro{\N}{#2}
\pgfmathsetmacro{\sizePerThreadM}{#3}
\pgfmathsetmacro{\sizePerThreadN}{#4}
\pgfmathsetmacro{\threadsPerWarpM}{#5}
\pgfmathsetmacro{\warpsPerCTAM}{#6}
\pgfmathsetmacro{\warpsPerCTAN}{#7}
\pgfmathsetmacro{\order}{#8}

\pgfmathsetmacro{\threadsPerWarpN}{64/\threadsPerWarpM}
\pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM}
\pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN}
\pgfmathsetmacro{\CTARepM}{\M/\CTASizeM}
\pgfmathsetmacro{\CTARepN}{\N/\CTASizeN}
\pgfmathsetmacro{\maxCTAId}{\CTARepM*\CTARepN-1}

\foreach \ctaId in {0,...,\maxCTAId}{
\pgfmathsetmacro{\ctaCoordM}{int(\ctaId/\CTARepN)}
\pgfmathsetmacro{\ctaCoordN}{mod(\ctaId,\CTARepN)}
\coordinate (CTA TL) at ($(TL)+(\ctaCoordN*\CTASizeN*\elem, -\ctaCoordM*\CTASizeM*\elem)$);
\drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order}
}

\node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {\dimColName=\M};
\node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {\dimRowName=\N};

\def\zoomR{1.5}
\coordinate (zoomin BL) at ($(TL)+(0, .3)$);

\foreach \hl in {0,...,\sizePerThreadM}{
\draw ($(zoomin BL)+(0, \hl*\elem*\zoomR)$) -- ++(\sizePerThreadN*\elem*\zoomR,0);
}
\foreach \vl in {0,...,\sizePerThreadN}{
\draw ($(zoomin BL)+(\vl*\elem*\zoomR, 0)$) -- ++(0, \sizePerThreadM*\elem*\zoomR);
}

\node [scale=.6*\scale, left] at ($(zoomin BL)+(0, .5*\sizePerThreadM*\elem*\zoomR)$) {$t_0$};
\node [scale=.6*\scale, right] at ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, .5*\sizePerThreadM*\elem*\zoomR)$) {\sizePerThreadM$\times$\sizePerThreadN};

\draw [densely dotted] (TL) -- (zoomin BL);
\draw [densely dotted] ($(TL)+(\sizePerThreadN*\elem, 0)$) -- ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, 0)$);
\draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
}
Loading
Loading