Skip to content

Commit

Permalink
Adding an option to force target nv compute capability for triton (#5564
Browse files Browse the repository at this point in the history
)

Adding an option to force NV compute capability for triton. This can be
done either with global env flag `TRITON_OVERRIDE_NV_CAPABILITY`, or
per-kernel with `override_nv_compute_capability=` launch argument.
  • Loading branch information
pawelszczerbuk authored Jan 9, 2025
1 parent f436c9e commit 99bc7b9
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_HIP_USE_BLOCK_PINGPONG",
"TRITON_LLVM_DEBUG_ONLY",
"TRITON_ENABLE_ASAN",
"TRITON_OVERRIDE_NV_CAPABILITY",
"USE_IR_LOC",
"NVPTX_ENABLE_DUMP",
// clang-format on
Expand Down
31 changes: 31 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6036,6 +6036,37 @@ def mul_add(data):
assert found_fma == enable_fp_fusion


# -----------------------
# test override_nv_compute_capability
# -----------------------


@pytest.mark.parametrize("nv_compute_capability", [70, 80, 90])
@pytest.mark.parametrize("env_var_override", [False, True])
def test_override_nv_compute_capability(nv_compute_capability, env_var_override, device):
if not is_cuda():
pytest.skip('test_override_nv_compute_capability only for CUDA')

@triton.jit
def simple(data, out):
in_ptrs = data + tl.arange(0, 128)
out_ptrs = out + tl.arange(0, 128)
tl.store(out_ptrs, tl.load(in_ptrs) * 1.5 + 1.0)

data = torch.randn((128, ), device=device, dtype=torch.float32)
out = torch.empty_like(data)

if env_var_override:
os.environ["TRITON_OVERRIDE_NV_CAPABILITY"] = str(nv_compute_capability)
h = simple[(1, )](data, out)
os.environ.pop("TRITON_OVERRIDE_NV_CAPABILITY")
else:
h = simple[(1, )](data, out, override_nv_compute_capability=nv_compute_capability)
torch.testing.assert_close(data * 1.5 + 1.0, out)
ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"])
assert int(ttgir_cc.group(1)) == nv_compute_capability


# -----------------------
# test propagate_nan
# -----------------------
Expand Down
14 changes: 11 additions & 3 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class CUDAOptions:
debug: bool = False
backend_name: str = 'cuda'
sanitize_overflow: bool = True
override_nv_compute_capability: int = None

def __post_init__(self):
default_libdir = Path(__file__).parent / 'lib'
Expand All @@ -147,7 +148,11 @@ def supports_target(target: GPUTarget):

def __init__(self, target: GPUTarget) -> None:
super().__init__(target)
self.capability = target.arch
# Capability can be overrided to limit feature set to a specific version
cap_override = os.getenv("TRITON_OVERRIDE_NV_CAPABILITY")
self.capability = int(cap_override) if cap_override is not None else target.arch
# HW Capability is used to determine the binary format
self.hw_capability = target.arch
assert isinstance(self.capability, int)
self.binary_ext = "cubin"

Expand All @@ -165,6 +170,9 @@ def parse_options(self, opts) -> Any:

if "enable_fp_fusion" not in args:
args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"

if "override_nv_compute_capability" in args and args["override_nv_compute_capability"] is not None:
self.capability = args["override_nv_compute_capability"]
args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
return CUDAOptions(**args)

Expand Down Expand Up @@ -396,8 +404,8 @@ def add_stages(self, stages, options):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.hw_capability)
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.hw_capability)

@functools.lru_cache()
def hash(self):
Expand Down

0 comments on commit 99bc7b9

Please sign in to comment.