diff --git a/python/src/llvm.cc b/python/src/llvm.cc index f9b98a2540a2..182f79d78332 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -139,8 +139,6 @@ std::string translateLLVMIRToASM(llvm::Module &module, { llvm::raw_string_ostream stream(result); llvm::buffer_ostream pstream(stream); - for (llvm::Function &f : module.functions()) - f.addFnAttr(llvm::Attribute::AlwaysInline); llvm::legacy::PassManager pass; // emit auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 514ac171a3bf..e62373d6fb25 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2563,8 +2563,6 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl. @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("add_overflow_check", [False, True]) def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path): - if add_overflow_check is True and is_hip(): - pytest.skip("overflow check disabled on HIP while fixing issues") overflow_check = """ %17 = arith.extsi %arg2 : i32 to i64 diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index c8c43a051277..c222be2cd64d 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -342,6 +342,9 @@ def make_llir(src, metadata, options): metadata["shared"] = src.get_int_attr("triton_gpu.shared") amd.cleanup_bitcode_metadata(llvm_mod) + # Disable inlining of print related functions, + # because inlining of these function could slow down compilation significantly + amd.disable_print_inline(llvm_mod) return str(llvm_mod) @staticmethod diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index a9bd3e9b7fb7..3c335099104d 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -161,6 +161,24 @@ void init_triton_amd(py::module &&m) { module->eraseNamedMetadata(openclVersion); }); + m.def("disable_print_inline", [](llvm::Module *module) { + // List of functions name prefixes we want to forbid inline. + std::array prefixes = {"__ockl_fprintf", "__ockl_printf"}; + + for (llvm::Function &f : module->functions()) { + if (!f.hasName()) + continue; + llvm::StringRef name = f.getName(); + + auto isNamePrefixed = [&name](const char *prefix) { + return name.starts_with(prefix); + }; + + if (llvm::any_of(prefixes, isNamePrefixed)) + f.addFnAttr(llvm::Attribute::NoInline); + } + }); + m.def( "assemble_amdgcn", [](const std::string &assembly, const std::string &arch,