-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[frontend] Remove Complex Regex for MLIR Parsing #4924
Merged
Merged
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
fd07096
wip
SamGinzburg 1d684f0
replace regex in irsource w/native bindings
SamGinzburg 244d2a0
leave the ptx stuff in there for now
SamGinzburg fb1ebb3
test passing just need to format
SamGinzburg 9df89c9
formatting fixes
SamGinzburg 5b8a827
Merge branch 'main' into PR-mlirparsing
SamGinzburg 2b2f18c
move irsource test to its own folder
SamGinzburg 30dc6a7
Merge remote-tracking branch 'refs/remotes/origin/PR-mlirparsing' int…
SamGinzburg 1be2894
Merge remote-tracking branch 'upstream/main' into PR-mlirparsing
SamGinzburg 4e84db6
check tma correctly, move tests, add second test
SamGinzburg 0ea0101
formatting
SamGinzburg 0c10375
address nit
SamGinzburg 62766fc
add assert
SamGinzburg a404563
remember to register backend
SamGinzburg 94beaca
Merge remote-tracking branch 'upstream/main' into PR-mlirparsing
SamGinzburg 3ebb020
move context
SamGinzburg 79c9eab
nits
SamGinzburg 4457c64
comment nit
SamGinzburg d61769c
use contains instead, check any val for tma desc
SamGinzburg c9fdd6c
rename get first func + check isKernel / remove dead code
SamGinzburg 873741f
formatting
SamGinzburg e41ba84
move test_irsource to tools dir
SamGinzburg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import tempfile | ||
import triton | ||
from triton.compiler import IRSource | ||
from triton._C.libtriton import ir | ||
|
||
target = triton.runtime.driver.active.get_current_target() | ||
|
||
|
||
def test_mlir_attribute_parsing() -> None: | ||
''' | ||
Tests that MLIR attributes are parsed correctly from input ttir/ttgir. | ||
|
||
Checks for the following: | ||
1. Name and type signature are parsed correctly | ||
2. _get_num_warps_from_ir_str() works | ||
3. tt.nv_tma_desc attribute is parsed correctly | ||
''' | ||
|
||
sample_ttgir = r""" | ||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> | ||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> | ||
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> | ||
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> | ||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> | ||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { | ||
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, | ||
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, | ||
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, | ||
%arg3: i32 {tt.divisibility = 16 : i32}, | ||
%arg4: i32 {tt.divisibility = 16 : i32}, | ||
%arg5: i32 {tt.divisibility = 16 : i32}, | ||
%arg6: i32 {tt.divisibility = 16 : i32}, | ||
%arg7: i32 {tt.divisibility = 16 : i32}, | ||
%arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32}, | ||
embg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
%desc: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} { | ||
tt.return | ||
} | ||
} | ||
""" | ||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: | ||
f.write(sample_ttgir) | ||
f.flush() | ||
context = ir.context() | ||
src = IRSource(f.name, context) | ||
|
||
# check name and type signature | ||
# should match ty_to_cpp(...) | ||
assert src.signature == \ | ||
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \ | ||
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"} | ||
assert src.name == "@matmul_kernel" | ||
|
||
# check num warps | ||
assert src.parse_options()['num_warps'] == 8 | ||
|
||
sample_ttgir_vector_add = r""" | ||
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> | ||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { | ||
tt.func public @add_kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, | ||
%arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}, | ||
%arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}, | ||
%arg3: i32 {tt.divisibility = 16 : i32}) | ||
attributes {noinline = false} { | ||
%c1024_i32 = arith.constant 1024 : i32 | ||
%0 = tt.get_program_id x : i32 | ||
%1 = arith.muli %0, %c1024_i32 : i32 | ||
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> | ||
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> | ||
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> | ||
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> | ||
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> | ||
%7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked> | ||
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked> | ||
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<i32>, #blocked> | ||
%10 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked> | ||
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked> | ||
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<i32>, #blocked> | ||
%13 = arith.addi %9, %12 : tensor<1024xi32, #blocked> | ||
%14 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked> | ||
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked> | ||
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<i32>, #blocked> | ||
tt.return | ||
} | ||
} | ||
""" | ||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: | ||
f.write(sample_ttgir_vector_add) | ||
f.flush() | ||
context = ir.context() | ||
src = IRSource(f.name, context) | ||
|
||
# now test compilation | ||
triton.compile(f.name, target=target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict | ||
from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict | ||
from .errors import CompilationError | ||
|
||
__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] | ||
__all__ = [ | ||
"compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", | ||
"LazyDict" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,19 +24,13 @@ | |
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing | ||
# zero or more arguments separated by commas, and capture it as group 2 (the argument list) | ||
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 | ||
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🙌🙌🙌 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the comment above is stale now. Can you guys remove it in whichever's your next PR (no need to create a PR just for this) |
||
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" | ||
prototype_pattern = { | ||
"ttir": mlir_prototype_pattern, | ||
"ttgir": mlir_prototype_pattern, | ||
"ptx": ptx_prototype_pattern, | ||
} | ||
|
||
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?' | ||
ptx_arg_type_pattern = r"\.param\s+\.(\w+)" | ||
arg_type_pattern = { | ||
"ttir": mlir_arg_type_pattern, | ||
"ttgir": mlir_arg_type_pattern, | ||
"ptx": ptx_arg_type_pattern, | ||
} | ||
|
||
|
@@ -54,16 +48,6 @@ def convert_type_repr(x): | |
return x | ||
|
||
|
||
def _get_num_warps_from_ir_str(src: str): | ||
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' | ||
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if | ||
# e.g. someone has an instruction (not module) attribute named "num-warps". | ||
num_warps_matches = re.findall(ttgir_num_warps_pattern, src) | ||
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" | ||
num_warps = int(num_warps_matches[0]) | ||
return num_warps | ||
|
||
|
||
class ASTSource: | ||
|
||
def __init__(self, fn, signature, constants=None, attrs=None) -> None: | ||
|
@@ -106,28 +90,41 @@ def parse_options(self): | |
|
||
class IRSource: | ||
|
||
def __init__(self, path): | ||
def __init__(self, path, context): | ||
SamGinzburg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.path = path | ||
path = Path(path) | ||
self.ext = path.suffix[1:] | ||
self.src = path.read_text() | ||
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) | ||
self.name = match.group(1) | ||
signature = match.group(2) | ||
types = re.findall(arg_type_pattern[self.ext], signature) | ||
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} | ||
ir.load_dialects(context) | ||
|
||
# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now. | ||
# TODO - replace with a proper parser | ||
if self.ext == "ptx": | ||
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) | ||
self.name = match.group(1) | ||
signature = match.group(2) | ||
types = re.findall(arg_type_pattern[self.ext], signature) | ||
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} | ||
else: | ||
self.module = ir.parse_mlir_module(self.path, context) | ||
peterbell10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fn_name = self.module.get_entry_func_name() | ||
self.name = "@" + fn_name | ||
funcOp = self.module.get_function(fn_name) | ||
func_ty = self.module.get_function_signature(funcOp) | ||
self.signature = {k: ty for k, ty in enumerate(func_ty)} | ||
|
||
def hash(self): | ||
return hashlib.sha256(self.src.encode("utf-8")).hexdigest() | ||
|
||
def make_ir(self, options, codegen_fns, module_map, context): | ||
module = ir.parse_mlir_module(self.path, context) | ||
module.context = context | ||
return module | ||
self.module.context = context | ||
return self.module | ||
|
||
def parse_options(self): | ||
if self.ext == "ttgir": | ||
return {'num_warps': _get_num_warps_from_ir_str(self.src)} | ||
num_warps = self.module.get_int_attr("triton_gpu.num-warps") | ||
assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute" | ||
return {'num_warps': num_warps} | ||
return dict() | ||
|
||
|
||
|
@@ -223,7 +220,9 @@ def compile(src, target=None, options=None): | |
# create backend | ||
if ir_source: | ||
assert isinstance(src, str), "source must be either AST or a filepath" | ||
src = IRSource(src) | ||
context = ir.context() | ||
src = IRSource(src, context) | ||
|
||
extra_options = src.parse_options() | ||
options = backend.parse_options(dict(options or dict(), **extra_options)) | ||
# create cache manager | ||
|
@@ -264,9 +263,15 @@ def compile(src, target=None, options=None): | |
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. | ||
if ir_source: | ||
first_stage += 1 | ||
context = ir.context() | ||
ir.load_dialects(context) | ||
backend.load_dialects(context) | ||
|
||
if not isinstance(src, IRSource): | ||
context = ir.context() | ||
ir.load_dialects(context) | ||
backend.load_dialects(context) | ||
else: | ||
# For IRSource, we have already grabbed the context + called ir.load_dialects | ||
# just need to load the dialects for the backend. | ||
backend.load_dialects(context) | ||
codegen_fns = backend.get_codegen_implementation() | ||
module_map = backend.get_module_map() | ||
try: | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe that should be in python/test/unit/tools/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just moved test_irsource.py to the tools dir