Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update
  • Loading branch information
Jokeren committed Nov 20, 2024
1 parent e9a1d0f commit d1cbdb5
Show file tree
Hide file tree
Showing 25 changed files with 213 additions and 103 deletions.
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc"
#include "triton/Dialect/Triton/IR/OpInterfaces.h"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
Expand Down
23 changes: 23 additions & 0 deletions include/triton/Dialect/Triton/IR/OpInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef TRITON_IR_OP_INTERFACES_H_
#define TRITON_IR_OP_INTERFACES_H_

#include "mlir/IR/OpDefinition.h"

#include "triton/Dialect/Triton/IR/Utility.h"

namespace mlir {

namespace triton {

namespace impl {

LogicalResult verifyTransposeOpInterface(Operation *op);

} // namespace impl

} // namespace triton
} // namespace mlir

#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc"

#endif // TRITON_IR_OP_INTERFACES_H_
11 changes: 10 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOpInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
It provides methods to access common properties such as the order attribute and the source operand.
}];

let cppNamespace = "::mlir";
let cppNamespace = "::mlir::triton";

let methods = [
InterfaceMethod<
Expand All @@ -20,6 +20,13 @@ def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
/*retType=*/"::mlir::Value",
/*methodName=*/"getSrc",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Get the rank of the source operand.
}],
/*retType=*/"int64_t",
/*methodName=*/"getRank",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Get the order of the transposition.
Expand All @@ -28,6 +35,8 @@ def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
/*methodName=*/"getOrder",
/*args=*/(ins)>
];

let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
}


Expand Down
15 changes: 8 additions & 7 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
Expand Down Expand Up @@ -577,16 +573,21 @@ def TT_TransOp : TT_Op<"trans", [Pure,
}];

let arguments = (
ins TT_TensorOrMemDesc:$src,
ins TT_Tensor:$src,
DenseI32ArrayAttr:$order
);

let results = (outs TT_TensorOrMemDesc:$result);
let results = (outs TT_Tensor:$result);

let extraClassDeclaration = [{
int64_t getRank() {
return cast<RankedTensorType>(getSrc().getType()).getRank();
}
}];

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";

let hasFolder = 1;
let hasVerifier = 1;
}

//
Expand Down
15 changes: 13 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
let hasVerifier = 1;
}

def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure, TransposeOpInterface]> {
def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
TransposeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {
let summary = "transpose the descriptor";

let description = [{
Expand All @@ -239,7 +242,15 @@ def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure, TransposeOpInterface]> {

let results = (outs TT_MemDescType:$result);

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
let extraClassDeclaration = [{
int64_t getRank() {
return cast<MemDescType>(getSrc().getType()).getRank();
}
}];

let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))";

let hasFolder = 1;
}

def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
Expand Down
45 changes: 29 additions & 16 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,27 +269,38 @@ struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern<ExpandDimsOp> {
return success();
}
};
struct MemDescTransOpConversion
: public ConvertOpToLLVMPattern<MemDescTransOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(MemDescTransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = cast<TensorOrMemDesc>(op.getType());
auto enc = cast<SharedEncodingAttr>(resultTy.getEncoding());
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto dstSmemObj = SharedMemoryObject(
srcSmemObj.base, srcSmemObj.baseElemType,
/*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()),
/*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder()));
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};

struct TransOpConversion : public ConvertOpToLLVMPattern<TransOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = cast<TensorOrMemDesc>(op.getType());
if (auto enc = dyn_cast<SharedEncodingAttr>(resultTy.getEncoding())) {
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto dstSmemObj = SharedMemoryObject(
srcSmemObj.base, srcSmemObj.baseElemType,
/*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()),
/*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder()));
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
} else if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(
resultTy.getEncoding())) {
auto resultTy = cast<RankedTensorType>(op.getType());
if (auto enc =
mlir::dyn_cast<BlockedEncodingAttr>(resultTy.getEncoding())) {
// If the dst encoding is blocked, then TransOp::inferReturnTypes
// ensures that:
// - the src encoding is also blocked, and
Expand All @@ -302,9 +313,10 @@ struct TransOpConversion : public ConvertOpToLLVMPattern<TransOp> {
rewriter.replaceOp(op, ret);
return success();
}
return emitOptionalError(loc, "unsupported encoding for TransOp");
return emitOptionalError(loc, "unsupported encoding for MemDescTransOp");
}
};

struct BroadcastOpConversion
: public ConvertOpToLLVMPattern<triton::BroadcastOp> {
using ConvertOpToLLVMPattern<triton::BroadcastOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -407,6 +419,7 @@ void mlir::triton::populateViewOpToLLVMPatterns(
patterns.add<CatOpConversion>(typeConverter, benefit);
patterns.add<JoinOpConversion>(typeConverter, benefit);
patterns.add<SplitOpConversion>(typeConverter, benefit);
patterns.add<MemDescTransOpConversion>(typeConverter, benefit);
patterns.add<TransOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<MemDescSubviewOpConversion>(typeConverter, benefit);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Triton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_triton_library(TritonIR
Ops.cpp
Traits.cpp
Types.cpp
OpInterfaces.cpp

DEPENDS
TritonTableGen
Expand Down
33 changes: 33 additions & 0 deletions lib/Dialect/Triton/IR/OpInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LogicalResult.h"

#include "triton/Dialect/Triton/IR/OpInterfaces.h"
#include "triton/Dialect/Triton/IR/Types.h"

namespace mlir {
namespace triton {
namespace impl {

LogicalResult verifyTransposeOpInterface(Operation *op) {
TransposeOpInterface transposeOp = cast<TransposeOpInterface>(op);
auto rank = transposeOp.getRank();
auto order = transposeOp.getOrder();
if (rank != order.size()) {
return op->emitError(
"order must have the same size as the rank of the operand and result");
}

SmallVector<int32_t, 8> sortedOrder(order);
llvm::sort(sortedOrder);
for (int32_t i = 0; i < sortedOrder.size(); i++) {
if (sortedOrder[i] != i) {
return op->emitError("order must be a permutation of [0, ..., rank - 1]");
}
}

return success();
}

} // namespace impl
} // namespace triton
} // namespace mlir
35 changes: 4 additions & 31 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,35 +223,8 @@ LogicalResult TransOp::inferReturnTypes(
return failure();
}
}
if (auto memDescTy = dyn_cast<MemDescType>(argTy)) {
inferredReturnTypes.push_back(MemDescType::get(
retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(),
memDescTy.getMutableMemory()));
} else {
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, retEltTy, retEncoding));
}
return success();
}

LogicalResult TransOp::verify() {
// Check that the op's `order` attribute is a permutation of the right length.
auto srcTy = getSrc().getType();

ArrayRef<int32_t> order = getOrder();
if (order.size() != srcTy.getRank()) {
return emitError("order must have the same size as the rank of the "
"operand and result");
}

SmallVector<int32_t, 8> sortedOrder(order);
llvm::sort(sortedOrder);
for (int32_t i = 0; i < sortedOrder.size(); i++) {
if (sortedOrder[i] != i) {
return emitError("order must be a permutation of [0, ..., rank - 1]");
}
}

inferredReturnTypes.push_back(
RankedTensorType::get(retShape, retEltTy, retEncoding));
return success();
}

Expand All @@ -266,8 +239,8 @@ DotOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
inferredReturnTypes.push_back(accTy);

// verify encodings
auto aEnc = cast<TensorOrMemDesc>(operands[0].getType()).getEncoding();
auto bEnc = cast<TensorOrMemDesc>(operands[1].getType()).getEncoding();
auto aEnc = cast<RankedTensorType>(operands[0].getType()).getEncoding();
auto bEnc = cast<RankedTensorType>(operands[1].getType()).getEncoding();
auto retEnc = accTy.getEncoding();
if (aEnc) {
assert(bEnc && retEnc);
Expand Down
45 changes: 45 additions & 0 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "mlir/IR/BuiltinTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

Expand Down Expand Up @@ -131,4 +132,48 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
return success();
}

OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {
// transpose(x, order=[0, 1, ...]) -> x
if (isIota(getOrder())) {
return getSrc();
}

// transpose(transpose(x)) -> transpose(x)
if (auto innerTrans = getSrc().getDefiningOp<MemDescTransOp>()) {
setOrder(applyPermutation(innerTrans.getOrder(), getOrder()));
setOperand(innerTrans.getSrc());
return getResult();
}

return {};
}

LogicalResult MemDescTransOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the input
auto argTy = cast<MemDescType>(operands[0].getType());
auto order = properties.as<Properties *>()->order.asArrayRef();
SmallVector<int64_t> retShape = applyPermutation(argTy.getShape(), order);

auto retEltTy = argTy.getElementType();
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferTransOpEncoding(argEncoding, order, retEncoding)
.failed()) {
return failure();
}
}
auto memDescTy = cast<MemDescType>(argTy);
inferredReturnTypes.push_back(MemDescType::get(
retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(),
memDescTy.getMutableMemory()));
return success();
}

} // namespace mlir::triton::gpu
Loading

0 comments on commit d1cbdb5

Please sign in to comment.