Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[frontend] Remove Complex Regex for MLIR Parsing #4924

Merged
merged 22 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,16 @@ void init_triton_ir(py::module &&m) {
[](ModuleOp &self, FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("get_first_func_name",
[](ModuleOp &self) -> std::string {
std::string str;
SamGinzburg marked this conversation as resolved.
Show resolved Hide resolved
for (auto &op : self.getOps()) {
if (auto func = dyn_cast<FuncOp>(op)) {
return func.getName().str();
}
}
return str;
})
.def("has_function",
[](ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
Expand All @@ -500,6 +510,54 @@ void init_triton_ir(py::module &&m) {
[](ModuleOp &self, std::string &funcName) -> FuncOp {
return self.lookupSymbol<FuncOp>(funcName);
})
/*
* def ty_to_cpp(ty) is the consumer of this function.
* If the type is a ptr it expects ty[0] == '*', else the type itself.
*/

.def("get_function_signature",
[](ModuleOp &self, FuncOp &func) -> std::vector<std::string> {
std::vector<std::string> strVec;

auto findTMA = [](ArrayRef<NamedAttribute> dictVals) {
for (auto attr : dictVals) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
SamGinzburg marked this conversation as resolved.
Show resolved Hide resolved
if (intAttr.getInt() == 1)
return true;
}
}
return false;
};

auto type = func.getFunctionType();
unsigned numArgs = type.getNumInputs();
for (unsigned i = 0; i != numArgs; ++i) {
std::string tempType;
llvm::raw_string_ostream os(tempType);

auto ty = type.getInput(i);
if (auto attributes = func.getCallableArgAttrs()) {
Attribute attr = attributes[i];
// Check for tt.nv_tma_desc = 1
if (auto dAttr = dyn_cast<DictionaryAttr>(attr)) {
ArrayRef<NamedAttribute> dictVals = dAttr.getValue();
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
if (findTMA(dictVals)) {
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
strVec.push_back("nvTmaDesc");
continue;
}
}
}
if (auto ptrType = dyn_cast<PointerType>(ty)) {
auto pType = ptrType.getPointeeType();
os << "*";
pType.print(os);
} else {
ty.print(os);
}
strVec.push_back(tempType);
}
return strVec;
})
.def("get_int_attr",
[](ModuleOp &self, std::string name) -> py::object {
auto ret = self->getAttrOfType<IntegerAttr>(name);
Expand Down
57 changes: 57 additions & 0 deletions python/test/unit/runtime/test_irsource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import tempfile
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe that should be in python/test/unit/tools/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just moved test_irsource.py to the tools dir

import triton
from triton.compiler import IRSource
from triton._C.libtriton import ir

target = triton.runtime.driver.active.get_current_target()

def test_mlir_attribute_parsing() -> None:
'''
Tests that MLIR attributes are parsed correctly from input ttir/ttgir.

Checks for the following:
1. Name and type signature are parsed correctly
2. _get_num_warps_from_ir_str() works
3. tt.nv_tma_desc attribute is parsed correctly
'''

sample_ttgir = r"""
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32},
%arg4: i32 {tt.divisibility = 16 : i32},
%arg5: i32 {tt.divisibility = 16 : i32},
%arg6: i32 {tt.divisibility = 16 : i32},
%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32},
%desc: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} {
tt.return
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir)
f.flush()
context = ir.context()
ir.load_dialects(context)
src = IRSource(f.name, context)

# check name and type signature
# should match ty_to_cpp(...)
assert src.signature == \
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "i32", 9: "nvTmaDesc"}
assert src.name == "@matmul_kernel"

# check num warps
assert src.parse_options()['num_warps'] == 8

# now test compilation
triton.compile(f.name, target=target)
4 changes: 3 additions & 1 deletion python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import multiprocessing
import shutil
import tempfile

import triton
import triton.language as tl
from triton.backends.compiler import AttrsDescriptor
from triton.compiler import ASTSource
from triton.compiler import ASTSource, IRSource
from triton._C.libtriton import ir

target = triton.runtime.driver.active.get_current_target()

Expand Down
7 changes: 5 additions & 2 deletions python/triton/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict
from .errors import CompilationError

__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]
__all__ = [
"compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError",
"LazyDict"
]
59 changes: 30 additions & 29 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,13 @@
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌🙌🙌

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment above is stale now. Can you guys remove it in whichever's your next PR (no need to create a PR just for this)

ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
"ttgir": mlir_prototype_pattern,
"ptx": ptx_prototype_pattern,
}

mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}

Expand All @@ -54,16 +48,6 @@ def convert_type_repr(x):
return x


def _get_num_warps_from_ir_str(src: str):
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
# e.g. someone has an instruction (not module) attribute named "num-warps".
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
num_warps = int(num_warps_matches[0])
return num_warps


class ASTSource:

def __init__(self, fn, signature, constants=None, attrs=None) -> None:
Expand Down Expand Up @@ -106,28 +90,38 @@ def parse_options(self):

class IRSource:

def __init__(self, path):
def __init__(self, path, context):
SamGinzburg marked this conversation as resolved.
Show resolved Hide resolved
self.path = path
path = Path(path)
self.ext = path.suffix[1:]
self.src = path.read_text()
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
self.name = match.group(1)
signature = match.group(2)
types = re.findall(arg_type_pattern[self.ext], signature)
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}

# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
# TODO - replace with a proper parser
if self.ext == "ptx":
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
self.name = match.group(1)
signature = match.group(2)
types = re.findall(arg_type_pattern[self.ext], signature)
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
else:
self.module = ir.parse_mlir_module(self.path, context)
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
fn_name = self.module.get_first_func_name()
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
self.name = "@" + fn_name
funcOp = self.module.get_function(fn_name)
func_ty = self.module.get_function_signature(funcOp)
self.signature = {k: ty for k, ty in enumerate(func_ty)}

def hash(self):
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, module_map, context):
module = ir.parse_mlir_module(self.path, context)
module.context = context
return module
self.module.context = context
return self.module

def parse_options(self):
if self.ext == "ttgir":
return {'num_warps': _get_num_warps_from_ir_str(self.src)}
return {'num_warps': self.module.get_int_attr("triton_gpu.num-warps")}
SamGinzburg marked this conversation as resolved.
Show resolved Hide resolved
return dict()


Expand Down Expand Up @@ -223,7 +217,11 @@ def compile(src, target=None, options=None):
# create backend
if ir_source:
assert isinstance(src, str), "source must be either AST or a filepath"
src = IRSource(src)
# Do an early init, since we use the MLIR parser which needs the context
context = ir.context()
ir.load_dialects(context)
src = IRSource(src, context)

extra_options = src.parse_options()
options = backend.parse_options(dict(options or dict(), **extra_options))
# create cache manager
Expand Down Expand Up @@ -264,8 +262,11 @@ def compile(src, target=None, options=None):
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
if ir_source:
first_stage += 1
context = ir.context()
ir.load_dialects(context)

# We initialize these
if not ir_source:
context = ir.context()
ir.load_dialects(context)
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()
module_map = backend.get_module_map()
Expand Down
Loading