Skip to content

Commit

Permalink
[ROCm] ROCm compatible configs for triton kernels (pytorch#107584)
Browse files Browse the repository at this point in the history
This PR brings in a few inductor changes required for ROCm

~**1 - Introduction of a toggle for enforced channel last convolution fallbacks**~
This addition is split off into its own PR after some cleanup by @pragupta  pytorch#107812

**2 - Addition of ROCm specific block sizes**
We are now able to support the MAX_AUTOTUNE mode on ROCm, we are proposing conditions to allow us to finetune our own block tuning. Currently triton on ROCm does not benefit from pipelining so we are setting all configs to `num_stages=1` and we have removed some upstream tunings on ROCm to avoid running out of shared memory resources.

In the future we will provide more optimised tunings for ROCm but for now this should mitigate any issues

~**3 - Addition of device_type to triton's compile_meta**~
~Proposing this addition to `triton_heuristics.py`, Triton on ROCm requires device_type to be set to hip ROCm/triton#284 suggesting to bring this change in here so we can pass down the correct device type to triton.~
This change is split off and will arrive in the wheel update PR pytorch#107600 leaving this PR to focus on the ROCm specific block sizes.

Pull Request resolved: pytorch#107584
Approved by: https://github.com/jithunnair-amd, https://github.com/jansel, https://github.com/eellison
  • Loading branch information
jataylo authored and pytorchmergebot committed Aug 26, 2023
1 parent 15e5bd5 commit a18ee0c
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 76 deletions.
39 changes: 28 additions & 11 deletions torch/_inductor/kernel/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
import logging
from typing import List, TypedDict
from typing import cast, List, Tuple, TypedDict

import torch
from .. import config, ir
Expand Down Expand Up @@ -44,18 +44,35 @@ def conv_grid(n, c, h, w, meta):
)


# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
kernel_configs = [
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
{"config": (64, 256, 16, 2, 4), "cond": True},
{"config": (256, 64, 16, 2, 4), "cond": True},
{"config": (1024, 16, 16, 1, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 256, 32, 2, 8), "cond": True},
{"config": (256, 64, 32, 2, 8), "cond": True},
]

# Create filtered list of configs based on conv
platform_configs = tuple(
cast(Tuple[int, int, int, int, int], config["config"])
for config in kernel_configs
if config["cond"]
)

# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip:
platform_configs = tuple(
(config[0], config[1], config[2], 1, config[4]) for config in platform_configs
)

conv_configs = functools.partial(
filtered_configs,
configs=(
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
(64, 256, 16, 2, 4),
(256, 64, 16, 2, 4),
(1024, 16, 16, 1, 8),
(128, 128, 32, 2, 8),
(64, 64, 32, 2, 4),
(64, 256, 32, 2, 8),
(256, 64, 32, 2, 8),
),
configs=platform_configs,
)

LOOP_BODY = """
Expand Down
96 changes: 63 additions & 33 deletions torch/_inductor/kernel/mm_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
import logging
from typing import List, Tuple
from typing import cast, List, Tuple

import sympy

Expand Down Expand Up @@ -44,44 +44,74 @@ def filtered_configs(
)


# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
mm_kernel_configs = [
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 128, 32, 3, 4), "cond": True},
{"config": (128, 64, 32, 3, 4), "cond": True},
{"config": (64, 128, 32, 4, 8), "cond": True},
{"config": (128, 64, 32, 4, 8), "cond": True},
{"config": (64, 32, 32, 5, 8), "cond": True},
{"config": (32, 64, 32, 5, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 64, 3, 8), "cond": True},
{"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
{"config": (64, 64, 16, 2, 4), "cond": True},
{"config": (32, 32, 16, 1, 2), "cond": True},
]

int8_mm_kernel_configs = [
{"config": (64, 64, 32, 2, 4), "cond": True},
{"config": (64, 128, 32, 3, 4), "cond": True},
{"config": (128, 64, 32, 3, 4), "cond": True},
{"config": (64, 128, 32, 4, 8), "cond": True},
{"config": (128, 64, 32, 4, 8), "cond": True},
{"config": (64, 32, 32, 5, 8), "cond": True},
{"config": (32, 64, 32, 5, 8), "cond": True},
{"config": (128, 128, 32, 2, 8), "cond": True},
{"config": (64, 64, 64, 3, 8), "cond": True},
# {"config": (32, 32, 128, 2, 4), "cond": True},
# {"config": (64, 64, 16, 2, 4), "cond": True},
# {"config": (32, 32, 16, 1, 2), "cond": True},
{"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None},
{"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
]

# Create filtered list of configs based on cond evaluation


mm_platform_configs = tuple(
cast(Tuple[int, int, int, int, int], config["config"])
for config in mm_kernel_configs
if config["cond"]
)
int8_platform_configs = tuple(
cast(Tuple[int, int, int, int, int], config["config"])
for config in int8_mm_kernel_configs
if config["cond"]
)

# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip:
mm_platform_configs = tuple(
(config[0], config[1], config[2], 1, config[4])
for config in mm_platform_configs
)
int8_platform_configs = tuple(
(config[0], config[1], config[2], 1, config[4])
for config in mm_platform_configs
)

mm_configs = functools.partial(
filtered_configs,
configs=(
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
(64, 64, 32, 2, 4),
(64, 128, 32, 3, 4),
(128, 64, 32, 3, 4),
(64, 128, 32, 4, 8),
(128, 64, 32, 4, 8),
(64, 32, 32, 5, 8),
(32, 64, 32, 5, 8),
(128, 128, 32, 2, 8),
(64, 64, 64, 3, 8),
(32, 32, 128, 2, 4),
(64, 64, 16, 2, 4),
(32, 32, 16, 1, 2),
),
configs=mm_platform_configs,
)

int8_mm_configs = functools.partial(
filtered_configs,
configs=(
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
(64, 64, 32, 2, 4),
(64, 128, 32, 3, 4),
(128, 64, 32, 3, 4),
(64, 128, 32, 4, 8),
(128, 64, 32, 4, 8),
(64, 32, 32, 5, 8),
(32, 64, 32, 5, 8),
(128, 128, 32, 2, 8),
(64, 64, 64, 3, 8),
# (32, 32, 128, 2, 4),
# (64, 64, 16, 2, 4),
# (32, 32, 16, 1, 2),
(128, 256, 128, 3, 8),
(256, 128, 128, 3, 8),
),
configs=int8_platform_configs,
)


Expand Down
114 changes: 82 additions & 32 deletions torch/_inductor/kernel/mm_plus_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,40 +103,90 @@
def mm_configs():
import triton

# these have been tweaked to workaround register issues
return [
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=4
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=16
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=8
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=1, num_warps=8
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=1, num_warps=8
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=1, num_warps=8
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_stages=2, num_warps=4
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, num_stages=1, num_warps=2
),
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform
mm_triton_configs = [
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 2,
"num_warps": 4,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 3,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 16,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
"num_stages": 4,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
"num_stages": 1,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
"num_stages": 1,
"num_warps": 8,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
"num_stages": 1,
"num_warps": 8,
"cond": torch.version.hip is None,
},
{
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
"num_stages": 2,
"num_warps": 4,
"cond": True,
},
{
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
"num_stages": 1,
"num_warps": 2,
"cond": True,
},
]

# Filter out configs in which cond evaluates to true
# On ROCm convert num_stages to 1 as pipelining provides no benefit
if torch.version.hip:
filtered_configs = [
triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
for c in mm_triton_configs
if c["cond"]
]
else:
filtered_configs = [
triton.Config(
c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
)
for c in mm_triton_configs
if c["cond"]
]

return filtered_configs


def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
"""
Expand Down

0 comments on commit a18ee0c

Please sign in to comment.