Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proton][Dialect] Add Initial Frontend and Target Backend Infrastructure For Proton Dialect #5506

Merged
merged 30 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

namespace proton {
CRobeck marked this conversation as resolved.
Show resolved Hide resolved
void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);
} // namespace proton
} // namespace triton
} // namespace mlir

Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_triton_library(TritonGPUToLLVM
SPMDOpToLLVM.cpp
DecomposeUnsupportedConversions.cpp
PrintOpToLLVM.cpp
RecordOpToLLVM.cpp
CRobeck marked this conversation as resolved.
Show resolved Hide resolved

DEPENDS
TritonGPUConversionPassIncGen
Expand All @@ -33,6 +34,7 @@ add_triton_library(TritonGPUToLLVM
MLIRGPUTransforms
TritonAnalysis
TritonIR
ProtonIR
TritonGPUIR
TritonGPUTransforms
TritonNvidiaGPUTransforms
Expand Down
40 changes: 40 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"

namespace {

struct RecordOpConversion
: public ConvertOpToLLVMPattern<mlir::triton::proton::RecordOp> {
explicit RecordOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit)
: mlir::ConvertOpToLLVMPattern<mlir::triton::proton::RecordOp>(
typeConverter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(mlir::triton::proton::RecordOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.eraseOp(op);
return success();
}

protected:
const TargetInfoBase &targetInfo;
};

} // namespace

void mlir::triton::proton::populateRecordOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<RecordOpConversion>(typeConverter, targetInfo, benefit);
}
1 change: 1 addition & 0 deletions lib/Conversion/TritonToTritonGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ add_triton_library(TritonToTritonGPU
TritonIR
TritonGPUIR
TritonGPUTransforms
ProtonIR
CRobeck marked this conversation as resolved.
Show resolved Hide resolved
)
11 changes: 10 additions & 1 deletion lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"

namespace {

using namespace mlir;
Expand Down Expand Up @@ -555,7 +557,13 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::DotScaledOp>, GenericOpPattern<triton::CallOp>,
TritonFuncOpPattern>(typeConverter, context);
}

// Proton patterns
void populateProtonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<GenericOpPattern<triton::proton::RecordOp>>(typeConverter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be taken out if we don't consider the frontend changes for now

Copy link
Contributor

@Jokeren Jokeren Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we don't need any conversion here anyway because proton's inputs are scalars but not tensors. Could you please double check? @CRobeck

Copy link
Contributor Author

@CRobeck CRobeck Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are scalars but could envision a case where we pass in tensors in the future - if we want to pass in a Triton tensor object for example for object specific tracing. I don't think the conversion hurts anything here but acknowledge it could potentially be misleading.

You're right though that I can remove this entire function (populateProtonPatterns) and the test still passes.

context);
}
//
// SCF patterns
//
Expand Down Expand Up @@ -770,6 +778,7 @@ class ConvertTritonToTritonGPU
populateArithPatternsAndLegality(typeConverter, patterns, target);
populateMathPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns, numCTAs);
populateProtonPatterns(typeConverter, patterns);
// TODO: can we use
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns);
Expand Down
9 changes: 8 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/SourceMgr.h"

#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
CRobeck marked this conversation as resolved.
Show resolved Hide resolved

namespace {

namespace py = pybind11;
Expand Down Expand Up @@ -235,7 +237,8 @@ void init_triton_ir(py::module &&m) {
registry.insert<TritonDialect, ::mlir::triton::gpu::TritonGPUDialect,
math::MathDialect, arith::ArithDialect, scf::SCFDialect,
::mlir::gpu::GPUDialect, cf::ControlFlowDialect,
LLVM::LLVMDialect, mlir::ub::UBDialect>();
::mlir::triton::proton::ProtonDialect, LLVM::LLVMDialect,
CRobeck marked this conversation as resolved.
Show resolved Hide resolved
mlir::ub::UBDialect>();
mlir::LLVM::registerInlinerInterface(registry);
registerBuiltinDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
Expand Down Expand Up @@ -1603,6 +1606,10 @@ void init_triton_ir(py::module &&m) {
llvm::StringRef(prefix));
self.create<PrintOp>(prefixAttr, hex, values, isSigned);
})
.def("create_proton_record",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's take the python api changes out of this PR. It's not determined yet and probably not very important in the short term. Users, like @pawelszczerbuk, can modify the GPU IR to instrument record ops

Copy link
Contributor Author

@CRobeck CRobeck Dec 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly just to have a template in place for the future Proton front end development to interact with the upper level Triton Python kernel code - we can always modify the interface later but how the Proton Python frontend fit together with the Triton IR builder wasn't super intuitive at first so having this here I think is helpful for the moment even if we ultimately change the API in the future. It also gives us some more testing avenues. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concern I have is that we don't have any real functionalities associated with these record ops anyway. Usually, we add a frontend op when there's at least some functionalities implemented. For example, the most recent tl.gather op. Though its initial commit doesn't address all problems and concerns, frontend users are able to use the op in their kernels but not just placeholders. So IMHO it makes more sense to add the op, and mark it as experimental, after the lowering has been implemented.

Copy link
Contributor Author

@CRobeck CRobeck Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about putting things in a dev module with some warnings about them being under development and not intentioned for use outside the development team:

https://github.com/triton-lang/triton/pull/5506/files#diff-054c0ddf64263bf99d77a237a61deeb79c0cd3e4289a03eba948ce833f1cdce0R17

https://github.com/triton-lang/triton/pull/5506/files#diff-bc216e7a68b5a4545f1ff69bfbf5d0fa64adb6b1acaca6ea0f366f47e93942ebR25

Then that gives us an easy method to develop/test in a way that turning them "on" is just a matter of moving them from the dev module to the experimental namespace?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's much better now!

Let's just move it under proton.language without using the dev class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think create_proton_record can be moved to triton_proton.cc

Copy link
Contributor Author

@CRobeck CRobeck Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to be under proton.language.

Also I think create_proton_record can be moved to triton_proton.cc

I think this would cause a lot of code disruption. TritonOpBuilder is a unique ptr wrapper around the MLIR OpBuilder. So I think no matter what we'd have to find way to pass that object to triton_proton.cc either through moving the definition of TritonOpBuilder out of ir.cc into a header file or possibly through a helper function defined in ir.cc and called from triton_proton.cc which seems like an equal or more amount of complexity to the existing code.

Maybe we can just invoke the generic MLIR OpBuilder from triton_proton.cc? But then I think we'd be mixing and matching builder objects in the code generator?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get it now. Let's move it to the last line of all ops and leave a comment:

// proton ops
.def("create_proton_record",
...

It's a problem for all third party backends that they cannot define custom ops. We will have to think about a better solution.

Copy link
Contributor Author

@CRobeck CRobeck Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

It's a problem for all third party backends that they cannot define custom ops. We will have to think about a better solution.

Right, there has to be an upper level Triton Op that the backends can specialize but I could not find a way to have a wholly custom third party backend Op. Given that TritonOpBuilder is defined in ir.cc the only solution I see is to pull that class/object out into a standalone module that can be passed to each backend so that they all operate on the same MLIR builder instance.

[](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void {
self.create<mlir::triton::proton::RecordOp>(isStart, regionId);
})
.def("create_assert",
[](TritonOpBuilder &self, Value &condition,
const std::string &message) -> void {
Expand Down
4 changes: 4 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ struct ConvertTritonAMDGPUToLLVM
patterns);
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, commonBenefit);

mlir::triton::proton::populateRecordOpToLLVMPattern(
typeConverter, patterns, targetInfo, commonBenefit);

mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ struct ConvertTritonGPUToLLVM
targetInfo, benefit);
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
targetInfo, benefit);
mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns,
Expand Down
1 change: 1 addition & 0 deletions third_party/proton/proton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
deactivate,
finalize,
profile,
record,
DEFAULT_PROFILE_NAME,
)
6 changes: 6 additions & 0 deletions third_party/proton/proton/language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from triton._C.libtriton import ir
from triton.language import core as tl


def proton_record(isStart: bool, regionId: int, builder: ir.builder) -> tl.tensor:
CRobeck marked this conversation as resolved.
Show resolved Hide resolved
return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void)
8 changes: 8 additions & 0 deletions third_party/proton/proton/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
import os

from triton._C.libproton import proton as libproton
from triton.language import core as tl
CRobeck marked this conversation as resolved.
Show resolved Hide resolved
from triton.language.core import builtin
from .hook import register_triton_hook, unregister_triton_hook
from .flags import set_profiling_off, set_profiling_on, is_command_line
from typing import Optional
from . import language

DEFAULT_PROFILE_NAME = "proton"


@builtin
def record(isStart: bool, regionId: int, _builder=None):
return language.proton_record(isStart, regionId, _builder)


def _select_backend() -> str:
backend = triton.runtime.driver.active.get_current_target().backend
if backend == "cuda":
Expand Down
41 changes: 41 additions & 0 deletions third_party/proton/test/test_proton_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import pytest
import pathlib

import triton
import triton.language as tl
import triton.profiler as proton


def test_proton_record(tmp_path: pathlib.Path):

@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
proton.record(True, 0)
y = tl.load(y_ptr + offsets, mask=mask)
proton.record(False, 0)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)

torch.manual_seed(0)
size = 2**12
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output = torch.empty_like(x)
n_elements = output.numel()
grid = (1, 1, 1)
pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
ttir = pgm.asm['ttir']
assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir
assert "proton.record() {isStart = false, regionId = 0 : i32}" in ttir
Loading