diff --git a/python/src/ir.cc b/python/src/ir.cc index cce7c87e8d87..a2a2c7263c69 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -24,6 +24,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Transforms/LocationSnapshot.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" @@ -491,6 +492,16 @@ void init_triton_ir(py::module &&m) { [](ModuleOp &self, FuncOp &funcOp) -> void { self.push_back(funcOp); }) + .def("get_entry_func_name", + [](ModuleOp &self) -> std::string { + for (auto &op : self.getOps()) { + if (auto func = dyn_cast(op)) { + if (LLVM::isKernel(func)) + return func.getName().str(); + } + } + return ""; + }) .def("has_function", [](ModuleOp &self, std::string &funcName) -> bool { if (self.lookupSymbol(funcName)) @@ -501,6 +512,43 @@ void init_triton_ir(py::module &&m) { [](ModuleOp &self, std::string &funcName) -> FuncOp { return self.lookupSymbol(funcName); }) + /* + * def ty_to_cpp(ty) is the consumer of this function. + * If the type is a ptr it expects ty[0] == '*', else the type itself. + */ + + .def("get_function_signature", + [](ModuleOp &self, FuncOp &func) -> std::vector { + std::vector strVec; + + auto type = func.getFunctionType(); + unsigned numArgs = type.getNumInputs(); + for (unsigned i = 0; i != numArgs; ++i) { + std::string tempType; + llvm::raw_string_ostream os(tempType); + + auto ty = type.getInput(i); + if (auto attributes = func.getCallableArgAttrs()) { + Attribute attr = attributes[i]; + // Check for tt.nv_tma_desc = 1 + if (auto dAttr = dyn_cast(attr)) { + if (dAttr.contains("tt.nv_tma_desc")) { + strVec.push_back("nvTmaDesc"); + continue; + } + } + } + if (auto ptrType = dyn_cast(ty)) { + auto pType = ptrType.getPointeeType(); + os << "*"; + pType.print(os); + } else { + ty.print(os); + } + strVec.push_back(tempType); + } + return strVec; + }) .def("get_int_attr", [](ModuleOp &self, std::string name) -> py::object { auto ret = self->getAttrOfType(name); diff --git a/python/test/unit/tools/test_irsource.py b/python/test/unit/tools/test_irsource.py new file mode 100644 index 000000000000..a886ebb457f4 --- /dev/null +++ b/python/test/unit/tools/test_irsource.py @@ -0,0 +1,93 @@ +import tempfile +import triton +from triton.compiler import IRSource +from triton._C.libtriton import ir + +target = triton.runtime.driver.active.get_current_target() + + +def test_mlir_attribute_parsing() -> None: + ''' + Tests that MLIR attributes are parsed correctly from input ttir/ttgir. + + Checks for the following: + 1. Name and type signature are parsed correctly + 2. _get_num_warps_from_ir_str() works + 3. tt.nv_tma_desc attribute is parsed correctly + ''' + + sample_ttgir = r""" +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg4: i32 {tt.divisibility = 16 : i32}, + %arg5: i32 {tt.divisibility = 16 : i32}, + %arg6: i32 {tt.divisibility = 16 : i32}, + %arg7: i32 {tt.divisibility = 16 : i32}, + %arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32}, + %desc: !tt.ptr {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} { + tt.return + } +} +""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(sample_ttgir) + f.flush() + context = ir.context() + src = IRSource(f.name, context) + + # check name and type signature + # should match ty_to_cpp(...) + assert src.signature == \ + {0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \ + 4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"} + assert src.name == "@matmul_kernel" + + # check num warps + assert src.parse_options()['num_warps'] == 8 + + sample_ttgir_vector_add = r""" + #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}) + attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> + %13 = arith.addi %9, %12 : tensor<1024xi32, #blocked> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %15, %13, %6 : tensor<1024x!tt.ptr, #blocked> + tt.return + } + } + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(sample_ttgir_vector_add) + f.flush() + context = ir.context() + src = IRSource(f.name, context) + + # now test compilation + triton.compile(f.name, target=target) diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index bbe8c047c6d1..a05efd7e0807 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,4 +1,7 @@ -from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict +from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict from .errors import CompilationError -__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] +__all__ = [ + "compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", + "LazyDict" +] diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 8e460c977f1b..f336296f596e 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -25,19 +25,13 @@ # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing # zero or more arguments separated by commas, and capture it as group 2 (the argument list) # - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 -mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" prototype_pattern = { - "ttir": mlir_prototype_pattern, - "ttgir": mlir_prototype_pattern, "ptx": ptx_prototype_pattern, } -mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?' ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { - "ttir": mlir_arg_type_pattern, - "ttgir": mlir_arg_type_pattern, "ptx": ptx_arg_type_pattern, } @@ -55,16 +49,6 @@ def convert_type_repr(x): return x -def _get_num_warps_from_ir_str(src: str): - ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' - # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if - # e.g. someone has an instruction (not module) attribute named "num-warps". - num_warps_matches = re.findall(ttgir_num_warps_pattern, src) - assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" - num_warps = int(num_warps_matches[0]) - return num_warps - - class ASTSource: def __init__(self, fn, signature, constants=None, attrs=None) -> None: @@ -107,28 +91,41 @@ def parse_options(self): class IRSource: - def __init__(self, path): + def __init__(self, path, context): self.path = path path = Path(path) self.ext = path.suffix[1:] self.src = path.read_text() - match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) - self.name = match.group(1) - signature = match.group(2) - types = re.findall(arg_type_pattern[self.ext], signature) - self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + ir.load_dialects(context) + + # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now. + # TODO - replace with a proper parser + if self.ext == "ptx": + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + else: + self.module = ir.parse_mlir_module(self.path, context) + fn_name = self.module.get_entry_func_name() + self.name = "@" + fn_name + funcOp = self.module.get_function(fn_name) + func_ty = self.module.get_function_signature(funcOp) + self.signature = {k: ty for k, ty in enumerate(func_ty)} def hash(self): return hashlib.sha256(self.src.encode("utf-8")).hexdigest() def make_ir(self, options, codegen_fns, module_map, context): - module = ir.parse_mlir_module(self.path, context) - module.context = context - return module + self.module.context = context + return self.module def parse_options(self): if self.ext == "ttgir": - return {'num_warps': _get_num_warps_from_ir_str(self.src)} + num_warps = self.module.get_int_attr("triton_gpu.num-warps") + assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute" + return {'num_warps': num_warps} return dict() @@ -225,7 +222,9 @@ def compile(src, target=None, options=None): # create backend if ir_source: assert isinstance(src, str), "source must be either AST or a filepath" - src = IRSource(src) + context = ir.context() + src = IRSource(src, context) + extra_options = src.parse_options() options = backend.parse_options(dict(options or dict(), **extra_options)) # create cache manager @@ -266,9 +265,15 @@ def compile(src, target=None, options=None): # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. if ir_source: first_stage += 1 - context = ir.context() - ir.load_dialects(context) - backend.load_dialects(context) + + if not isinstance(src, IRSource): + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + else: + # For IRSource, we have already grabbed the context + called ir.load_dialects + # just need to load the dialects for the backend. + backend.load_dialects(context) codegen_fns = backend.get_codegen_implementation() module_map = backend.get_module_map() try: