Skip to content

Commit

Permalink
Add verifier for triton_gpu.blocked layout. (#2622)
Browse files Browse the repository at this point in the history
Add verifier for triton_gpu.blocked layout.

Checks that:

- The rank of the layout matches the rank of the tensor it's applied to.
 - The tensor's threads-per-warp, warps-per-cta, and ctas-per-cga all
   match the module.
 - The layout's rank is self-consistent.
 - The layout's `order` and `CTAOrder` fields are permutations of
   0..(rank-1).

Unfortunately it seems we cannot unittest the verifiers on the
attributes themselves.  When one of these verifiers fails, we get an
assert() failure. 🤷

Many lit tests ran afoul of this verifier.  I fixed most of them
manually, but I decided to delete some large tests (apparently generated
code) that had many issues, on the theory that the cost-benefit tradeoff
of fixing these by hand was unfavorable.  (Indeed in many cases it
wasn't clear what the test was intending to check, so I couldn't be sure
that I wasn't rendering the test useless with my changes.)
  • Loading branch information
jlebar authored Nov 18, 2023
1 parent 984e93b commit db6935c
Show file tree
Hide file tree
Showing 21 changed files with 556 additions and 187 deletions.
12 changes: 12 additions & 0 deletions include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace impl {
int constexpr maxTensorNumElements = 1048576;

LogicalResult verifyTensorSize(Operation *op);
LogicalResult verifyTensorLayouts(Operation *op);

LogicalResult verifySameOperandsEncoding(Operation *op,
bool allowTensorPointerType = false);
Expand All @@ -48,6 +49,17 @@ class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
}
};

// Trait applied to all Triton MLIR ops. Checks that the layouts of tensors are
// valid.
template <class ConcreteType>
class VerifyTensorLayoutsTrait
: public TraitBase<ConcreteType, VerifyTensorLayoutsTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyTensorLayouts(op);
}
};

template <typename ConcreteType>
class SameOperandsAndResultEncoding
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
Expand Down
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
include "mlir/IR/OpBase.td"

def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;
Expand Down
3 changes: 2 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
// Op Base
//
class TT_Op<string mnemonic, list<Trait> traits = []> :
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
Op<Triton_Dialect, mnemonic,
!listconcat(traits, [TensorSizeTrait, VerifyTensorLayoutsTrait])> {
}

//
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def CTALayoutAttr : TritonGPU_Attr<"CTALayout"> {
"Unsupported getTotalElemsPerThread in CTALayoutAttr.");
}
}];

let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -477,6 +479,7 @@ for
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
"CTALayoutAttr":$CTALayout
);
let genVerifyDecl = 1;

let builders = [
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
Expand Down
4 changes: 3 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ include "mlir/Interfaces/ViewLikeInterface.td"
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;

class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
Op<TritonGPU_Dialect, mnemonic,
!listconcat(traits, [VerifyTensorLayoutsTrait])> {
}

def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
[SameOperandsAndResultShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def Source1IsSharedEncoding: NativeOpTrait<"Source1IsSharedEncoding">;
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;

class TTNG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonNvidiaGPU_Dialect, mnemonic, traits>;
Op<TritonNvidiaGPU_Dialect, mnemonic,
!listconcat(traits, [VerifyTensorLayoutsTrait])> {
}

// --------------------------------------------------------------------------------------------------
// MBarrier related Ops:
Expand Down
7 changes: 7 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ struct ConvertTritonGPUToLLVM
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);

// Hack: WSMaterialization may have changed the effective number of warps,
// in a way that isn't reflected in triton_gpu.num-warps. If so, we have to
// respect that here.
if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) {
numWarps *= attr.cast<IntegerAttr>().getInt();
}

// Preprocess
decomposeFp8e4b15Convert(mod);
decomposeSplatToSharedLayout(mod, numWarps, threadsPerWarp, numCTAs);
Expand Down
113 changes: 113 additions & 0 deletions lib/Dialect/Triton/IR/Traits.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#include "triton/Dialect/Triton/IR/Traits.h"

#include <numeric>

#include "mlir/IR/TypeUtilities.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

using namespace mlir;
namespace ttg = mlir::triton::gpu;

static LogicalResult verifySameEncoding(Type typeA, Type typeB,
bool allowTensorPointerType) {
Expand Down Expand Up @@ -92,6 +96,115 @@ LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) {
return success();
}

template <typename T> static int64_t accumProduct(T &&container) {
return std::accumulate(container.begin(), container.end(), 1,
std::multiplies<int64_t>());
}

// Check that the Triton layouts on op's operands and return types are valid.
// For example, we check that the number of warps per block in a Triton GPU
// blocked layout matches that of its module.
//
// It's a little weird to check these properties of a layout only when the
// layout is used in an op, since most of the properties don't actually depend
// on the op. They do depend on the *module*, though, and a layout is attached
// to a module only by virtue of being used in one of the module's ops.
LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) {
auto module = op->getParentOfType<ModuleOp>();
auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult {
// Only ranked tensors can have layouts.
auto rankedTy = val.getType().dyn_cast<RankedTensorType>();
if (!rankedTy)
return success();

mlir::Attribute layout = rankedTy.getEncoding();
if (!layout)
return success();

// TODO(jlebar): Currently this only checks blocked layouts, but other
// layouts also have invariants!

// TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
if (auto blocked = layout.dyn_cast<ttg::BlockedEncodingAttr>()) {
// A different verifier should have checked that the layout itself is
// valid, including that threads-per-warp has the same rank as
// warps-per-block etc.
auto layoutRank = blocked.getThreadsPerWarp().size();
if (layoutRank != rankedTy.getRank()) {
return makeErr() << layout << ".\nLayout has rank " << layoutRank
<< ", but the tensor it's attached to has rank "
<< rankedTy.getRank() << ".";
}

int moduleThreadsPerWarp =
ttg::TritonGPUDialect::getThreadsPerWarp(module);
int64_t layoutThreadsPerWarp = accumProduct(blocked.getThreadsPerWarp());
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
return makeErr() << layout << ".\nLayout has a total of "
<< layoutThreadsPerWarp
<< " threads per warp, but the module specifies "
<< moduleThreadsPerWarp << " threads per warp.";
}

int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module);
int64_t layoutWarpsPerCTA = accumProduct(blocked.getWarpsPerCTA());
if (layoutWarpsPerCTA != moduleWarpsPerCTA) {
return makeErr() << layout << ".\nLayout has a total of "
<< layoutWarpsPerCTA
<< " warps per CTA, but the module specifies "
<< moduleWarpsPerCTA << " warps per CTA.";
}

if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) {
int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module);
int64_t layoutCTAsPerCGA =
accumProduct(blocked.getCTALayout().getCTAsPerCGA());
if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
return makeErr() << layout << ".\nLayout has a total of "
<< layoutCTAsPerCGA
<< " CTAs per CGA, but the module specifies "
<< moduleCTAsPerCGA << " CTAs per CGA.";
}
}
}

return success();
};

for (size_t i = 0; i < op->getNumOperands(); i++) {
auto operand = op->getOperand(i);
auto err = checkLayout(operand, [&]() {
// Stringify the operand using `printAsOperand`. This prints e.g. "%42"
// rather than the full definition.
std::string operandStr;
llvm::raw_string_ostream os(operandStr);
// If we don't assume verified, dump() will recursively call this
// function!
operand.printAsOperand(os, OpPrintingFlags().assumeVerified());

return op->emitError("Operand ")
<< i << " (" << operand << ") has an invalid layout: ";
});
if (!err.succeeded())
return err;
}

for (size_t i = 0; i < op->getNumResults(); i++) {
auto result = op->getResult(i);
auto err = checkLayout(result, [&]() {
if (op->getNumResults() == 1) {
return op->emitError("Result has an invalid layout: ");
} else {
return op->emitError("Result ") << i << " has an invalid layout: ";
}
});
if (!err.succeeded())
return err;
}

return success();
}

static ArrayRef<int64_t> getTypeShape(Type type) {
auto rankedType = type.dyn_cast<RankedTensorType>();
if (auto ptrType = type.dyn_cast<triton::PointerType>())
Expand Down
57 changes: 57 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,63 @@ bool isExpensiveCat(CatOp cat, Attribute targetEncoding) {
return newTotalElemsPerThread < totalElemsPerThread;
}

// Is `vals` some permutation of the numbers 0..(vals.size()-1)?
static bool isPermutationOfIota(ArrayRef<unsigned> vals) {
SmallVector<unsigned, 4> sorted(vals.begin(), vals.end());
llvm::sort(sorted);
for (int i = 0; i < sorted.size(); i++) {
if (sorted[i] != i) {
return false;
}
}
return true;
}

LogicalResult CTALayoutAttr::verify(
function_ref<InFlightDiagnostic()> emitError, ArrayRef<unsigned> CTAsPerCGA,
ArrayRef<unsigned> CTASplitNum, ArrayRef<unsigned> CTAOrder) {
if (CTAsPerCGA.size() != CTASplitNum.size() ||
CTASplitNum.size() != CTAOrder.size()) {
return emitError() << "CTAsPerCTA, CTASplitNum, and CTAOrder must all have "

This comment has been minimized.

Copy link
@alexander-zinoviev

alexander-zinoviev Nov 18, 2023

Contributor

CTAsPerCGA

It would be nice to have the actual values in the error string too.

"the same rank.";
}

if (!isPermutationOfIota(CTAOrder)) {
return emitError()
<< "CTAOrder must be a permutation of 0..(rank-1), but was ["
<< CTAOrder << "]";
}
return success();
}

LogicalResult
BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<unsigned> sizePerThread,
ArrayRef<unsigned> threadsPerWarp,
ArrayRef<unsigned> warpsPerCTA,
ArrayRef<unsigned> order, CTALayoutAttr CTALayout) {
if (sizePerThread.size() != threadsPerWarp.size() ||
threadsPerWarp.size() != warpsPerCTA.size() ||
warpsPerCTA.size() != order.size()) {
return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and "
"order must all have the same rank.";
}

// Empty CTALayout is allowed, but if it's present its rank must match the
// BlockedEncodingAttr's rank.
if (CTALayout.getCTASplitNum().size() != 0 &&
sizePerThread.size() != CTALayout.getCTASplitNum().size()) {
return emitError() << "BlockedEncodingAttr and CTALayout's fields must "
"have the same rank.";
}
if (!isPermutationOfIota(order)) {
return emitError()
<< "order must be a permutation of 0..(rank-1), but was [" << order
<< "]";
}
return success();
}

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down
58 changes: 44 additions & 14 deletions lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@

#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

#include <set>

#include "mlir/IR/OperationSupport.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"

#include <set>

using namespace mlir;
namespace ttg = triton::gpu;
namespace ttng = triton::nvidia_gpu;
Expand Down Expand Up @@ -682,10 +682,13 @@ void tryRegisterRealloc(ModuleOp mod) {
}
}

//===----------------------------------------------------------------------===//
// WSMaterializationPass
//===----------------------------------------------------------------------===//

// This pass adds top-level `if` statements to the module so that:
// - there's one group of four warps that handles memory operations, and
// - there are one or two groups of four warps handling math operations.
//
// If we use two groups for math operations, it's in "ping-pong" fashion:
// The memory warp group does loads/stores for group A while group B runs, then
// it switches to serving group B.
struct WSMaterializationPass
: public TritonGPUWSMaterializationBase<WSMaterializationPass> {
WSMaterializationPass() = default;
Expand All @@ -710,23 +713,50 @@ struct WSMaterializationPass
materializeMutexOperations(mod);
tryRegisterRealloc(mod);

// TODO: More flexible way to set num-warps
// One dma, one math warp group, set num-warps = 8
auto i32_ty = IntegerType::get(mod->getContext(), 32);
mod->setAttr("triton_gpu.num-warps",
IntegerAttr::get(i32_ty, llvm::APInt(32, 8)));

// The IR before this pass specifies 4 warps per CTA. But this pass splits
// things so that there are 4 warps per *group* (aka "agent"), and there are
// 2 or 3 groups. So from CUDA's perspective there are now 8 or 12 warps
// per CTA.
//
// A natural thing to do would be to change the module's num-warps property
// to be 8 or 12 to match the new reality. But it's actually better to keep
// num-warps as 4, because the 2 or 3 groups are not working
// collaboratively.
//
// For example, tensors are not distributed between the groups. A blocked
// layout with `warpsPerCta = [2,2]` (implying the data is distributed among
// 4 warps) makes sense, but `warpsPerCta = [4,2]` should not appear in the
// IR, because this would be saying that the data is distributed among 8
// warps, which would mean that its data is distributed between the groups.
// It's an invariant that the product of the warpsPerCta equals the module's
// num-warps, so this implies the module's num-warps should be 4.
//
// As another example, there is code that checks whether a load is
// "expensive" by comparing the number of elements loaded to the number of
// warps in the module. Here too we should compare the number of elements
// being loaded to 4 warps, because only the 4 warps from the load/store
// group are participating in the load.
//
// But at some point (at least when we launch the kernel!) we really do need
// to know that the CTA has 8 or 12 warps in it. So instead of modifying
// num-warps, we add a *new* attribute to the module that indicates how many
// warp groups there are, and we modify users that need to know the "true"
// number of warps to read it.
int32_t numWarpGroups = 2;
WalkResult result = mod->walk([&](scf::IfOp ifOp) {
if (ifOp->hasAttr("agent.num-roles")) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (result.wasInterrupted()) {
mod->setAttr("triton_gpu.num-warps",
IntegerAttr::get(i32_ty, llvm::APInt(32, 12)));
numWarpGroups = 3;
}
mod->removeAttr("async.num-agents");

auto builder = OpBuilder::atBlockBegin(mod.getBody());
mod->setAttr("triton_gpu.num-warp-groups-per-cta",
builder.getI32IntegerAttr(numWarpGroups));
}
};

Expand Down
Loading

0 comments on commit db6935c

Please sign in to comment.