Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jax-triton-dev authored and vwbaker committed Mar 20, 2024
1 parent 92ac0aa commit 84f9d9d
Show file tree
Hide file tree
Showing 28 changed files with 1,622 additions and 56 deletions.
966 changes: 966 additions & 0 deletions BUILD

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
axisAnalysisPass(axisAnalysisPass) {}

// True if elements allocated to a thread are contiguous within the axis. This
// is not the case in MMA-like encodings wherea thread might have elements
// (0,0),(0,1) and (8,0),(8,1) for example. The problem with this is that the
// deduplication mechanism assumes that for example constancy=4 and
// elements/thread=4 that if a thread has all elements constant.
bool contiguouslyMapped(Attribute encoding) const {
if (auto slice = encoding.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
return contiguouslyMapped(slice.getParent());
}
return encoding.isa<triton::gpu::BlockedEncodingAttr>();
}

// Try to deduplicate the resultVals based on the
// constancy properties of the result discovered by
// the axis analysis pass. If possible, redundant
Expand All @@ -93,8 +105,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
if (!encoding)
// encoding not available
return resultVals;
if (!encoding.dyn_cast<BlockedEncodingAttr>() &&
!encoding.dyn_cast<SliceEncodingAttr>()) {
if (!contiguouslyMapped(encoding)) {
// TODO: constraining the ecndoing type here is necessary for avoiding
// crashes in the getElemsPerThread call below happening in the
// test_core::test_fp8_dot_acc
Expand Down
16 changes: 0 additions & 16 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,6 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
!listconcat(traits, [VerifyTensorLayoutsTrait])> {
}

class TTG_Type<string name, string typeMnemonic,
list<Trait> traits = []> : TypeDef<TritonGPU_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

def TTG_AsyncToken : TTG_Type<"AsyncToken",
"async.token", []> {
let summary = "async token type";
let description = [{
`ttg.async.token` is a type returned by an asynchronous operation.
It is used to establish an SSA-based link between async operations
and operations that group or synchronize the async operations.
}];
}


def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
[SameOperandsAndResultShape,
SameOperandsAndResultElementType,
Expand Down
9 changes: 9 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,13 @@ def TTG_TokenType : TTG_TypeDef<"Token", "token"> {
let skipDefaultBuilders = 1;
}

def TTG_AsyncToken : TTG_TypeDef<"AsyncToken",
"async.token", []> {
let summary = "async token type";
let description = [{
`ttg.async.token` is a type returned by an asynchronous operation.
It is used to establish an SSA-based link between async operations
and operations that group or synchronize the async operations.
}];
}
#endif
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ bool supportMMA(triton::DotOp op, int version) {
auto aElemTy = op.getA().getType().getElementType();
auto bElemTy = op.getB().getType().getElementType();
if (version == 3) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
// TODO(b/311157761): enable mma_v3
if (!triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return false;
auto retType = op.getType();
auto retShapePerCTA = getShapePerCTA(retType);
Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using ttg::SliceEncodingAttr;
// Get the highest version supported for the hardware and the dot.
static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
int baseVersion = 0;
if (computeCapability < 75) {
if (computeCapability < 80) {
baseVersion = 1;
} else if (computeCapability < 90) {
baseVersion = 2;
Expand Down Expand Up @@ -307,8 +307,10 @@ class BlockedToMMA : public mlir::RewritePattern {
} else {

// convert operands
int minBitwidth =
std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
// TODO(b/296812125): Fix minBitwidth issue upstream and uncomment.
// int minBitwidth =
// std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
int minBitwidth = 0;
Type minType = IntegerType::get(ctx, minBitwidth);
// convert A operand
auto newAEncoding = ttg::DotOperandEncodingAttr::get(
Expand Down
16 changes: 15 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <algorithm>
#include <cstdlib>
#include <cctype>
#include <memory>
#include <string>

inline bool isPipeliningEnabled() {
const char *s = std::getenv("ENABLE_PIPELINING");
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}

namespace {

Expand Down Expand Up @@ -329,7 +341,9 @@ class TritonGPUOptimizeDotOperandsPass

mlir::RewritePatternSet patterns(context);
patterns.add<SwizzleShmemConvert>(context);
if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
// TODO(b/291216607): Fix crashes and enable by default.
if (isPipeliningEnabled() &&
triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
patterns.add<HoistLayoutConversion>(context);
patterns.add<FuseTransHopper>(context);
patterns.add<MMAV3UseRegOperand>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ createSchedule(scf::ForOp forOp, int numStages) {
static void hoistAllocAndConst(scf::ForOp forOp) {
SmallVector<Operation *> toHoist;
for (Operation &op : forOp.getBody()->without_terminator()) {
if (isa<ttg::LocalAllocOp, arith::ConstantOp>(op))
if (isa<arith::ConstantOp>(op))
toHoist.push_back(&op);
}
for (Operation *op : toHoist) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
return op;
if (isa<ttg::LocalLoadOp>(op))
return op;
if (isa<ttg::LocalAllocOp>(op))
return op;
if (auto asyncCopyOp = dyn_cast<ttg::AsyncCopyGlobalToLocalOp>(op)) {
rewriter.setInsertionPoint(asyncCopyOp);
Value mask = getPredMask(rewriter, asyncCopyOp.getSrc().getType(),
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct FenceInsertionPass
// Only insert fences for compute capability 9.0
if (computeCapability < 90)
return;
if (::triton::tools::getBoolEnv("DISABLE_MMA_V3"))
// TODO(b/311157761): enable mma_v3
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return;
ModuleOp mod = getOperation();
mod.walk([&](Operation *op) {
Expand Down
91 changes: 91 additions & 0 deletions python/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# NOTE: Do not depend on any targets from this directory,
# but use //third_party/py/triton instead.

load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(
default_applicable_licenses = ["//:license"],
default_visibility = [
"//third_party/py/triton:__pkg__",
"//third_party/triton/python:__subpackages__",
],
)

cc_library(
name = "passes",
hdrs = ["src/passes.h"],
includes = ["src"],
visibility = ["//third_party/triton/third_party:__subpackages__"],
)

pybind_extension(
name = "libtriton",
srcs = [
"src/interpreter.cc",
"src/ir.cc",
"src/llvm.cc",
"src/main.cc",
"src/passes.cc",
],
copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"],
deps = [
":passes",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IPO",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:InstCombine",
"@llvm-project//llvm:Linker",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:ToLLVMIRTranslation",
"@llvm-project//mlir:Transforms",
"//:TritonAnalysis",
"//:TritonDialects",
"//:TritonGPUToLLVM",
"//:TritonGPUTransforms",
"//:TritonHSACO",
"//:TritonLLVMIR",
"//:TritonNvidiaGPUTransforms",
"//:TritonPTX",
"//:TritonToTritonGPU",
"//:TritonTools",
"//:TritonTransforms",
"//third_party/triton/third_party/nvidia:triton_nvidia",
],
)

pybind_extension(
name = "triton_launcher",
srcs = [
"triton/compiler/triton_launcher.c",
],
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
deps = [
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cuda_runtime",
],
)

filegroup(
name = "files",
srcs = glob(
include = ["triton/**/*.py"],
),
)
27 changes: 27 additions & 0 deletions python/test/regression/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "tests",
size = "large",
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
include = ["test_*.py"],

#TODO(b/321005767): fix failing test
exclude = [
"test_performance.py",
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
107 changes: 107 additions & 0 deletions python/test/unit/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "hopper",
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
include = ["hopper/**/test_*.py"],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)

pytest_multi_tests(
name = "language",
size = "large",
srcs = [
"conftest.py",
"language/conftest.py",
"language/test_core.py",
],
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
include = ["language/**/test_*.py"],
exclude = [
"language/test_subprocess.py", # TODO(b/320224484): fix failing test
"language/test_reproducer.py", # this is not an actual test, but a tool for running reproducers
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)

pytest_multi_tests(
name = "operators",
size = "large",
srcs = ["conftest.py"],
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
[
"operators/**/test_*.py",
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)

pytest_multi_tests(
name = "runtime",
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests =
glob(
include = ["runtime/**/test_*.py"],
exclude = [
"runtime/test_launch.py", #TODO(b/320226169): fix failing tests
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)

pytest_multi_tests(
name = "tools",
size = "large",
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests =
glob(
include = ["tools/**/test_*.py"],
exclude = [
"tools/test_aot.py", # TODO(b/320224484): fix failing test
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
2 changes: 1 addition & 1 deletion python/triton/_C/include
Loading

0 comments on commit 84f9d9d

Please sign in to comment.