Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Remove memdesc from tt.trans and implements ttg.memdesc_trans #5194

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/triton/Dialect/Triton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)

set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td)
mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs)
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)

set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td)
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)

add_public_tablegen_target(TritonTableGen)
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpInterfaces.h"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
Expand Down
21 changes: 21 additions & 0 deletions include/triton/Dialect/Triton/IR/OpInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef TRITON_IR_OP_INTERFACES_H_
#define TRITON_IR_OP_INTERFACES_H_

#include "mlir/IR/OpDefinition.h"

namespace mlir {

namespace triton {

namespace impl {

LogicalResult verifyTransposeOpInterface(Operation *op);

} // namespace impl

} // namespace triton
} // namespace mlir

#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc"

#endif // TRITON_IR_OP_INTERFACES_H_
36 changes: 36 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOpInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef TRITON_OP_INTERFACES
#define TRITON_OP_INTERFACES

include "mlir/IR/OpBase.td"


def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
let description = [{
This interface is implemented by operations that perform a transpose.
It provides methods to access common properties such as the order attribute and the source operand.
}];

let cppNamespace = "::mlir::triton";

let methods = [
InterfaceMethod<
/*desc=*/[{
Get the source operand of the transposition.
}],
/*retType=*/"::mlir::Value",
/*methodName=*/"getSrc",
/*args=*/(ins)>,
InterfaceMethod<
/*desc=*/[{
Get the order of the transposition.
}],
/*retType=*/"::mlir::ArrayRef<int32_t>",
/*methodName=*/"getOrder",
/*args=*/(ins)>
];

let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
}


#endif // TRITON_OP_INTERFACES
67 changes: 30 additions & 37 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"


//
Expand All @@ -44,8 +41,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
Pure]> {
let summary = "Cast int64 to pointer";

let arguments = (ins TT_I64Like:$src);
Expand All @@ -58,8 +54,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
Pure]> {
let summary = "Cast pointer to int64";

let arguments = (ins TT_PtrLike:$src);
Expand All @@ -73,8 +68,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
Pure]> {
let summary = "Cast between types of the same bitwidth";

let arguments = (ins TT_Type:$src);
Expand All @@ -89,8 +83,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
Pure]> {
let summary = "Floating point casting for custom types";

let description = [{
Expand Down Expand Up @@ -118,8 +111,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
//

def TT_ClampFOp : TT_Op<"clampf", [Elementwise,
SameOperandsAndResultType,
Pure]> {
SameOperandsAndResultType,
Pure]> {
let summary = "Clamp operation for floating point types";

let description = [{
Expand Down Expand Up @@ -149,8 +142,8 @@ def TT_ClampFOp : TT_Op<"clampf", [Elementwise,
//

def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise,
SameOperandsAndResultType,
Pure]> {
SameOperandsAndResultType,
Pure]> {
let summary = "Precise sqrt for floating point types";

let description = [{
Expand All @@ -165,8 +158,8 @@ def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise,
}

def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise,
SameOperandsAndResultType,
Pure]> {
SameOperandsAndResultType,
Pure]> {
let summary = "Precise div for floating point types";

let description = [{
Expand All @@ -181,8 +174,8 @@ def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise,
}

def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise,
SameOperandsAndResultType,
Pure]> {
SameOperandsAndResultType,
Pure]> {
let summary = "Most significant N bits of the 2N-bit product of two integers";

let description = [{
Expand All @@ -200,12 +193,12 @@ def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise,
// Pointer Arith Ops
//
def TT_AddPtrOp : TT_Op<"addptr",
[Pure,
Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">]> {
[Pure,
Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);

let results = (outs TT_PtrLike:$result);
Expand Down Expand Up @@ -546,6 +539,7 @@ def TT_SplitOp : TT_Op<"split", [
}

def TT_TransOp : TT_Op<"trans", [Pure,
TransposeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {

Expand Down Expand Up @@ -579,16 +573,15 @@ def TT_TransOp : TT_Op<"trans", [Pure,
}];

let arguments = (
ins TT_TensorOrMemDesc:$src,
ins TT_Tensor:$src,
DenseI32ArrayAttr:$order
);

let results = (outs TT_TensorOrMemDesc:$result);
let results = (outs TT_Tensor:$result);

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";

let hasFolder = 1;
let hasVerifier = 1;
}

//
Expand Down Expand Up @@ -677,10 +670,10 @@ def TT_DotOp : TT_Op<"dot", [Pure,
// DotScaled Op
//
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
AttrSizedOperandSegments,
DotLike,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
AttrSizedOperandSegments,
DotLike,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot_scaled";

let description = [{
Expand Down Expand Up @@ -783,10 +776,10 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
// External Elementwise op
//
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
SameOperandsAndResultEncoding,
SameVariadicOperandSize,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ConditionallySpeculatable]> {
SameOperandsAndResultEncoding,
SameVariadicOperandSize,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ConditionallySpeculatable]> {

let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/Types.h.inc"

#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc"
#include "triton/Dialect/Triton/IR/TypeInterfaces.h.inc"

namespace mlir {

Expand Down
26 changes: 26 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
Expand Down Expand Up @@ -221,6 +222,31 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
let hasVerifier = 1;
}

def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
TransposeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {
let summary = "transpose the descriptor";

let description = [{
This operation returns a new descriptor
representing a transposed view of the buffer.
}];

let arguments = (ins TT_MemDescType:$src, Variadic<I32>:$order);

let arguments = (
ins TT_MemDescType:$src,
DenseI32ArrayAttr:$order
);

let results = (outs TT_MemDescType:$result);

let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))";

let hasFolder = 1;
}

def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Load a buffer from local memory into a distributed tensor";

Expand Down
5 changes: 2 additions & 3 deletions lib/Analysis/Alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation(
if (isa<triton::gpu::LocalAllocOp>(op)) {
aliasInfo.insert(result);
pessimistic = false;
} else if (isa<triton::gpu::MemDescSubviewOp, triton::TransOp>(op)) {
// extract_slice %src
// trans %src
} else if (isa<triton::gpu::MemDescSubviewOp, triton::gpu::MemDescTransOp>(
op)) {
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else {
Expand Down
45 changes: 29 additions & 16 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,27 +269,38 @@ struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern<ExpandDimsOp> {
return success();
}
};
struct MemDescTransOpConversion
: public ConvertOpToLLVMPattern<MemDescTransOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(MemDescTransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = cast<TensorOrMemDesc>(op.getType());
auto enc = cast<SharedEncodingAttr>(resultTy.getEncoding());
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto dstSmemObj = SharedMemoryObject(
srcSmemObj.base, srcSmemObj.baseElemType,
/*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()),
/*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder()));
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};

struct TransOpConversion : public ConvertOpToLLVMPattern<TransOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = cast<TensorOrMemDesc>(op.getType());
if (auto enc = dyn_cast<SharedEncodingAttr>(resultTy.getEncoding())) {
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto dstSmemObj = SharedMemoryObject(
srcSmemObj.base, srcSmemObj.baseElemType,
/*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()),
/*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder()));
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
} else if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(
resultTy.getEncoding())) {
auto resultTy = cast<RankedTensorType>(op.getType());
if (auto enc =
mlir::dyn_cast<BlockedEncodingAttr>(resultTy.getEncoding())) {
// If the dst encoding is blocked, then TransOp::inferReturnTypes
// ensures that:
// - the src encoding is also blocked, and
Expand All @@ -302,9 +313,10 @@ struct TransOpConversion : public ConvertOpToLLVMPattern<TransOp> {
rewriter.replaceOp(op, ret);
return success();
}
return emitOptionalError(loc, "unsupported encoding for TransOp");
return emitOptionalError(loc, "unsupported encoding for MemDescTransOp");
}
};

struct BroadcastOpConversion
: public ConvertOpToLLVMPattern<triton::BroadcastOp> {
using ConvertOpToLLVMPattern<triton::BroadcastOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -407,6 +419,7 @@ void mlir::triton::populateViewOpToLLVMPatterns(
patterns.add<CatOpConversion>(typeConverter, benefit);
patterns.add<JoinOpConversion>(typeConverter, benefit);
patterns.add<SplitOpConversion>(typeConverter, benefit);
patterns.add<MemDescTransOpConversion>(typeConverter, benefit);
patterns.add<TransOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<MemDescSubviewOpConversion>(typeConverter, benefit);
Expand Down
Loading
Loading