Skip to content

Commit

Permalink
Record narrow static M/N sizes in EncodingAttr and rationalize Mate…
Browse files Browse the repository at this point in the history
…rializeEncoding for narrow shapes. (#15431)

This changes how we approach narrow matmul tile size selection (in
particular, vecmat/matvec), from "we don't really care that much, so
let's derive narrow tiles from general ones by just truncation", to "we
actually care, at least in specific cases, about freely controlling
narrow matmul tiles independently of the general wide matmul case."

There are 2 immediate needs for this: @dcaballe was doing something
comparable in #15421 to generally unlock better AVX-512 codegen for
`f32` `vecmat`, and I have a specific need for this in #15158 for some
`s16 x u4` quantized `vecmat` case.

The solution proposed here is more general than the one in #15241 in
that it is not only about `vecmat` and `matvec`, it supports any
narrow-M / narrow-N case. Like #15241, it does so by extending
`EncodingAttr` in some way. Unlike #15241, it does so by adding optional
narrow-M / narrow-N integer attributes, instead of extending the `user`
enum.

Along the way, this rationalizes MaterializeEncoding code around
tile-size selection. Narrow tile sizes are now explicitly enumerated,
and the enumeration of tile sizes is now clearly decoupled from the
choosing among the enumerated tile sizes.

Another change made along the way: this changes the tile shape
convention around here from MxKxN to MxNxK, bringing this in line with
the convention in use in ukernels. The motivation for this convention is
that the MxN part here is particularly important as the output tile
shape, so it helps that the MxNxK convention has that as a contiguous
subset.

To avoid useless redundancy as the narrow-N case is almost 100%
symmetrical to the narrow-M case, the enumeration only goes over
narrow-M cases, and the handling of narrow-N is deferred to the choose
function, transposing the problem to derive narrow-N tiles from narrow-M
tiles. For `vecmat`/`matvec`, the symmetry is perfect, as the
accumulator tile is 1D in these cases, there is no difference at all.
For other non-vector narrow cases, there could conceivably be a
difference someday motivating decoupling narrow-N from narrow-M, but
this is sufficiently far-fetched that it's best to left that to be dealt
with then a concrete use case arises, and enjoy the benefit of smaller
code until then.
  • Loading branch information
bjacob authored Nov 8, 2023
1 parent d509096 commit 7ab3509
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 195 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -233,28 +233,28 @@ func.func @matvec_lowering_f32f32f32_aarch64() attributes {
} {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0)
: !flow.dispatch.tensor<readonly:tensor<16x16xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32]>>>
: !flow.dispatch.tensor<readonly:tensor<16x16xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0)
: !flow.dispatch.tensor<readonly:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32]>>>
: !flow.dispatch.tensor<readonly:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0)
: !flow.dispatch.tensor<readwrite:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32]>>>
: !flow.dispatch.tensor<readwrite:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 16], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<16x16xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32]>>>
-> tensor<16x16xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32]>>
: !flow.dispatch.tensor<readonly:tensor<16x16xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>>
-> tensor<16x16xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [16, 1], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32]>>>
-> tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32]>>
: !flow.dispatch.tensor<readonly:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>>
-> tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [16, 1], strides = [1, 1]
: !flow.dispatch.tensor<readwrite:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32]>>>
-> tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32]>>
: !flow.dispatch.tensor<readwrite:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>>
-> tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>
%6 = linalg.matmul
ins(%3, %4 : tensor<16x16xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32]>>,
tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32]>>)
outs(%5 : tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32]>>)
-> tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32]>>
ins(%3, %4 : tensor<16x16xf32, #iree_linalg_ext.encoding<user = MATMUL, role = LHS, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>,
tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RHS, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>)
outs(%5 : tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>)
-> tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [16, 1], strides = [1, 1]
: tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32]>>
-> !flow.dispatch.tensor<readwrite:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32]>>>
: tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>
-> !flow.dispatch.tensor<readwrite:tensor<16x1xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32], matmul_narrow_N = 1 : index>>>
return
}
// CHECK: func @matvec_lowering_f32f32f32_aarch64()
Expand Down
4 changes: 0 additions & 4 deletions compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
namespace mlir {
namespace iree_compiler {

void adjustTileSizesToNarrowStaticShape(
IREE::LinalgExt::MaterializeEncodingInfo &encodingInfo,
ArrayRef<int64_t> shape);

IREE::LinalgExt::MaterializeEncodingValueFn
getMaterializeEncodingValueFn(IREE::HAL::ExecutableTargetAttr targetAttr);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,53 +269,6 @@ struct MaterializeFlowDispatchTensorStoreOp

} // namespace

void adjustTileSizesToNarrowStaticShape(MaterializeEncodingInfo &encodingInfo,
ArrayRef<int64_t> shape) {
for (size_t i = 0; i < encodingInfo.innerDimsPos.size(); i++) {
int64_t size = shape[encodingInfo.innerDimsPos[i]];
// Dynamic sizes are assumed to be large enough, not to be candidates for
// narrow kernels.
if (ShapedType::isDynamic(size))
continue;
int64_t &tileSize = encodingInfo.innerTileSizes[i];
// Let's not try to handle any dynamic tile sizes here. We could handle the
// case where size==1 (as whatever is the runtime value of tileSize, it
// can't be less than that, so it should be OK to replace it with 1) but
// in general, adjusting dynamic tile sizes has to be done by the
// materializeEncodingValueFn which we obtain those tileSizes from.
if (ShapedType::isDynamic(tileSize))
continue;
// Adjust tile sizes for narrow cases: ensure that narrow sizes (those that
// are less than the normal tileSize) don't get padded to more than the
// next power of two, or tileSize, whichever is smaller.
//
// For example, if size==1, always adjust tileSize to be 1, so that
// matrix-times-vector problems remain that, instead of becoming more
// general matrix-times-matrix.
//
// Another example, if tileSize==6, then:
//
// Original tensor size | adjusted tileSize
// -------------------- | -----------------
// 1 | 1
// 2 | 2
// 3 | 4
// 4 | 4
// 5 | 6
// >= 6 | 6
//
// Note: this implies that microkernels that implement a code path for
// a given `tileSize` value should also implement alternative code paths
// for all powers of two smaller than `tileSize`, as those could end up
// being selected here, and would fall back on slow generic code if no
// optimized code path is provided.
for (int po2 = 1; po2 < tileSize; po2 *= 2) {
if (size <= po2 && tileSize >= po2)
tileSize = po2;
}
}
}

static FailureOr<MaterializeEncodingValueInfo>
chooseDynamicEncodingInfoVMVXMicrokernels(RankedTensorType tensorType,
OpBuilder &builder, Location loc) {
Expand Down
93 changes: 65 additions & 28 deletions compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -76,19 +77,46 @@ static Value pad(OpBuilder &builder, Location loc, Value source,
lowPad, highPad, zero);
}

static Value setEncoding(OpBuilder &builder, Location loc, Value source,
IREE::LinalgExt::EncodingAttr encodingAttr) {
Value setEncoding(OpBuilder &builder, Location loc, Value source,
IREE::LinalgExt::EncodingAttr encodingAttr) {
auto sourceType = source.getType().cast<RankedTensorType>();
auto resultType = RankedTensorType::get(
sourceType.getShape(), sourceType.getElementType(), encodingAttr);
return builder.create<IREE::LinalgExt::SetEncodingOp>(loc, resultType,
source);
};

struct MatmulNarrowSizes {
std::optional<int64_t> M, N;
};

// Returns the minimum of static sizes of the M-dimension in the types of the
// LHS and/or the Output operand of a matmul, whichever is static.
static MatmulNarrowSizes getMatmulNarrowSizes(ShapedType outType) {
int64_t M = outType.getDimSize(0);
int64_t N = outType.getDimSize(1);
MatmulNarrowSizes narrow;
// Threshold below which a M/N size is considered "narrow", making it
// eligible for a narrow tile size during materialization. This value should
// be at least as large as the actual M/N tile sizes that we choose on any
// target in CPUMaterializeEncodingPass. If it is smaller, we will miss
// opportunities to select optimized narrow tiles for narrow matmuls.
// If it is larger, everything will work fine, but the IR will be a bit more
// verbose as more narrow_matmul_{M,N} optional parameters will be specified.
const int64_t kNarrowThreshold = 16;
if (!ShapedType::isDynamic(M) && M < kNarrowThreshold) {
narrow.M = M;
}
if (!ShapedType::isDynamic(N) && N < kNarrowThreshold) {
narrow.N = N;
}
return narrow;
}

static IREE::LinalgExt::EncodingAttr
makeEncoding(OpBuilder &builder, IREE::LinalgExt::EncodingUser user,
IREE::LinalgExt::EncodingRole role, TypeRange operandTypes,
Type originalType) {
Type originalType, MatmulNarrowSizes narrow) {
auto *context = builder.getContext();
auto userAttr = IREE::LinalgExt::EncodingUserAttr::get(context, user);
auto roleAttr = IREE::LinalgExt::EncodingRoleAttr::get(context, role);
Expand All @@ -100,19 +128,24 @@ makeEncoding(OpBuilder &builder, IREE::LinalgExt::EncodingUser user,
auto operandElemTypesAttr = ArrayAttr::get(context, elemTypeAttrs);
auto originalTypeAttr =
originalType ? TypeAttr::get(originalType) : TypeAttr{};
auto getAttr = [&](std::optional<int64_t> x) {
return x ? builder.getIndexAttr(*x) : IntegerAttr();
};
return IREE::LinalgExt::EncodingAttr::get(
context, userAttr, roleAttr, operandElemTypesAttr, originalTypeAttr);
context, userAttr, roleAttr, operandElemTypesAttr, originalTypeAttr,
getAttr(narrow.M), getAttr(narrow.N));
}

static Value padAndSetEncoding(OpBuilder &builder, Location loc, Value source,
IREE::LinalgExt::EncodingUser user,
IREE::LinalgExt::EncodingRole role,
TypeRange operandTypes) {
// No need to specify original_type in the encoding poadded to pad(), because
TypeRange operandTypes,
MatmulNarrowSizes narrow) {
// No need to specify original_type in the encoding passed to pad(), because
// the operand there is the `source` tensor, so it will default to reading its
// original shape.
auto encodingForPad =
makeEncoding(builder, user, role, operandTypes, /*originalType=*/Type{});
auto encodingForPad = makeEncoding(builder, user, role, operandTypes,
/*originalType=*/Type{}, narrow);
Value padded = pad(builder, loc, source, encodingForPad);
// For setEncoding() below, we potentially need to specify an encoding with an
// explicit original_type, because the operand there is the padded tensor
Expand All @@ -122,8 +155,8 @@ static Value padAndSetEncoding(OpBuilder &builder, Location loc, Value source,
// the tensor type that the encoding is applied to.
auto encodingForSetEncoding = encodingForPad;
if (padded.getType() != source.getType()) {
encodingForSetEncoding =
makeEncoding(builder, user, role, operandTypes, source.getType());
encodingForSetEncoding = makeEncoding(builder, user, role, operandTypes,
source.getType(), narrow);
}
return setEncoding(builder, loc, padded, encodingForSetEncoding);
}
Expand Down Expand Up @@ -193,17 +226,19 @@ struct SetMatmulEncoding : public OpRewritePattern<linalg::MatmulOp> {
}

IREE::LinalgExt::EncodingUser user = IREE::LinalgExt::EncodingUser::MATMUL;
MatmulNarrowSizes narrowSizes =
getMatmulNarrowSizes(origOut.getType().cast<ShapedType>());
Location loc = matmulOp.getLoc();
TypeRange operandTypes = matmulOp->getOperandTypes();
Value encodedLhs =
padAndSetEncoding(rewriter, loc, origLhs, user,
IREE::LinalgExt::EncodingRole::LHS, operandTypes);
Value encodedRhs =
padAndSetEncoding(rewriter, loc, origRhs, user,
IREE::LinalgExt::EncodingRole::RHS, operandTypes);
Value encodedOut =
padAndSetEncoding(rewriter, loc, origOut, user,
IREE::LinalgExt::EncodingRole::RESULT, operandTypes);
Value encodedLhs = padAndSetEncoding(rewriter, loc, origLhs, user,
IREE::LinalgExt::EncodingRole::LHS,
operandTypes, narrowSizes);
Value encodedRhs = padAndSetEncoding(rewriter, loc, origRhs, user,
IREE::LinalgExt::EncodingRole::RHS,
operandTypes, narrowSizes);
Value encodedOut = padAndSetEncoding(rewriter, loc, origOut, user,
IREE::LinalgExt::EncodingRole::RESULT,
operandTypes, narrowSizes);

Value matmulTiled = rewriter
.create<linalg::MatmulOp>(
Expand Down Expand Up @@ -272,17 +307,19 @@ struct SetBatchMatmulEncoding : public OpRewritePattern<linalg::BatchMatmulOp> {

IREE::LinalgExt::EncodingUser user =
IREE::LinalgExt::EncodingUser::BATCH_MATMUL;
MatmulNarrowSizes narrowSizes =
getMatmulNarrowSizes(origOut.getType().cast<ShapedType>());
Location loc = matmulOp.getLoc();
TypeRange operandTypes = matmulOp->getOperandTypes();
Value encodedLhs =
padAndSetEncoding(rewriter, loc, origLhs, user,
IREE::LinalgExt::EncodingRole::LHS, operandTypes);
Value encodedRhs =
padAndSetEncoding(rewriter, loc, origRhs, user,
IREE::LinalgExt::EncodingRole::RHS, operandTypes);
Value encodedOut =
padAndSetEncoding(rewriter, loc, origOut, user,
IREE::LinalgExt::EncodingRole::RESULT, operandTypes);
Value encodedLhs = padAndSetEncoding(rewriter, loc, origLhs, user,
IREE::LinalgExt::EncodingRole::LHS,
operandTypes, narrowSizes);
Value encodedRhs = padAndSetEncoding(rewriter, loc, origRhs, user,
IREE::LinalgExt::EncodingRole::RHS,
operandTypes, narrowSizes);
Value encodedOut = padAndSetEncoding(rewriter, loc, origOut, user,
IREE::LinalgExt::EncodingRole::RESULT,
operandTypes, narrowSizes);

Value matmulTiled = rewriter
.create<linalg::BatchMatmulOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def EncodingAttr :
AttrParameter<"EncodingUserAttr", "kind of operation using this tensor">:$user,
AttrParameter<"EncodingRoleAttr", "role of this tensor as an operand">:$role,
AttrParameter<"ArrayAttr", "element types of the user's operands">:$element_types,
OptionalParameter<"TypeAttr", "type of the original tensor type before padding">:$original_type
OptionalParameter<"TypeAttr", "type of the original tensor type before padding">:$original_type,
// TODO(#15466): generalize matmul_narrow_{M,N} into a list?
OptionalParameter<"IntegerAttr", "optional M narrow dimension size (only for MATMUL and BATCH_MATMUL users)">:$matmul_narrow_M,
OptionalParameter<"IntegerAttr", "optional N narrow dimension size (only for MATMUL and BATCH_MATMUL users)">:$matmul_narrow_N
);

let genVerifyDecl = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ bool isMatmulEncodingUser(EncodingUser user);
// Check if encoding user is one of batch matmul encodings.
bool isBatchMatmulEncodingUser(EncodingUser user);

struct MatmulTileParams {
struct TileMxNxK {
int64_t M = 1;
int64_t K = 1;
int64_t N = 1;
int64_t K = 1;
};

MaterializeEncodingInfo
chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
MatmulTileParams tileParams);
MaterializeEncodingInfo getEncodingInfoForMatmul(EncodingUser user,
EncodingRole role,
TileMxNxK tileMxNxK);

} // namespace LinalgExt
} // namespace IREE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ chooseEncodingInfo(RankedTensorType tensorType) {
case EncodingUser::MATMUL:
case EncodingUser::BATCH_MATMUL:
if (tensorType.getElementType().isF32()) {
return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8});
return getEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 8, 4});
}
}
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ bool isBatchMatmulEncodingUser(EncodingUser user) {
return user == EncodingUser::BATCH_MATMUL;
}

MaterializeEncodingInfo
chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
MatmulTileParams tileParams) {
MaterializeEncodingInfo getEncodingInfoForMatmul(EncodingUser user,
EncodingRole role,
TileMxNxK tileMxNxK) {
// Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
int64_t matmulDimBase = isBatchMatmulEncodingUser(user) ? 1 : 0;

MaterializeEncodingInfo encodingInfo;
encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
switch (role) {
case (EncodingRole::LHS): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
encodingInfo.innerTileSizes = {tileMxNxK.M, tileMxNxK.K};
break;
}
case (EncodingRole::RHS): {
encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
encodingInfo.innerTileSizes = {tileMxNxK.N, tileMxNxK.K};
encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase};
encodingInfo.outerDimsPerm =
llvm::to_vector(llvm::seq<int64_t>(0, matmulDimBase));
Expand All @@ -42,7 +42,7 @@ chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
break;
}
case (EncodingRole::RESULT): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
encodingInfo.innerTileSizes = {tileMxNxK.M, tileMxNxK.N};
break;
}
default: {
Expand Down

0 comments on commit 7ab3509

Please sign in to comment.