Skip to content

Commit

Permalink
[frontend] Remove Complex Regex for MLIR Parsing (triton-lang#4924)
Browse files Browse the repository at this point in the history
There were a number of complex regexes used for parsing MLIR in the
python frontend. For maintainability reasons, it is likely better to
just expose the MLIR bindings to python and use those instead.

The PTX regex is left in place because we don't have an easy way to
parse PTX (for now).
  • Loading branch information
SamGinzburg authored and Luosuu committed Nov 13, 2024
1 parent e5b0748 commit 1c814d0
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 32 deletions.
48 changes: 48 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Transforms/LocationSnapshot.h"

#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
Expand Down Expand Up @@ -491,6 +492,16 @@ void init_triton_ir(py::module &&m) {
[](ModuleOp &self, FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("get_entry_func_name",
[](ModuleOp &self) -> std::string {
for (auto &op : self.getOps()) {
if (auto func = dyn_cast<FuncOp>(op)) {
if (LLVM::isKernel(func))
return func.getName().str();
}
}
return "";
})
.def("has_function",
[](ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
Expand All @@ -501,6 +512,43 @@ void init_triton_ir(py::module &&m) {
[](ModuleOp &self, std::string &funcName) -> FuncOp {
return self.lookupSymbol<FuncOp>(funcName);
})
/*
* def ty_to_cpp(ty) is the consumer of this function.
* If the type is a ptr it expects ty[0] == '*', else the type itself.
*/

.def("get_function_signature",
[](ModuleOp &self, FuncOp &func) -> std::vector<std::string> {
std::vector<std::string> strVec;

auto type = func.getFunctionType();
unsigned numArgs = type.getNumInputs();
for (unsigned i = 0; i != numArgs; ++i) {
std::string tempType;
llvm::raw_string_ostream os(tempType);

auto ty = type.getInput(i);
if (auto attributes = func.getCallableArgAttrs()) {
Attribute attr = attributes[i];
// Check for tt.nv_tma_desc = 1
if (auto dAttr = dyn_cast<DictionaryAttr>(attr)) {
if (dAttr.contains("tt.nv_tma_desc")) {
strVec.push_back("nvTmaDesc");
continue;
}
}
}
if (auto ptrType = dyn_cast<PointerType>(ty)) {
auto pType = ptrType.getPointeeType();
os << "*";
pType.print(os);
} else {
ty.print(os);
}
strVec.push_back(tempType);
}
return strVec;
})
.def("get_int_attr",
[](ModuleOp &self, std::string name) -> py::object {
auto ret = self->getAttrOfType<IntegerAttr>(name);
Expand Down
93 changes: 93 additions & 0 deletions python/test/unit/tools/test_irsource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import tempfile
import triton
from triton.compiler import IRSource
from triton._C.libtriton import ir

target = triton.runtime.driver.active.get_current_target()


def test_mlir_attribute_parsing() -> None:
'''
Tests that MLIR attributes are parsed correctly from input ttir/ttgir.
Checks for the following:
1. Name and type signature are parsed correctly
2. _get_num_warps_from_ir_str() works
3. tt.nv_tma_desc attribute is parsed correctly
'''

sample_ttgir = r"""
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32},
%arg4: i32 {tt.divisibility = 16 : i32},
%arg5: i32 {tt.divisibility = 16 : i32},
%arg6: i32 {tt.divisibility = 16 : i32},
%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32},
%desc: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} {
tt.return
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir)
f.flush()
context = ir.context()
src = IRSource(f.name, context)

# check name and type signature
# should match ty_to_cpp(...)
assert src.signature == \
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}
assert src.name == "@matmul_kernel"

# check num warps
assert src.parse_options()['num_warps'] == 8

sample_ttgir_vector_add = r"""
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32})
attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
%7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
%10 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
%13 = arith.addi %9, %12 : tensor<1024xi32, #blocked>
%14 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
tt.return
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir_vector_add)
f.flush()
context = ir.context()
src = IRSource(f.name, context)

# now test compilation
triton.compile(f.name, target=target)
7 changes: 5 additions & 2 deletions python/triton/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict
from .errors import CompilationError

__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]
__all__ = [
"compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError",
"LazyDict"
]
65 changes: 35 additions & 30 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,13 @@
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
"ttgir": mlir_prototype_pattern,
"ptx": ptx_prototype_pattern,
}

mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}

Expand All @@ -55,16 +49,6 @@ def convert_type_repr(x):
return x


def _get_num_warps_from_ir_str(src: str):
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
# e.g. someone has an instruction (not module) attribute named "num-warps".
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
num_warps = int(num_warps_matches[0])
return num_warps


class ASTSource:

def __init__(self, fn, signature, constants=None, attrs=None) -> None:
Expand Down Expand Up @@ -107,28 +91,41 @@ def parse_options(self):

class IRSource:

def __init__(self, path):
def __init__(self, path, context):
self.path = path
path = Path(path)
self.ext = path.suffix[1:]
self.src = path.read_text()
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
self.name = match.group(1)
signature = match.group(2)
types = re.findall(arg_type_pattern[self.ext], signature)
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
ir.load_dialects(context)

# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
# TODO - replace with a proper parser
if self.ext == "ptx":
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
self.name = match.group(1)
signature = match.group(2)
types = re.findall(arg_type_pattern[self.ext], signature)
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
else:
self.module = ir.parse_mlir_module(self.path, context)
fn_name = self.module.get_entry_func_name()
self.name = "@" + fn_name
funcOp = self.module.get_function(fn_name)
func_ty = self.module.get_function_signature(funcOp)
self.signature = {k: ty for k, ty in enumerate(func_ty)}

def hash(self):
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, module_map, context):
module = ir.parse_mlir_module(self.path, context)
module.context = context
return module
self.module.context = context
return self.module

def parse_options(self):
if self.ext == "ttgir":
return {'num_warps': _get_num_warps_from_ir_str(self.src)}
num_warps = self.module.get_int_attr("triton_gpu.num-warps")
assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute"
return {'num_warps': num_warps}
return dict()


Expand Down Expand Up @@ -225,7 +222,9 @@ def compile(src, target=None, options=None):
# create backend
if ir_source:
assert isinstance(src, str), "source must be either AST or a filepath"
src = IRSource(src)
context = ir.context()
src = IRSource(src, context)

extra_options = src.parse_options()
options = backend.parse_options(dict(options or dict(), **extra_options))
# create cache manager
Expand Down Expand Up @@ -266,9 +265,15 @@ def compile(src, target=None, options=None):
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
if ir_source:
first_stage += 1
context = ir.context()
ir.load_dialects(context)
backend.load_dialects(context)

if not isinstance(src, IRSource):
context = ir.context()
ir.load_dialects(context)
backend.load_dialects(context)
else:
# For IRSource, we have already grabbed the context + called ir.load_dialects
# just need to load the dialects for the backend.
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()
module_map = backend.get_module_map()
try:
Expand Down

0 comments on commit 1c814d0

Please sign in to comment.