-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Proton][Dialect] Add Initial Frontend and Target Backend Infrastructure For Proton Dialect #5506
Changes from 21 commits
c27f894
7496112
e47e4c9
435fbe6
a33f513
f4095cc
09b7cf9
877e2ea
ae0ff69
e4784f4
754fdd1
d2f9c0b
7b0485d
36f1861
3897f81
bf1011e
97b92b2
6a0dc4c
f1911e4
b085d75
5cb4ad2
190d26c
3c72681
d980aea
134f102
6a64945
ca63a07
3184568
a61a21e
9e98d14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,8 @@ | |
#include "triton/Tools/Sys/GetEnv.hpp" | ||
#include "llvm/Support/SourceMgr.h" | ||
|
||
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" | ||
CRobeck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
namespace { | ||
|
||
namespace py = pybind11; | ||
|
@@ -235,7 +237,8 @@ void init_triton_ir(py::module &&m) { | |
registry.insert<TritonDialect, ::mlir::triton::gpu::TritonGPUDialect, | ||
math::MathDialect, arith::ArithDialect, scf::SCFDialect, | ||
::mlir::gpu::GPUDialect, cf::ControlFlowDialect, | ||
LLVM::LLVMDialect, mlir::ub::UBDialect>(); | ||
::mlir::triton::proton::ProtonDialect, LLVM::LLVMDialect, | ||
CRobeck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mlir::ub::UBDialect>(); | ||
mlir::LLVM::registerInlinerInterface(registry); | ||
registerBuiltinDialectTranslation(registry); | ||
registerLLVMDialectTranslation(registry); | ||
|
@@ -1603,6 +1606,10 @@ void init_triton_ir(py::module &&m) { | |
llvm::StringRef(prefix)); | ||
self.create<PrintOp>(prefixAttr, hex, values, isSigned); | ||
}) | ||
.def("create_proton_record", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's take the python api changes out of this PR. It's not determined yet and probably not very important in the short term. Users, like @pawelszczerbuk, can modify the GPU IR to instrument record ops There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is mostly just to have a template in place for the future Proton front end development to interact with the upper level Triton Python kernel code - we can always modify the interface later but how the Proton Python frontend fit together with the Triton IR builder wasn't super intuitive at first so having this here I think is helpful for the moment even if we ultimately change the API in the future. It also gives us some more testing avenues. What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The concern I have is that we don't have any real functionalities associated with these record ops anyway. Usually, we add a frontend op when there's at least some functionalities implemented. For example, the most recent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about putting things in a Then that gives us an easy method to develop/test in a way that turning them "on" is just a matter of moving them from the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's much better now! Let's just move it under There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to be under proton.language.
I think this would cause a lot of code disruption. TritonOpBuilder is a unique ptr wrapper around the MLIR OpBuilder. So I think no matter what we'd have to find way to pass that object to triton_proton.cc either through moving the definition of TritonOpBuilder out of ir.cc into a header file or possibly through a helper function defined in ir.cc and called from triton_proton.cc which seems like an equal or more amount of complexity to the existing code. Maybe we can just invoke the generic MLIR OpBuilder from triton_proton.cc? But then I think we'd be mixing and matching builder objects in the code generator? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I get it now. Let's move it to the last line of all ops and leave a comment:
It's a problem for all third party backends that they cannot define custom ops. We will have to think about a better solution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated.
Right, there has to be an upper level Triton Op that the backends can specialize but I could not find a way to have a wholly custom third party backend Op. Given that TritonOpBuilder is defined in ir.cc the only solution I see is to pull that class/object out into a standalone module that can be passed to each backend so that they all operate on the same MLIR builder instance. |
||
[](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void { | ||
self.create<mlir::triton::proton::RecordOp>(isStart, regionId); | ||
}) | ||
.def("create_assert", | ||
[](TritonOpBuilder &self, Value &condition, | ||
const std::string &message) -> void { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,4 +29,5 @@ add_triton_library(TritonAMDGPUToLLVM | |
LINK_LIBS PUBLIC | ||
TritonGPUToLLVM | ||
TritonAMDGPUIR | ||
TritonProtonToLLVM | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,4 +25,5 @@ add_triton_library(TritonNVIDIAGPUToLLVM | |
|
||
LINK_LIBS PUBLIC | ||
TritonGPUToLLVM | ||
TritonProtonToLLVM | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H | ||
#define TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H | ||
|
||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::triton; | ||
|
||
namespace mlir { | ||
namespace triton { | ||
namespace proton { | ||
void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, | ||
RewritePatternSet &patterns, | ||
const TargetInfoBase &targetInfo, | ||
PatternBenefit benefit); | ||
} // namespace proton | ||
} // namespace triton | ||
} // namespace mlir | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_subdirectory(Dialect) | ||
add_subdirectory(TritonProtonToLLVM) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
add_triton_library(TritonProtonToLLVM | ||
RecordOpToLLVM.cpp | ||
|
||
DEPENDS | ||
TritonGPUConversionPassIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
CRobeck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
MLIRIR | ||
MLIRPass | ||
MLIRGPUDialect | ||
MLIRGPUToNVVMTransforms | ||
MLIRGPUToROCDLTransforms | ||
MLIRGPUTransforms | ||
TritonAnalysis | ||
TritonIR | ||
ProtonIR | ||
TritonGPUIR | ||
TritonGPUTransforms | ||
TritonNvidiaGPUTransforms | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#include "mlir/Conversion/LLVMCommon/Pattern.h" | ||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/Utility.h" | ||
#include "triton/Dialect/Triton/IR/Dialect.h" | ||
|
||
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" | ||
#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" | ||
|
||
namespace { | ||
|
||
struct RecordOpConversion | ||
: public ConvertOpToLLVMPattern<mlir::triton::proton::RecordOp> { | ||
explicit RecordOpConversion(LLVMTypeConverter &typeConverter, | ||
const TargetInfoBase &targetInfo, | ||
PatternBenefit benefit) | ||
: mlir::ConvertOpToLLVMPattern<mlir::triton::proton::RecordOp>( | ||
typeConverter, benefit), | ||
targetInfo(targetInfo) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(mlir::triton::proton::RecordOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
rewriter.eraseOp(op); | ||
return success(); | ||
} | ||
|
||
protected: | ||
const TargetInfoBase &targetInfo; | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::triton::proton::populateRecordOpToLLVMPattern( | ||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, | ||
const TargetInfoBase &targetInfo, PatternBenefit benefit) { | ||
patterns.add<RecordOpConversion>(typeConverter, targetInfo, benefit); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,5 +7,6 @@ | |
deactivate, | ||
finalize, | ||
profile, | ||
dev, | ||
DEFAULT_PROFILE_NAME, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from triton._C.libtriton import ir | ||
from triton.language import core as tl | ||
|
||
|
||
def proton_record(isStart: bool, regionId: int, builder: ir.builder) -> tl.tensor: | ||
CRobeck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
import pytest | ||
import pathlib | ||
|
||
import triton | ||
import triton.language as tl | ||
import triton.profiler as proton | ||
|
||
|
||
def test_proton_record(tmp_path: pathlib.Path): | ||
|
||
@triton.jit | ||
def add_kernel( | ||
x_ptr, | ||
y_ptr, | ||
output_ptr, | ||
n_elements, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
pid = tl.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
x = tl.load(x_ptr + offsets, mask=mask) | ||
proton.dev.record(True, 0) | ||
y = tl.load(y_ptr + offsets, mask=mask) | ||
proton.dev.record(False, 0) | ||
output = x + y | ||
tl.store(output_ptr + offsets, output, mask=mask) | ||
|
||
torch.manual_seed(0) | ||
size = 2**12 | ||
x = torch.rand(size, device='cuda') | ||
y = torch.rand(size, device='cuda') | ||
output = torch.empty_like(x) | ||
n_elements = output.numel() | ||
grid = (1, 1, 1) | ||
pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) | ||
ttir = pgm.asm['ttir'] | ||
assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir | ||
assert "proton.record() {isStart = false, regionId = 0 : i32}" in ttir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be taken out if we don't consider the frontend changes for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we don't need any conversion here anyway because proton's inputs are scalars but not tensors. Could you please double check? @CRobeck
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are scalars but could envision a case where we pass in tensors in the future - if we want to pass in a Triton tensor object for example for object specific tracing. I don't think the conversion hurts anything here but acknowledge it could potentially be misleading.
You're right though that I can remove this entire function (populateProtonPatterns) and the test still passes.