From c27f8945d6d186da48da62cf814e1532d6b17818 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Wed, 25 Dec 2024 19:01:51 +0000 Subject: [PATCH 01/29] temp --- .../PatternTritonGPUOpToLLVM.h | 5 ++ lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 2 + .../TritonGPUToLLVM/RecordOpToLLVM.cpp | 40 +++++++++++ .../TritonToTritonGPU/CMakeLists.txt | 1 + .../TritonToTritonGPUPass.cpp | 5 +- python/src/ir.cc | 16 ++++- .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 4 ++ .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 2 + third_party/proton/CMakeLists.txt | 12 +++- third_party/proton/csrc/Proton.cpp | 3 + third_party/proton/dialect/CMakeLists.txt | 2 +- .../lib/Dialect/Proton/IR/CMakeLists.txt | 2 + third_party/proton/dialect/triton_proton.cc | 51 ++++++++++++- third_party/proton/proton/__init__.py | 1 + third_party/proton/proton/profile.py | 43 ++++++++++- third_party/proton/proton/semantic.py | 26 +++++++ third_party/proton/test/test_proton_record.py | 71 +++++++++++++++++++ 17 files changed, 278 insertions(+), 8 deletions(-) create mode 100644 lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp create mode 100644 third_party/proton/proton/semantic.py create mode 100644 third_party/proton/test/test_proton_record.py diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index c5c78e6d5b20..79335273e520 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -121,6 +121,11 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index d6cc4387f79e..96b60bf01ac7 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -20,6 +20,7 @@ add_triton_library(TritonGPUToLLVM SPMDOpToLLVM.cpp DecomposeUnsupportedConversions.cpp PrintOpToLLVM.cpp + RecordOpToLLVM.cpp DEPENDS TritonGPUConversionPassIncGen @@ -33,6 +34,7 @@ add_triton_library(TritonGPUToLLVM MLIRGPUTransforms TritonAnalysis TritonIR + ProtonIR TritonGPUIR TritonGPUTransforms TritonNvidiaGPUTransforms diff --git a/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp new file mode 100644 index 000000000000..40ee63c4a6a0 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp @@ -0,0 +1,40 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" + + +namespace { + +struct RecordOpConversion : public ConvertOpToLLVMPattern { + explicit RecordOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(mlir::triton::proton::RecordOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.eraseOp(op); + return success(); + } + + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateRecordOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 1b629ba1639f..1438ec75c47d 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -12,4 +12,5 @@ add_triton_library(TritonToTritonGPU TritonIR TritonGPUIR TritonGPUTransforms + ProtonIR ) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 67ab63beb736..e45f6918e599 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -16,6 +16,8 @@ #include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" + namespace { using namespace mlir; @@ -544,7 +546,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, - GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, diff --git a/python/src/ir.cc b/python/src/ir.cc index 53ba39ae1026..9fbcb49ab89a 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -31,6 +31,8 @@ #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/SourceMgr.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" + namespace { namespace py = pybind11; @@ -234,7 +236,7 @@ void init_triton_ir(py::module &&m) { DialectRegistry registry; registry.insert(); mlir::LLVM::registerInlinerInterface(registry); registerBuiltinDialectTranslation(registry); @@ -1603,6 +1605,18 @@ void init_triton_ir(py::module &&m) { llvm::StringRef(prefix)); self.create(prefixAttr, hex, values, isSigned); }) + .def("create_record", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values, + const std::vector &isSigned) -> void { + auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)); + self.create(prefixAttr, hex, values, isSigned); + }) + .def("create_proton_record", + [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void{ + self.create(isStart, regionId); + }) .def("create_assert", [](TritonOpBuilder &self, Value &condition, const std::string &message) -> void { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 0e29b0c00d2b..4564176059f4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -228,6 +228,10 @@ struct ConvertTritonAMDGPUToLLVM patterns); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); + + mlir::triton::populateRecordOpToLLVMPattern(typeConverter, patterns, + targetInfo, commonBenefit); + mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 089e4aaebb2b..444017fe8453 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -149,6 +149,8 @@ struct ConvertTritonGPUToLLVM targetInfo, benefit); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); + mlir::triton::populateRecordOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns, diff --git a/third_party/proton/CMakeLists.txt b/third_party/proton/CMakeLists.txt index e0fafb43a929..de5fe978a730 100644 --- a/third_party/proton/CMakeLists.txt +++ b/third_party/proton/CMakeLists.txt @@ -2,8 +2,11 @@ project(Proton CXX) set(PROTON_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/csrc) set(PROTON_EXTERN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/extern) +set(PROTON_DIALECT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/dialect) file(GLOB_RECURSE PROTON_SRC ${PROTON_SRC_DIR}/lib/*.cpp) -add_library(proton SHARED ${PROTON_SRC} ${PROTON_SRC_DIR}/${PROJECT_NAME}.cpp) +add_library(proton SHARED ${PROTON_SRC} ${PROTON_SRC_DIR}/${PROJECT_NAME}.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dialect/triton_proton.cc + ) if(NOT CUPTI_INCLUDE_DIR) message(FATAL_ERROR "CUPTI include directory not defined") @@ -17,6 +20,7 @@ endif() include_directories(${JSON_INCLUDE_DIR}) include_directories(${PROTON_SRC_DIR}/include) +include_directories(${PROTON_DIALECT_DIR}/include) include_directories(${PROTON_EXTERN_DIR}) find_package(Python3 REQUIRED Interpreter Development.Module) @@ -38,5 +42,9 @@ include_directories(${CUPTI_INCLUDE_DIR}) include_directories(SYSTEM ${ROCTRACER_INCLUDE_DIR}) target_compile_definitions(proton PRIVATE __HIP_PLATFORM_AMD__) -target_link_libraries(proton PRIVATE Python3::Module pybind11::headers) +target_link_libraries(proton PRIVATE Python3::Module pybind11::headers LLVMSupport MLIRIR ProtonIR TritonGPUIR TritonIR MLIRGPUDialect + MLIRTransforms + MLIRTransformUtils + TritonGPUTransforms + ) target_link_options(proton PRIVATE ${PROTON_PYTHON_LDFLAGS}) diff --git a/third_party/proton/csrc/Proton.cpp b/third_party/proton/csrc/Proton.cpp index 7c1e07bf3d9b..c2a19d950005 100644 --- a/third_party/proton/csrc/Proton.cpp +++ b/third_party/proton/csrc/Proton.cpp @@ -88,7 +88,10 @@ void initProton(pybind11::module &&m) { pybind11::bind_map>(m, "MetricMap"); } +void init_triton_proton(pybind11::module &&m); + PYBIND11_MODULE(libproton, m) { m.doc() = "Python bindings to the Proton API"; initProton(std::move(m.def_submodule("proton"))); + init_triton_proton(std::move(m.def_submodule("ttproton"))); } diff --git a/third_party/proton/dialect/CMakeLists.txt b/third_party/proton/dialect/CMakeLists.txt index cfa5938873d9..ecf525b2137d 100644 --- a/third_party/proton/dialect/CMakeLists.txt +++ b/third_party/proton/dialect/CMakeLists.txt @@ -4,5 +4,5 @@ add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc) - target_link_libraries(TritonProton PRIVATE ProtonIR Python3::Module pybind11::headers) + target_link_libraries(TritonProton PRIVATE ProtonIR Python3::Module pybind11::headers LLVMSupport MLIRIR) endif() diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt index 5eea5cb3cf9e..2cfe9e8e31e1 100644 --- a/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt @@ -7,6 +7,8 @@ add_triton_library(ProtonIR ProtonAttrDefsIncGen LINK_LIBS PUBLIC + LLVMSupport + MLIRIR MLIRLLVMDialect TritonIR TritonGPUIR diff --git a/third_party/proton/dialect/triton_proton.cc b/third_party/proton/dialect/triton_proton.cc index 8046539794e1..f717ab073338 100644 --- a/third_party/proton/dialect/triton_proton.cc +++ b/third_party/proton/dialect/triton_proton.cc @@ -1,15 +1,59 @@ #include "Dialect/Proton/IR/Dialect.h" #include "mlir/Pass/PassManager.h" #include "passes.h" + #include -#include -#include namespace py = pybind11; +namespace { + +using namespace mlir; +using namespace triton; + +class ProtonOpBuilder { +public: + ProtonOpBuilder(MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + OpBuilder &getBuilder() { return *builder; } + MLIRContext *getContext() { return builder->getContext(); } + + Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; +}; + +} // anonymous namespace + + + void init_triton_proton(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + using namespace ::mlir::triton::proton; auto passes = m.def_submodule("passes"); + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + + .def("create_record", + [](ProtonOpBuilder &self, bool isStart, int32_t regionId) -> void{ + self.create(isStart, regionId); + }); + // load dialects m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; @@ -17,4 +61,7 @@ void init_triton_proton(py::module &&m) { context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); + + } + diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index ded8b01142af..e3766965fd18 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -7,5 +7,6 @@ deactivate, finalize, profile, + record, DEFAULT_PROFILE_NAME, ) diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 575c85b0cac8..d71015a84540 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -3,11 +3,52 @@ import os from triton._C.libproton import proton as libproton +from triton.language import core as tl +import triton.language +#from triton._C.libtriton import ir from .hook import register_triton_hook, unregister_triton_hook +from functools import wraps from .flags import set_profiling_off, set_profiling_on, is_command_line -from typing import Optional +from typing import Optional, TypeVar +from . import semantic +import builtins + +T = TypeVar('T') DEFAULT_PROFILE_NAME = "proton" +TRITON_BUILTIN = "__triton_builtin__" + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + +@builtin +def record(prefix, *args, hex=False, _builder=None): + import string + prefix = tl._constexpr_to_value(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(triton.language.semantic.to_tensor(arg, _builder)) + return semantic.proton_record(True, 0, _builder) +# return semantic.record(prefix, new_args, hex, _builder) def _select_backend() -> str: diff --git a/third_party/proton/proton/semantic.py b/third_party/proton/proton/semantic.py new file mode 100644 index 000000000000..c0a2368f1ed8 --- /dev/null +++ b/third_party/proton/proton/semantic.py @@ -0,0 +1,26 @@ +from __future__ import annotations # remove after python 3.11 +import warnings + +from typing import List, Optional, Sequence, Tuple, TypeVar +import numbers + +from triton._C.libtriton import ir +from triton._C.libproton import ttproton +#import triton._C.libproton as proton +from triton.language import core as tl + +def proton_record(isStart: bool, regionId: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) + +def record(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + new_args = [arg.handle for arg in args] + is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args] + return tl.tensor(builder.create_record(prefix, hex, new_args, is_signed), tl.void) diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_proton_record.py new file mode 100644 index 000000000000..089ea9959696 --- /dev/null +++ b/third_party/proton/test/test_proton_record.py @@ -0,0 +1,71 @@ +import torch + +import triton +import triton.language as tl +import triton.profiler as proton + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + proton.record("x: ", x, hex=True) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +# %% +# Let's also declare a helper function to (1) allocate the `z` tensor +# and (2) enqueue the above kernel with appropriate grid/block sizes: + + +def add(x: torch.Tensor, y: torch.Tensor): + # We need to preallocate the output. + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + # The SPMD launch grid denotes the number of kernel instances that run in parallel. + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. + # In this case, we use a 1D grid where the size is the number of blocks: + BLOCK = 4096 + grid = (int(n_elements / BLOCK), 1, 1) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + ttir = pgm.asm['ttir'] +# print(ttir) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return output + +torch.manual_seed(0) +size = 2**12 +x = torch.rand(size, device='cuda') +y = torch.rand(size, device='cuda') + +output_triton = add(x, y) +print(output_triton) From 7496112ff9fb3cbe49173cdcc6a7cc61943b2a3f Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 02:14:08 +0000 Subject: [PATCH 02/29] update --- python/src/ir.cc | 16 ++-- third_party/proton/dialect/triton_proton.cc | 80 +++++++++---------- third_party/proton/proton/profile.py | 17 +--- third_party/proton/proton/semantic.py | 26 +++--- third_party/proton/test/test_proton_record.py | 4 +- 5 files changed, 65 insertions(+), 78 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index 9fbcb49ab89a..4936b9ecef76 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1605,14 +1605,14 @@ void init_triton_ir(py::module &&m) { llvm::StringRef(prefix)); self.create(prefixAttr, hex, values, isSigned); }) - .def("create_record", - [](TritonOpBuilder &self, const std::string &prefix, bool hex, - const std::vector &values, - const std::vector &isSigned) -> void { - auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), - llvm::StringRef(prefix)); - self.create(prefixAttr, hex, values, isSigned); - }) +// .def("create_record", +// [](TritonOpBuilder &self, const std::string &prefix, bool hex, +// const std::vector &values, +// const std::vector &isSigned) -> void { +// auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), +// llvm::StringRef(prefix)); +// self.create(prefixAttr, hex, values, isSigned); +// }) .def("create_proton_record", [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void{ self.create(isStart, regionId); diff --git a/third_party/proton/dialect/triton_proton.cc b/third_party/proton/dialect/triton_proton.cc index f717ab073338..0007aaf80455 100644 --- a/third_party/proton/dialect/triton_proton.cc +++ b/third_party/proton/dialect/triton_proton.cc @@ -6,54 +6,54 @@ namespace py = pybind11; -namespace { - -using namespace mlir; -using namespace triton; - -class ProtonOpBuilder { -public: - ProtonOpBuilder(MLIRContext *context) { - builder = std::make_unique(context); - lastLoc = std::make_unique(builder->getUnknownLoc()); - } - - OpBuilder &getBuilder() { return *builder; } - MLIRContext *getContext() { return builder->getContext(); } - - Location getLastLoc() { - assert(lastLoc); - return *lastLoc; - } - template OpTy create(Args &&...args) { - auto loc = getLastLoc(); - return builder->create(loc, std::forward(args)...); - } - -private: - std::unique_ptr builder; - std::unique_ptr lastLoc; -}; - -} // anonymous namespace +//namespace { +// +//using namespace mlir; +//using namespace triton; +// +//class ProtonOpBuilder { +//public: +// ProtonOpBuilder(MLIRContext *context) { +// builder = std::make_unique(context); +// lastLoc = std::make_unique(builder->getUnknownLoc()); +// } +// +// OpBuilder &getBuilder() { return *builder; } +// MLIRContext *getContext() { return builder->getContext(); } +// +// Location getLastLoc() { +// assert(lastLoc); +// return *lastLoc; +// } +// template OpTy create(Args &&...args) { +// auto loc = getLastLoc(); +// return builder->create(loc, std::forward(args)...); +// } +// +//private: +// std::unique_ptr builder; +// std::unique_ptr lastLoc; +//}; +// +//} // anonymous namespace void init_triton_proton(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; - using namespace ::mlir::triton::proton; +// using namespace ::mlir::triton::proton; auto passes = m.def_submodule("passes"); - py::class_(m, "builder", py::module_local(), - py::dynamic_attr()) - .def(py::init()) - - .def("create_record", - [](ProtonOpBuilder &self, bool isStart, int32_t regionId) -> void{ - self.create(isStart, regionId); - }); - +// py::class_(m, "builder", py::module_local(), +// py::dynamic_attr()) +// .def(py::init()) +// +// .def("create_record", +// [](ProtonOpBuilder &self, bool isStart, int32_t regionId) -> void{ +// self.create(isStart, regionId); +// }); +// // load dialects m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index d71015a84540..a8b6137758fa 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -34,21 +34,8 @@ def wrapper(*args, **kwargs): return wrapper @builtin -def record(prefix, *args, hex=False, _builder=None): - import string - prefix = tl._constexpr_to_value(prefix) - assert isinstance(prefix, str), f"{prefix} is not string" - b_ascii = True - for ch in prefix: - if ch not in string.printable: - b_ascii = False - break - assert b_ascii, f"{prefix} is not an ascii string" - new_args = [] - for arg in args: - new_args.append(triton.language.semantic.to_tensor(arg, _builder)) - return semantic.proton_record(True, 0, _builder) -# return semantic.record(prefix, new_args, hex, _builder) +def record(isStart: bool, regionId: int, _builder=None): + return semantic.proton_record(isStart, regionId, _builder) def _select_backend() -> str: diff --git a/third_party/proton/proton/semantic.py b/third_party/proton/proton/semantic.py index c0a2368f1ed8..46a79339d818 100644 --- a/third_party/proton/proton/semantic.py +++ b/third_party/proton/proton/semantic.py @@ -5,22 +5,22 @@ import numbers from triton._C.libtriton import ir -from triton._C.libproton import ttproton +#from triton._C.libproton import ttproton #import triton._C.libproton as proton from triton.language import core as tl def proton_record(isStart: bool, regionId: int, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) -def record(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: - # It makes sense visually for prefix to end in ": "; make it so. Also, - # non-empty prefixes should start with " ". - if not prefix.endswith(" ") and args: - prefix += " " - if not prefix.endswith(": ") and args: - prefix = prefix[:-1] + ": " - if len(prefix) > 2 and not prefix.startswith(" "): - prefix = " " + prefix - new_args = [arg.handle for arg in args] - is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args] - return tl.tensor(builder.create_record(prefix, hex, new_args, is_signed), tl.void) +#def record(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: +# # It makes sense visually for prefix to end in ": "; make it so. Also, +# # non-empty prefixes should start with " ". +# if not prefix.endswith(" ") and args: +# prefix += " " +# if not prefix.endswith(": ") and args: +# prefix = prefix[:-1] + ": " +# if len(prefix) > 2 and not prefix.startswith(" "): +# prefix = " " + prefix +# new_args = [arg.handle for arg in args] +# is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args] +# return tl.tensor(builder.create_record(prefix, hex, new_args, is_signed), tl.void) diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_proton_record.py index 089ea9959696..bb6aab0b3210 100644 --- a/third_party/proton/test/test_proton_record.py +++ b/third_party/proton/test/test_proton_record.py @@ -29,7 +29,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. # Load x and y from DRAM, masking out any extra elements in case the input is not a # multiple of the block size. x = tl.load(x_ptr + offsets, mask=mask) - proton.record("x: ", x, hex=True) + proton.record(True, 0) y = tl.load(y_ptr + offsets, mask=mask) output = x + y # Write x + y back to DRAM. @@ -57,7 +57,7 @@ def add(x: torch.Tensor, y: torch.Tensor): # - Don't forget to pass meta-parameters as keywords arguments. pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) ttir = pgm.asm['ttir'] -# print(ttir) + print(ttir) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output From e47e4c9439a4eda2341fee3da2c1eb42ff86e2f4 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 02:36:05 +0000 Subject: [PATCH 03/29] update --- third_party/proton/csrc/Proton.cpp | 3 -- third_party/proton/dialect/triton_proton.cc | 42 --------------------- third_party/proton/proton/language.py | 9 +++++ third_party/proton/proton/profile.py | 5 +-- third_party/proton/proton/semantic.py | 26 ------------- 5 files changed, 11 insertions(+), 74 deletions(-) create mode 100644 third_party/proton/proton/language.py delete mode 100644 third_party/proton/proton/semantic.py diff --git a/third_party/proton/csrc/Proton.cpp b/third_party/proton/csrc/Proton.cpp index c2a19d950005..7c1e07bf3d9b 100644 --- a/third_party/proton/csrc/Proton.cpp +++ b/third_party/proton/csrc/Proton.cpp @@ -88,10 +88,7 @@ void initProton(pybind11::module &&m) { pybind11::bind_map>(m, "MetricMap"); } -void init_triton_proton(pybind11::module &&m); - PYBIND11_MODULE(libproton, m) { m.doc() = "Python bindings to the Proton API"; initProton(std::move(m.def_submodule("proton"))); - init_triton_proton(std::move(m.def_submodule("ttproton"))); } diff --git a/third_party/proton/dialect/triton_proton.cc b/third_party/proton/dialect/triton_proton.cc index 0007aaf80455..5ff9fa0f00c2 100644 --- a/third_party/proton/dialect/triton_proton.cc +++ b/third_party/proton/dialect/triton_proton.cc @@ -6,54 +6,12 @@ namespace py = pybind11; -//namespace { -// -//using namespace mlir; -//using namespace triton; -// -//class ProtonOpBuilder { -//public: -// ProtonOpBuilder(MLIRContext *context) { -// builder = std::make_unique(context); -// lastLoc = std::make_unique(builder->getUnknownLoc()); -// } -// -// OpBuilder &getBuilder() { return *builder; } -// MLIRContext *getContext() { return builder->getContext(); } -// -// Location getLastLoc() { -// assert(lastLoc); -// return *lastLoc; -// } -// template OpTy create(Args &&...args) { -// auto loc = getLastLoc(); -// return builder->create(loc, std::forward(args)...); -// } -// -//private: -// std::unique_ptr builder; -// std::unique_ptr lastLoc; -//}; -// -//} // anonymous namespace - - void init_triton_proton(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; -// using namespace ::mlir::triton::proton; auto passes = m.def_submodule("passes"); -// py::class_(m, "builder", py::module_local(), -// py::dynamic_attr()) -// .def(py::init()) -// -// .def("create_record", -// [](ProtonOpBuilder &self, bool isStart, int32_t regionId) -> void{ -// self.create(isStart, regionId); -// }); -// // load dialects m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; diff --git a/third_party/proton/proton/language.py b/third_party/proton/proton/language.py new file mode 100644 index 000000000000..f92a0b1a1d03 --- /dev/null +++ b/third_party/proton/proton/language.py @@ -0,0 +1,9 @@ +from __future__ import annotations # remove after python 3.11 +import warnings + +from triton._C.libtriton import ir +from triton.language import core as tl + +def proton_record(isStart: bool, regionId: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) + diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index a8b6137758fa..80ad1e48d409 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -5,12 +5,11 @@ from triton._C.libproton import proton as libproton from triton.language import core as tl import triton.language -#from triton._C.libtriton import ir from .hook import register_triton_hook, unregister_triton_hook from functools import wraps from .flags import set_profiling_off, set_profiling_on, is_command_line from typing import Optional, TypeVar -from . import semantic +from . import language import builtins T = TypeVar('T') @@ -35,7 +34,7 @@ def wrapper(*args, **kwargs): @builtin def record(isStart: bool, regionId: int, _builder=None): - return semantic.proton_record(isStart, regionId, _builder) + return language.proton_record(isStart, regionId, _builder) def _select_backend() -> str: diff --git a/third_party/proton/proton/semantic.py b/third_party/proton/proton/semantic.py deleted file mode 100644 index 46a79339d818..000000000000 --- a/third_party/proton/proton/semantic.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations # remove after python 3.11 -import warnings - -from typing import List, Optional, Sequence, Tuple, TypeVar -import numbers - -from triton._C.libtriton import ir -#from triton._C.libproton import ttproton -#import triton._C.libproton as proton -from triton.language import core as tl - -def proton_record(isStart: bool, regionId: int, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) - -#def record(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: -# # It makes sense visually for prefix to end in ": "; make it so. Also, -# # non-empty prefixes should start with " ". -# if not prefix.endswith(" ") and args: -# prefix += " " -# if not prefix.endswith(": ") and args: -# prefix = prefix[:-1] + ": " -# if len(prefix) > 2 and not prefix.startswith(" "): -# prefix = " " + prefix -# new_args = [arg.handle for arg in args] -# is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args] -# return tl.tensor(builder.create_record(prefix, hex, new_args, is_signed), tl.void) From 435fbe65c31aa1b76289a37598a07d3568365325 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 02:59:34 +0000 Subject: [PATCH 04/29] update --- python/src/ir.cc | 8 -------- third_party/proton/CMakeLists.txt | 14 +++----------- third_party/proton/dialect/CMakeLists.txt | 2 +- .../dialect/lib/Dialect/Proton/IR/CMakeLists.txt | 2 -- third_party/proton/dialect/triton_proton.cc | 9 ++------- third_party/proton/test/test_proton_record.py | 1 + 6 files changed, 7 insertions(+), 29 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index 4936b9ecef76..99143c019204 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1605,14 +1605,6 @@ void init_triton_ir(py::module &&m) { llvm::StringRef(prefix)); self.create(prefixAttr, hex, values, isSigned); }) -// .def("create_record", -// [](TritonOpBuilder &self, const std::string &prefix, bool hex, -// const std::vector &values, -// const std::vector &isSigned) -> void { -// auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), -// llvm::StringRef(prefix)); -// self.create(prefixAttr, hex, values, isSigned); -// }) .def("create_proton_record", [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void{ self.create(isStart, regionId); diff --git a/third_party/proton/CMakeLists.txt b/third_party/proton/CMakeLists.txt index de5fe978a730..d36fd594f5c7 100644 --- a/third_party/proton/CMakeLists.txt +++ b/third_party/proton/CMakeLists.txt @@ -2,12 +2,8 @@ project(Proton CXX) set(PROTON_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/csrc) set(PROTON_EXTERN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/extern) -set(PROTON_DIALECT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/dialect) file(GLOB_RECURSE PROTON_SRC ${PROTON_SRC_DIR}/lib/*.cpp) -add_library(proton SHARED ${PROTON_SRC} ${PROTON_SRC_DIR}/${PROJECT_NAME}.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/dialect/triton_proton.cc - ) - +add_library(proton SHARED ${PROTON_SRC} ${PROTON_SRC_DIR}/${PROJECT_NAME}.cpp) if(NOT CUPTI_INCLUDE_DIR) message(FATAL_ERROR "CUPTI include directory not defined") endif() @@ -20,7 +16,6 @@ endif() include_directories(${JSON_INCLUDE_DIR}) include_directories(${PROTON_SRC_DIR}/include) -include_directories(${PROTON_DIALECT_DIR}/include) include_directories(${PROTON_EXTERN_DIR}) find_package(Python3 REQUIRED Interpreter Development.Module) @@ -42,9 +37,6 @@ include_directories(${CUPTI_INCLUDE_DIR}) include_directories(SYSTEM ${ROCTRACER_INCLUDE_DIR}) target_compile_definitions(proton PRIVATE __HIP_PLATFORM_AMD__) -target_link_libraries(proton PRIVATE Python3::Module pybind11::headers LLVMSupport MLIRIR ProtonIR TritonGPUIR TritonIR MLIRGPUDialect - MLIRTransforms - MLIRTransformUtils - TritonGPUTransforms - ) +target_link_libraries(proton PRIVATE Python3::Module pybind11::headers) + target_link_options(proton PRIVATE ${PROTON_PYTHON_LDFLAGS}) diff --git a/third_party/proton/dialect/CMakeLists.txt b/third_party/proton/dialect/CMakeLists.txt index ecf525b2137d..cfa5938873d9 100644 --- a/third_party/proton/dialect/CMakeLists.txt +++ b/third_party/proton/dialect/CMakeLists.txt @@ -4,5 +4,5 @@ add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc) - target_link_libraries(TritonProton PRIVATE ProtonIR Python3::Module pybind11::headers LLVMSupport MLIRIR) + target_link_libraries(TritonProton PRIVATE ProtonIR Python3::Module pybind11::headers) endif() diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt index 2cfe9e8e31e1..5eea5cb3cf9e 100644 --- a/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt @@ -7,8 +7,6 @@ add_triton_library(ProtonIR ProtonAttrDefsIncGen LINK_LIBS PUBLIC - LLVMSupport - MLIRIR MLIRLLVMDialect TritonIR TritonGPUIR diff --git a/third_party/proton/dialect/triton_proton.cc b/third_party/proton/dialect/triton_proton.cc index 5ff9fa0f00c2..8046539794e1 100644 --- a/third_party/proton/dialect/triton_proton.cc +++ b/third_party/proton/dialect/triton_proton.cc @@ -1,15 +1,13 @@ #include "Dialect/Proton/IR/Dialect.h" #include "mlir/Pass/PassManager.h" #include "passes.h" - #include +#include +#include namespace py = pybind11; - void init_triton_proton(py::module &&m) { - using ret = py::return_value_policy; - using namespace pybind11::literals; auto passes = m.def_submodule("passes"); // load dialects @@ -19,7 +17,4 @@ void init_triton_proton(py::module &&m) { context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); - - } - diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_proton_record.py index bb6aab0b3210..f3f76e9582e1 100644 --- a/third_party/proton/test/test_proton_record.py +++ b/third_party/proton/test/test_proton_record.py @@ -58,6 +58,7 @@ def add(x: torch.Tensor, y: torch.Tensor): pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) ttir = pgm.asm['ttir'] print(ttir) + assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output From a33f5138b31b57d638a1e9e16b6bdf2a11e6ac9b Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 03:15:53 +0000 Subject: [PATCH 05/29] update --- .../Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h | 4 ++-- lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp | 2 +- third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- .../nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- third_party/proton/CMakeLists.txt | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 79335273e520..f9d50ed4e065 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -120,12 +120,12 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit); - +namespace proton { void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit); - +} // namespace proton } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp index 40ee63c4a6a0..2243d129910b 100644 --- a/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp @@ -33,7 +33,7 @@ struct RecordOpConversion : public ConvertOpToLLVMPattern(typeConverter, targetInfo, benefit); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 4564176059f4..66d1d8b8e98f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -229,7 +229,7 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); - mlir::triton::populateRecordOpToLLVMPattern(typeConverter, patterns, + mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 444017fe8453..901ba65a6966 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -149,7 +149,7 @@ struct ConvertTritonGPUToLLVM targetInfo, benefit); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); - mlir::triton::populateRecordOpToLLVMPattern(typeConverter, patterns, + mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); diff --git a/third_party/proton/CMakeLists.txt b/third_party/proton/CMakeLists.txt index d36fd594f5c7..e0fafb43a929 100644 --- a/third_party/proton/CMakeLists.txt +++ b/third_party/proton/CMakeLists.txt @@ -4,6 +4,7 @@ set(PROTON_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/csrc) set(PROTON_EXTERN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/extern) file(GLOB_RECURSE PROTON_SRC ${PROTON_SRC_DIR}/lib/*.cpp) add_library(proton SHARED ${PROTON_SRC} ${PROTON_SRC_DIR}/${PROJECT_NAME}.cpp) + if(NOT CUPTI_INCLUDE_DIR) message(FATAL_ERROR "CUPTI include directory not defined") endif() @@ -38,5 +39,4 @@ include_directories(SYSTEM ${ROCTRACER_INCLUDE_DIR}) target_compile_definitions(proton PRIVATE __HIP_PLATFORM_AMD__) target_link_libraries(proton PRIVATE Python3::Module pybind11::headers) - target_link_options(proton PRIVATE ${PROTON_PYTHON_LDFLAGS}) From f4095cce3eb594bf7554c93eca5168e35b4e8440 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 03:20:49 +0000 Subject: [PATCH 06/29] update --- third_party/proton/proton/profile.py | 39 ++++++++++++++-------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 80ad1e48d409..faa68558b5c3 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -4,33 +4,34 @@ from triton._C.libproton import proton as libproton from triton.language import core as tl +from triton.language.core import builtin import triton.language from .hook import register_triton_hook, unregister_triton_hook from functools import wraps from .flags import set_profiling_off, set_profiling_on, is_command_line -from typing import Optional, TypeVar +from typing import Optional from . import language -import builtins +#import builtins -T = TypeVar('T') +#T = TypeVar('T') DEFAULT_PROFILE_NAME = "proton" -TRITON_BUILTIN = "__triton_builtin__" - -def builtin(fn: T) -> T: - """Mark a function as a builtin.""" - assert callable(fn) - - @wraps(fn) - def wrapper(*args, **kwargs): - if "_builder" not in kwargs or kwargs["_builder"] is None: - raise ValueError("Did you forget to add @triton.jit ? " - "(`_builder` argument must be provided outside of JIT functions.)") - return fn(*args, **kwargs) - - setattr(wrapper, TRITON_BUILTIN, True) - - return wrapper +#TRITON_BUILTIN = "__triton_builtin__" +# +#def builtin(fn: T) -> T: +# """Mark a function as a builtin.""" +# assert callable(fn) +# +# @wraps(fn) +# def wrapper(*args, **kwargs): +# if "_builder" not in kwargs or kwargs["_builder"] is None: +# raise ValueError("Did you forget to add @triton.jit ? " +# "(`_builder` argument must be provided outside of JIT functions.)") +# return fn(*args, **kwargs) +# +# setattr(wrapper, TRITON_BUILTIN, True) +# +# return wrapper @builtin def record(isStart: bool, regionId: int, _builder=None): From 09b7cf922152a38dcfeb6637b67fd7fd6f81896d Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 03:21:23 +0000 Subject: [PATCH 07/29] update --- third_party/proton/proton/profile.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index faa68558b5c3..889211a32baf 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -11,27 +11,8 @@ from .flags import set_profiling_off, set_profiling_on, is_command_line from typing import Optional from . import language -#import builtins - -#T = TypeVar('T') DEFAULT_PROFILE_NAME = "proton" -#TRITON_BUILTIN = "__triton_builtin__" -# -#def builtin(fn: T) -> T: -# """Mark a function as a builtin.""" -# assert callable(fn) -# -# @wraps(fn) -# def wrapper(*args, **kwargs): -# if "_builder" not in kwargs or kwargs["_builder"] is None: -# raise ValueError("Did you forget to add @triton.jit ? " -# "(`_builder` argument must be provided outside of JIT functions.)") -# return fn(*args, **kwargs) -# -# setattr(wrapper, TRITON_BUILTIN, True) -# -# return wrapper @builtin def record(isStart: bool, regionId: int, _builder=None): From 877e2eaf4d0d03219da93e6c4bd193a481652d10 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 03:26:00 +0000 Subject: [PATCH 08/29] update --- third_party/proton/proton/profile.py | 1 - third_party/proton/test/test_proton_record.py | 41 +++---------------- 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 889211a32baf..d289c83c6748 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -7,7 +7,6 @@ from triton.language.core import builtin import triton.language from .hook import register_triton_hook, unregister_triton_hook -from functools import wraps from .flags import set_profiling_off, set_profiling_on, is_command_line from typing import Optional from . import language diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_proton_record.py index f3f76e9582e1..919981365a92 100644 --- a/third_party/proton/test/test_proton_record.py +++ b/third_party/proton/test/test_proton_record.py @@ -8,65 +8,36 @@ @triton.jit -def add_kernel(x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. +def add_kernel(x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, ): - # There are multiple 'programs' processing different data. We identify which program - # we are here: - pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. - # This program will process inputs that are offset from the initial data. - # For instance, if you had a vector of length 256 and block_size of 64, the programs - # would each access the elements [0:64, 64:128, 128:192, 192:256]. - # Note that offsets is a list of pointers: + pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) - # Create a mask to guard memory operations against out-of-bounds accesses. mask = offsets < n_elements - # Load x and y from DRAM, masking out any extra elements in case the input is not a - # multiple of the block size. x = tl.load(x_ptr + offsets, mask=mask) proton.record(True, 0) y = tl.load(y_ptr + offsets, mask=mask) output = x + y - # Write x + y back to DRAM. tl.store(output_ptr + offsets, output, mask=mask) -# %% -# Let's also declare a helper function to (1) allocate the `z` tensor -# and (2) enqueue the above kernel with appropriate grid/block sizes: - def add(x: torch.Tensor, y: torch.Tensor): - # We need to preallocate the output. output = torch.empty_like(x) assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE n_elements = output.numel() - # The SPMD launch grid denotes the number of kernel instances that run in parallel. - # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. - # In this case, we use a 1D grid where the size is the number of blocks: BLOCK = 4096 grid = (int(n_elements / BLOCK), 1, 1) - # NOTE: - # - Each torch.tensor object is implicitly converted into a pointer to its first element. - # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. - # - Don't forget to pass meta-parameters as keywords arguments. pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) ttir = pgm.asm['ttir'] - print(ttir) assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir - # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still - # running asynchronously at this point. return output torch.manual_seed(0) size = 2**12 x = torch.rand(size, device='cuda') y = torch.rand(size, device='cuda') - -output_triton = add(x, y) -print(output_triton) From ae0ff69fe56ce1f16b976299711e411b6a0cc20e Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 03:52:34 +0000 Subject: [PATCH 09/29] update --- .../PatternTritonGPUOpToLLVM.h | 6 +- .../TritonGPUToLLVM/RecordOpToLLVM.cpp | 12 ++-- .../TritonToTritonGPUPass.cpp | 4 +- python/src/ir.cc | 13 ++-- .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 4 +- .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- third_party/proton/proton/language.py | 2 +- third_party/proton/proton/profile.py | 1 + third_party/proton/test/test_proton_record.py | 72 ++++++++++--------- 9 files changed, 61 insertions(+), 55 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index f9d50ed4e065..f1fc8420170b 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -122,9 +122,9 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit); namespace proton { void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - const TargetInfoBase &targetInfo, - PatternBenefit benefit); + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); } // namespace proton } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp index 2243d129910b..04a38316e6c5 100644 --- a/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp @@ -8,14 +8,15 @@ #include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" - namespace { -struct RecordOpConversion : public ConvertOpToLLVMPattern { +struct RecordOpConversion + : public ConvertOpToLLVMPattern { explicit RecordOpConversion(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - PatternBenefit benefit) - : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), targetInfo(targetInfo) {} LogicalResult @@ -26,7 +27,6 @@ struct RecordOpConversion : public ConvertOpToLLVMPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, - GenericOpPattern, GenericOpPattern, - GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, diff --git a/python/src/ir.cc b/python/src/ir.cc index 99143c019204..1ab97e4023fd 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -236,8 +236,9 @@ void init_triton_ir(py::module &&m) { DialectRegistry registry; registry.insert(); + ::mlir::gpu::GPUDialect, cf::ControlFlowDialect, + ::mlir::triton::proton::ProtonDialect, LLVM::LLVMDialect, + mlir::ub::UBDialect>(); mlir::LLVM::registerInlinerInterface(registry); registerBuiltinDialectTranslation(registry); registerLLVMDialectTranslation(registry); @@ -1605,10 +1606,10 @@ void init_triton_ir(py::module &&m) { llvm::StringRef(prefix)); self.create(prefixAttr, hex, values, isSigned); }) - .def("create_proton_record", - [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void{ - self.create(isStart, regionId); - }) + .def("create_proton_record", + [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void { + self.create(isStart, regionId); + }) .def("create_assert", [](TritonOpBuilder &self, Value &condition, const std::string &message) -> void { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 66d1d8b8e98f..7e2d6d6aab80 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -229,8 +229,8 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); - mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns, - targetInfo, commonBenefit); + mlir::triton::proton::populateRecordOpToLLVMPattern( + typeConverter, patterns, targetInfo, commonBenefit); mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 901ba65a6966..0e54c5858234 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -150,7 +150,7 @@ struct ConvertTritonGPUToLLVM mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); + targetInfo, benefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns, diff --git a/third_party/proton/proton/language.py b/third_party/proton/proton/language.py index f92a0b1a1d03..6092399afe36 100644 --- a/third_party/proton/proton/language.py +++ b/third_party/proton/proton/language.py @@ -4,6 +4,6 @@ from triton._C.libtriton import ir from triton.language import core as tl + def proton_record(isStart: bool, regionId: int, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) - diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index d289c83c6748..f4c3f5b56c2a 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -13,6 +13,7 @@ DEFAULT_PROFILE_NAME = "proton" + @builtin def record(isStart: bool, regionId: int, _builder=None): return language.proton_record(isStart, regionId, _builder) diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_proton_record.py index 919981365a92..7daf48b09517 100644 --- a/third_party/proton/test/test_proton_record.py +++ b/third_party/proton/test/test_proton_record.py @@ -1,4 +1,6 @@ import torch +import pytest +import pathlib import triton import triton.language as tl @@ -7,37 +9,39 @@ DEVICE = triton.runtime.driver.active.get_active_torch_device() -@triton.jit -def add_kernel(x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - proton.record(True, 0) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - - -def add(x: torch.Tensor, y: torch.Tensor): - output = torch.empty_like(x) - assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE - n_elements = output.numel() - BLOCK = 4096 - grid = (int(n_elements / BLOCK), 1, 1) - pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) - ttir = pgm.asm['ttir'] - assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir - return output - -torch.manual_seed(0) -size = 2**12 -x = torch.rand(size, device='cuda') -y = torch.rand(size, device='cuda') +def test_proton_record(tmp_path: pathlib.Path): + + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + proton.record(True, 0) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + BLOCK = 4096 + grid = (int(n_elements / BLOCK), 1, 1) + pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + ttir = pgm.asm['ttir'] + assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir + return output + + torch.manual_seed(0) + size = 2**12 + x = torch.rand(size, device='cuda') + y = torch.rand(size, device='cuda') + output_triton = add(x, y) From e4784f4bf26a55fc2c648d6e54c8ec7d8fa593c5 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 04:18:49 +0000 Subject: [PATCH 10/29] update --- .../TritonToTritonGPU/TritonToTritonGPUPass.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 7945aafd279b..5721b22bbb5e 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -546,7 +546,6 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, - GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, @@ -558,7 +557,13 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, TritonFuncOpPattern>(typeConverter, context); } - +// Proton patterns +void populateProtonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add>(typeConverter, + context); +} // // SCF patterns // @@ -773,6 +778,7 @@ class ConvertTritonToTritonGPU populateArithPatternsAndLegality(typeConverter, patterns, target); populateMathPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns, numCTAs); + populateProtonPatterns(typeConverter, patterns); // TODO: can we use // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); From 754fdd1701d08001156ced345c75bcda1c857910 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 04:22:58 +0000 Subject: [PATCH 11/29] update --- third_party/proton/proton/language.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/proton/proton/language.py b/third_party/proton/proton/language.py index 6092399afe36..612ea2a749b4 100644 --- a/third_party/proton/proton/language.py +++ b/third_party/proton/proton/language.py @@ -1,6 +1,3 @@ -from __future__ import annotations # remove after python 3.11 -import warnings - from triton._C.libtriton import ir from triton.language import core as tl From d2f9c0bd3486780ee2a3842b6c755af7a21ef3d0 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sat, 28 Dec 2024 15:03:17 +0000 Subject: [PATCH 12/29] update --- third_party/proton/proton/profile.py | 1 - third_party/proton/test/test_proton_record.py | 22 +++++++------------ 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index f4c3f5b56c2a..cd8ab85b2389 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -5,7 +5,6 @@ from triton._C.libproton import proton as libproton from triton.language import core as tl from triton.language.core import builtin -import triton.language from .hook import register_triton_hook, unregister_triton_hook from .flags import set_profiling_off, set_profiling_on, is_command_line from typing import Optional diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_proton_record.py index 7daf48b09517..957f31787e1f 100644 --- a/third_party/proton/test/test_proton_record.py +++ b/third_party/proton/test/test_proton_record.py @@ -6,8 +6,6 @@ import triton.language as tl import triton.profiler as proton -DEVICE = triton.runtime.driver.active.get_active_torch_device() - def test_proton_record(tmp_path: pathlib.Path): @@ -26,22 +24,18 @@ def add_kernel( x = tl.load(x_ptr + offsets, mask=mask) proton.record(True, 0) y = tl.load(y_ptr + offsets, mask=mask) + proton.record(False, 0) output = x + y tl.store(output_ptr + offsets, output, mask=mask) - def add(x: torch.Tensor, y: torch.Tensor): - output = torch.empty_like(x) - assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE - n_elements = output.numel() - BLOCK = 4096 - grid = (int(n_elements / BLOCK), 1, 1) - pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) - ttir = pgm.asm['ttir'] - assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir - return output - torch.manual_seed(0) size = 2**12 x = torch.rand(size, device='cuda') y = torch.rand(size, device='cuda') - output_triton = add(x, y) + output = torch.empty_like(x) + n_elements = output.numel() + grid = (1, 1, 1) + pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + ttir = pgm.asm['ttir'] + assert "proton.record() {isStart = true, regionId = 0 : i32}" in ttir + assert "proton.record() {isStart = false, regionId = 0 : i32}" in ttir From 7b0485d10ef7bd5bc8e522f8a8865e9c437c6b28 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sun, 29 Dec 2024 02:10:58 +0000 Subject: [PATCH 13/29] move some paths into proton dialect --- .../PatternTritonGPUOpToLLVM.h | 6 ------ .../TritonGPUToLLVM/RecordOpToLLVM.cpp | 1 + .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 2 ++ .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 2 ++ .../PatternTritonProtonOpToLLVM.h | 20 +++++++++++++++++++ 5 files changed, 25 insertions(+), 6 deletions(-) create mode 100644 third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index f1fc8420170b..0d331c06de77 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -120,12 +120,6 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit); -namespace proton { -void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - const TargetInfoBase &targetInfo, - PatternBenefit benefit); -} // namespace proton } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp index 04a38316e6c5..9b0b08ed730b 100644 --- a/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp @@ -7,6 +7,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" namespace { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 7e2d6d6aab80..10534115fd35 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -24,6 +24,8 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" + namespace mlir::triton { #define GEN_PASS_DEF_CONVERTTRITONAMDGPUTOLLVM #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 0e54c5858234..cb976e8ec4a0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -22,6 +22,8 @@ #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h" + namespace mlir { namespace triton { #define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM diff --git a/third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h b/third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h new file mode 100644 index 000000000000..47d9f4bf5a6d --- /dev/null +++ b/third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h @@ -0,0 +1,20 @@ +#ifndef TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton { +namespace proton { +void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +} // namespace proton +} // namespace triton +} // namespace mlir + +#endif From 36f1861242526882ad1f73e8edc80419fb874e7d Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sun, 29 Dec 2024 02:11:46 +0000 Subject: [PATCH 14/29] update --- .../triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 0d331c06de77..c5c78e6d5b20 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -120,6 +120,7 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit); + } // namespace triton } // namespace mlir From 3897f8130008a7aa26b0ac1209dc45a08f052a7e Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sun, 29 Dec 2024 02:38:10 +0000 Subject: [PATCH 15/29] move more things into proton dialect --- lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 - .../amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt | 1 + .../lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt | 1 + third_party/proton/dialect/lib/CMakeLists.txt | 1 + .../lib/TritonProtonToLLVM/CMakeLists.txt | 20 +++++++++++++++++++ .../TritonProtonToLLVM}/RecordOpToLLVM.cpp | 0 6 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt rename {lib/Conversion/TritonGPUToLLVM => third_party/proton/dialect/lib/TritonProtonToLLVM}/RecordOpToLLVM.cpp (100%) diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 96b60bf01ac7..65949672f1dc 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -20,7 +20,6 @@ add_triton_library(TritonGPUToLLVM SPMDOpToLLVM.cpp DecomposeUnsupportedConversions.cpp PrintOpToLLVM.cpp - RecordOpToLLVM.cpp DEPENDS TritonGPUConversionPassIncGen diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index e2465f17b622..d101035a4ba4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -29,4 +29,5 @@ add_triton_library(TritonAMDGPUToLLVM LINK_LIBS PUBLIC TritonGPUToLLVM TritonAMDGPUIR + TritonProtonToLLVM ) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index 96727b357106..a3d8a87290fb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -25,4 +25,5 @@ add_triton_library(TritonNVIDIAGPUToLLVM LINK_LIBS PUBLIC TritonGPUToLLVM + TritonProtonToLLVM ) diff --git a/third_party/proton/dialect/lib/CMakeLists.txt b/third_party/proton/dialect/lib/CMakeLists.txt index 0ca0f41c5af4..a224fd6f21f4 100644 --- a/third_party/proton/dialect/lib/CMakeLists.txt +++ b/third_party/proton/dialect/lib/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Dialect) +add_subdirectory(TritonProtonToLLVM) diff --git a/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt b/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..2d86b8259686 --- /dev/null +++ b/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt @@ -0,0 +1,20 @@ +add_triton_library(TritonProtonToLLVM + RecordOpToLLVM.cpp + + DEPENDS + TritonGPUConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRGPUDialect + MLIRGPUToNVVMTransforms + MLIRGPUToROCDLTransforms + MLIRGPUTransforms + TritonAnalysis + TritonIR + ProtonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUTransforms +) diff --git a/lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp b/third_party/proton/dialect/lib/TritonProtonToLLVM/RecordOpToLLVM.cpp similarity index 100% rename from lib/Conversion/TritonGPUToLLVM/RecordOpToLLVM.cpp rename to third_party/proton/dialect/lib/TritonProtonToLLVM/RecordOpToLLVM.cpp From bf1011eb47fb1d32d00190ffdaae3aad261dc8bb Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sun, 29 Dec 2024 14:54:24 +0000 Subject: [PATCH 16/29] update cmake --- lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 - lib/Conversion/TritonToTritonGPU/CMakeLists.txt | 1 - 2 files changed, 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 65949672f1dc..d6cc4387f79e 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -33,7 +33,6 @@ add_triton_library(TritonGPUToLLVM MLIRGPUTransforms TritonAnalysis TritonIR - ProtonIR TritonGPUIR TritonGPUTransforms TritonNvidiaGPUTransforms diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 1438ec75c47d..1b629ba1639f 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -12,5 +12,4 @@ add_triton_library(TritonToTritonGPU TritonIR TritonGPUIR TritonGPUTransforms - ProtonIR ) From 97b92b268f08bc0406637be3f5c06fb1d109bad3 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sun, 29 Dec 2024 16:07:25 +0000 Subject: [PATCH 17/29] update cmake --- lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 + lib/Conversion/TritonToTritonGPU/CMakeLists.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index d6cc4387f79e..39b1f93f5ae5 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -33,6 +33,7 @@ add_triton_library(TritonGPUToLLVM MLIRGPUTransforms TritonAnalysis TritonIR + # ProtonIR TritonGPUIR TritonGPUTransforms TritonNvidiaGPUTransforms diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 1b629ba1639f..fb5f7156f9aa 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -10,6 +10,7 @@ add_triton_library(TritonToTritonGPU MLIRPass MLIRTransforms TritonIR + ProtonIR TritonGPUIR TritonGPUTransforms ) From 6a0dc4cd6f2a6b2b47576fa724a2b0a060d3d048 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sun, 29 Dec 2024 16:07:56 +0000 Subject: [PATCH 18/29] update cmake --- lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 39b1f93f5ae5..d6cc4387f79e 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -33,7 +33,6 @@ add_triton_library(TritonGPUToLLVM MLIRGPUTransforms TritonAnalysis TritonIR - # ProtonIR TritonGPUIR TritonGPUTransforms TritonNvidiaGPUTransforms From f1911e4a732c3e366b0a6778bf0227271f2a18e8 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sun, 29 Dec 2024 20:23:40 +0000 Subject: [PATCH 19/29] add dev module --- third_party/proton/proton/__init__.py | 2 +- third_party/proton/proton/dev/record.py | 13 +++++++++++++ third_party/proton/proton/profile.py | 12 +++++++++--- third_party/proton/test/test_proton_record.py | 4 ++-- 4 files changed, 25 insertions(+), 6 deletions(-) create mode 100644 third_party/proton/proton/dev/record.py diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index e3766965fd18..18bdd3e8013f 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -7,6 +7,6 @@ deactivate, finalize, profile, - record, + dev, DEFAULT_PROFILE_NAME, ) diff --git a/third_party/proton/proton/dev/record.py b/third_party/proton/proton/dev/record.py new file mode 100644 index 000000000000..9f1cf243c5a3 --- /dev/null +++ b/third_party/proton/proton/dev/record.py @@ -0,0 +1,13 @@ +import functools +import triton +import os + +from triton._C.libproton import proton as libproton +from triton.language import core as tl +from triton.language.core import builtin +from .. import language + +@builtin +def record(isStart: bool, regionId: int, _builder=None): + return language.proton_record(isStart, regionId, _builder) + diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index cd8ab85b2389..2975dcce4658 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -9,13 +9,19 @@ from .flags import set_profiling_off, set_profiling_on, is_command_line from typing import Optional from . import language +import warnings DEFAULT_PROFILE_NAME = "proton" -@builtin -def record(isStart: bool, regionId: int, _builder=None): - return language.proton_record(isStart, regionId, _builder) +class dev: + + @builtin + def record(isStart: bool, regionId: int, _builder=None): + warnings.warn( + "\nWarning the dev module within Proton contains under development features that are not intended to be used outside of the core development team" + ) + return language.proton_record(isStart, regionId, _builder) def _select_backend() -> str: diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_proton_record.py index 957f31787e1f..0b9c72c8eba8 100644 --- a/third_party/proton/test/test_proton_record.py +++ b/third_party/proton/test/test_proton_record.py @@ -22,9 +22,9 @@ def add_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) - proton.record(True, 0) + proton.dev.record(True, 0) y = tl.load(y_ptr + offsets, mask=mask) - proton.record(False, 0) + proton.dev.record(False, 0) output = x + y tl.store(output_ptr + offsets, output, mask=mask) From b085d750c419209e6906052db9824b6398359d3e Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sun, 29 Dec 2024 20:27:33 +0000 Subject: [PATCH 20/29] formatting --- third_party/proton/proton/dev/record.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/proton/proton/dev/record.py b/third_party/proton/proton/dev/record.py index 9f1cf243c5a3..fad04772e26f 100644 --- a/third_party/proton/proton/dev/record.py +++ b/third_party/proton/proton/dev/record.py @@ -7,7 +7,7 @@ from triton.language.core import builtin from .. import language + @builtin def record(isStart: bool, regionId: int, _builder=None): return language.proton_record(isStart, regionId, _builder) - From 5cb4ad2ced56753c2b90afbb7b6543eb4e0f4007 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Mon, 30 Dec 2024 01:52:57 +0000 Subject: [PATCH 21/29] update --- third_party/proton/proton/dev/record.py | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 third_party/proton/proton/dev/record.py diff --git a/third_party/proton/proton/dev/record.py b/third_party/proton/proton/dev/record.py deleted file mode 100644 index fad04772e26f..000000000000 --- a/third_party/proton/proton/dev/record.py +++ /dev/null @@ -1,13 +0,0 @@ -import functools -import triton -import os - -from triton._C.libproton import proton as libproton -from triton.language import core as tl -from triton.language.core import builtin -from .. import language - - -@builtin -def record(isStart: bool, regionId: int, _builder=None): - return language.proton_record(isStart, regionId, _builder) From 190d26ca46686535df89aec3a025bbecb0f3cad0 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Mon, 30 Dec 2024 19:39:44 +0000 Subject: [PATCH 22/29] update --- .../dialect/lib/TritonProtonToLLVM/CMakeLists.txt | 14 -------------- third_party/proton/proton/__init__.py | 4 +++- third_party/proton/proton/language.py | 12 +++++++++++- third_party/proton/proton/profile.py | 13 ------------- third_party/proton/test/test_proton_record.py | 7 +++---- 5 files changed, 17 insertions(+), 33 deletions(-) diff --git a/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt b/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt index 2d86b8259686..84b134fda39d 100644 --- a/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt +++ b/third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt @@ -1,20 +1,6 @@ add_triton_library(TritonProtonToLLVM RecordOpToLLVM.cpp - DEPENDS - TritonGPUConversionPassIncGen - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRGPUDialect - MLIRGPUToNVVMTransforms - MLIRGPUToROCDLTransforms - MLIRGPUTransforms - TritonAnalysis - TritonIR ProtonIR - TritonGPUIR - TritonGPUTransforms - TritonNvidiaGPUTransforms ) diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index 18bdd3e8013f..13bf59e252c5 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -7,6 +7,8 @@ deactivate, finalize, profile, - dev, DEFAULT_PROFILE_NAME, ) +from .language import ( + record +) diff --git a/third_party/proton/proton/language.py b/third_party/proton/proton/language.py index 612ea2a749b4..d95a2e9c1a73 100644 --- a/third_party/proton/proton/language.py +++ b/third_party/proton/proton/language.py @@ -1,6 +1,16 @@ from triton._C.libtriton import ir from triton.language import core as tl - +from triton.language.core import builtin +import warnings def proton_record(isStart: bool, regionId: int, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) + +@builtin +def record(isStart: bool, regionId: int, _builder=None): + warnings.warn( + "\nWarning the proton language module within Proton contains under development features that are not intended to be used outside of the core development team" + ) + return proton_record(isStart, regionId, _builder) + + diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 2975dcce4658..903743f95f45 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -4,26 +4,13 @@ from triton._C.libproton import proton as libproton from triton.language import core as tl -from triton.language.core import builtin from .hook import register_triton_hook, unregister_triton_hook from .flags import set_profiling_off, set_profiling_on, is_command_line from typing import Optional -from . import language -import warnings DEFAULT_PROFILE_NAME = "proton" -class dev: - - @builtin - def record(isStart: bool, regionId: int, _builder=None): - warnings.warn( - "\nWarning the dev module within Proton contains under development features that are not intended to be used outside of the core development team" - ) - return language.proton_record(isStart, regionId, _builder) - - def _select_backend() -> str: backend = triton.runtime.driver.active.get_current_target().backend if backend == "cuda": diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_proton_record.py index 0b9c72c8eba8..04531fe4c264 100644 --- a/third_party/proton/test/test_proton_record.py +++ b/third_party/proton/test/test_proton_record.py @@ -4,8 +4,7 @@ import triton import triton.language as tl -import triton.profiler as proton - +import triton.profiler.language as pl def test_proton_record(tmp_path: pathlib.Path): @@ -22,9 +21,9 @@ def add_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) - proton.dev.record(True, 0) + pl.record(True, 0) y = tl.load(y_ptr + offsets, mask=mask) - proton.dev.record(False, 0) + pl.record(False, 0) output = x + y tl.store(output_ptr + offsets, output, mask=mask) From 3c726816bd74aa28ebb025d0bc22263edabee1ff Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Mon, 30 Dec 2024 19:52:53 +0000 Subject: [PATCH 23/29] update --- third_party/proton/proton/__init__.py | 4 +--- third_party/proton/proton/language.py | 8 ++------ third_party/proton/test/test_proton_record.py | 1 + 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index 13bf59e252c5..11fad20ebfe9 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -9,6 +9,4 @@ profile, DEFAULT_PROFILE_NAME, ) -from .language import ( - record -) +from .language import (record) diff --git a/third_party/proton/proton/language.py b/third_party/proton/proton/language.py index d95a2e9c1a73..99f8d997d653 100644 --- a/third_party/proton/proton/language.py +++ b/third_party/proton/proton/language.py @@ -3,14 +3,10 @@ from triton.language.core import builtin import warnings -def proton_record(isStart: bool, regionId: int, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) @builtin -def record(isStart: bool, regionId: int, _builder=None): +def record(isStart: bool, regionId: int, builder=None): warnings.warn( "\nWarning the proton language module within Proton contains under development features that are not intended to be used outside of the core development team" ) - return proton_record(isStart, regionId, _builder) - - + return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_proton_record.py index 04531fe4c264..0c623c3784ed 100644 --- a/third_party/proton/test/test_proton_record.py +++ b/third_party/proton/test/test_proton_record.py @@ -6,6 +6,7 @@ import triton.language as tl import triton.profiler.language as pl + def test_proton_record(tmp_path: pathlib.Path): @triton.jit From d980aeab4ed3b35a14e843620c8457ea0bf2d36e Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Mon, 30 Dec 2024 20:22:34 +0000 Subject: [PATCH 24/29] update --- third_party/proton/test/{test_proton_record.py => test_record.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename third_party/proton/test/{test_proton_record.py => test_record.py} (100%) diff --git a/third_party/proton/test/test_proton_record.py b/third_party/proton/test/test_record.py similarity index 100% rename from third_party/proton/test/test_proton_record.py rename to third_party/proton/test/test_record.py From 134f1026db0cbc4313bdc206c4b48818b77f29d6 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Mon, 30 Dec 2024 20:31:45 +0000 Subject: [PATCH 25/29] update --- third_party/proton/proton/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index 11fad20ebfe9..65f58e5da64b 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -9,4 +9,4 @@ profile, DEFAULT_PROFILE_NAME, ) -from .language import (record) +from .language import record From 6a64945df84736e2128ba07500bf23b9422f74a5 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Mon, 30 Dec 2024 20:48:05 +0000 Subject: [PATCH 26/29] update --- third_party/proton/proton/language.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/proton/proton/language.py b/third_party/proton/proton/language.py index 99f8d997d653..d923f60c6a01 100644 --- a/third_party/proton/proton/language.py +++ b/third_party/proton/proton/language.py @@ -5,8 +5,8 @@ @builtin -def record(isStart: bool, regionId: int, builder=None): +def record(isStart: bool, regionId: int, _builder=None): warnings.warn( "\nWarning the proton language module within Proton contains under development features that are not intended to be used outside of the core development team" ) - return tl.tensor(builder.create_proton_record(isStart, regionId), tl.void) + return tl.tensor(_builder.create_proton_record(isStart, regionId), tl.void) From ca63a070c1976a7da33ffacc11b1d29db078f190 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Tue, 31 Dec 2024 01:38:18 +0000 Subject: [PATCH 27/29] update --- lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 5721b22bbb5e..5159890468a8 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -558,6 +558,10 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonFuncOpPattern>(typeConverter, context); } // Proton patterns +// NOTE: Because Proton's inputs are scalars and not tensors this conversion +// isn't strictly nessessary however you could envision a case where we pass in +// tensors in for Triton object specific tracing operations in which case we +// would need to fill in the OpConversionPattern void populateProtonPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); From 3184568c246315871bd8b282acda4edc6930dad7 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Tue, 31 Dec 2024 02:40:32 +0000 Subject: [PATCH 28/29] update --- python/src/ir.cc | 9 +++++---- third_party/proton/proton/__init__.py | 1 - third_party/proton/proton/profile.py | 2 -- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index 1ab97e4023fd..531f1444ec46 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1606,10 +1606,6 @@ void init_triton_ir(py::module &&m) { llvm::StringRef(prefix)); self.create(prefixAttr, hex, values, isSigned); }) - .def("create_proton_record", - [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void { - self.create(isStart, regionId); - }) .def("create_assert", [](TritonOpBuilder &self, Value &condition, const std::string &message) -> void { @@ -1661,6 +1657,11 @@ void init_triton_ir(py::module &&m) { std::vector &tensorShape) -> Value { return self.create(base, shape, strides, tensorShape); + }) + // Proton Ops + .def("create_proton_record", + [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void { + self.create(isStart, regionId); }); py::class_(m, "pass_manager", py::module_local()) diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index 65f58e5da64b..ded8b01142af 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -9,4 +9,3 @@ profile, DEFAULT_PROFILE_NAME, ) -from .language import record diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 903743f95f45..f639060f54c8 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -3,14 +3,12 @@ import os from triton._C.libproton import proton as libproton -from triton.language import core as tl from .hook import register_triton_hook, unregister_triton_hook from .flags import set_profiling_off, set_profiling_on, is_command_line from typing import Optional DEFAULT_PROFILE_NAME = "proton" - def _select_backend() -> str: backend = triton.runtime.driver.active.get_current_target().backend if backend == "cuda": From a61a21eb9a8ff769e39d47b0d72f4f1336ac0161 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Tue, 31 Dec 2024 02:48:49 +0000 Subject: [PATCH 29/29] formatting --- third_party/proton/proton/profile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index f639060f54c8..575c85b0cac8 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -9,6 +9,7 @@ DEFAULT_PROFILE_NAME = "proton" + def _select_backend() -> str: backend = triton.runtime.driver.active.get_current_target().backend if backend == "cuda":