Skip to content

Commit

Permalink
Data-tiling encodings: take the element types out of the enums. (#15182)
Browse files Browse the repository at this point in the history
This has been discussed for a while: ever since #14336 made encodings a
data structure, it was an odd remnant that we were still encoding the
element types tuple in the user enum. This was cumbersome, and
resurfaced in every design discussion as it looked like something that
wasn't scaling with new data types. Concretely, this is good to fix
ahead of adding `i16xi16` and `i16xi4` data-tiling support (#15158).
  • Loading branch information
bjacob authored Oct 16, 2023
1 parent 2b5e61f commit 9d7a4ba
Show file tree
Hide file tree
Showing 23 changed files with 749 additions and 805 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -30,114 +32,118 @@ using IREE::HAL::ExecutableTargetAttr;

namespace {

static MatmulTileParams
static FailureOr<MatmulTileParams>
chooseMatmulTileParamsGeneric(ExecutableTargetAttr target) {
if (isVMVXBackend(target) && hasMicrokernels(target)) {
// VMVX+ukernel uses dynamic tile shapes.
return {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic};
return MatmulTileParams{ShapedType::kDynamic, ShapedType::kDynamic,
ShapedType::kDynamic};
} else {
// Some vaguely reasonable static tile shape.
return {8, 4, 8};
return MatmulTileParams{8, 4, 8};
}
}

static MatmulTileParams
chooseMatmulTileParamsAArch64(EncodingUser user, ExecutableTargetAttr target) {
switch (user) {
case EncodingUser::MATMUL_F32F32F32:
case EncodingUser::MATMUL_F16F16F32:
case EncodingUser::MATMUL_F16F16F16:
case EncodingUser::MATMUL_BF16BF16F32:
case EncodingUser::MATMUL_BF16BF16BF16:
case EncodingUser::BATCH_MATMUL_F32F32F32:
case EncodingUser::BATCH_MATMUL_F16F16F32:
case EncodingUser::BATCH_MATMUL_F16F16F16:
case EncodingUser::BATCH_MATMUL_BF16BF16F32:
case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
static FailureOr<MatmulTileParams>
chooseMatmulTileParamsAArch64(EncodingUser user, TypeRange elementTypes,
ExecutableTargetAttr target) {
if (user != EncodingUser::MATMUL && user != EncodingUser::BATCH_MATMUL) {
return failure();
}

assert(elementTypes.size() == 3);
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
Type out = elementTypes[2];

if (out.isF32() || out.isF16() || out.isBF16()) {
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
// reconsider when taking advantage of native f16/bf16 arithmetic when the
// accumulator itself is f16/bf16.
return {8, 1, 8};
case EncodingUser::MATMUL_I8I8I32:
case EncodingUser::BATCH_MATMUL_I8I8I32:
return MatmulTileParams{8, 1, 8};
}

if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
out.isSignlessInteger(32)) {
if (hasFeature(target, "+i8mm")) {
// Aim to use SMMLA.
return {8, 8, 8};
return MatmulTileParams{8, 8, 8};
}
if (hasFeature(target, "+dotprod")) {
// Aim to use SDOT.
return {8, 4, 8};
return MatmulTileParams{8, 4, 8};
}
return {8, 1, 8};
default:
assert(false);
return {};
return MatmulTileParams{8, 1, 8};
}

return failure();
}

static MatmulTileParams
chooseMatmulTileParamsX86_64(EncodingUser user, ExecutableTargetAttr target) {
switch (user) {
case EncodingUser::MATMUL_F32F32F32:
case EncodingUser::MATMUL_F16F16F32:
case EncodingUser::MATMUL_F16F16F16:
case EncodingUser::MATMUL_BF16BF16F32:
case EncodingUser::MATMUL_BF16BF16BF16:
case EncodingUser::BATCH_MATMUL_F32F32F32:
case EncodingUser::BATCH_MATMUL_F16F16F32:
case EncodingUser::BATCH_MATMUL_F16F16F16:
case EncodingUser::BATCH_MATMUL_BF16BF16F32:
case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
static FailureOr<MatmulTileParams>
chooseMatmulTileParamsX86_64(EncodingUser user, TypeRange elementTypes,
ExecutableTargetAttr target) {
if (user != EncodingUser::MATMUL && user != EncodingUser::BATCH_MATMUL) {
return failure();
}

assert(elementTypes.size() == 3);
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
Type out = elementTypes[2];

if (out.isF32() || out.isF16() || out.isBF16()) {
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
// reconsider when taking advantage of native f16/bf16 arithmetic when the
// accumulator itself is f16/bf16.
if (hasFeature(target, "+avx512f")) {
return {16, 1, 16};
return MatmulTileParams{16, 1, 16};
}
if (hasFeature(target, "+avx")) {
// Note: for good performance, most +avx users will also want to add
// +fma, but that's a local instruction selection detail and the tile
// layout is unaffected, as there are enough registers even with the
// need for intermediate product registers when +fma is not used.
return {8, 1, 8};
return MatmulTileParams{8, 1, 8};
}
// SSE fallback.
return {8, 1, 4};
case EncodingUser::MATMUL_I8I8I32:
case EncodingUser::BATCH_MATMUL_I8I8I32:
return MatmulTileParams{8, 1, 4};
}

if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
out.isSignlessInteger(32)) {
if (hasFeature(target, "+avx512vnni")) {
// Aim to use VPDPWSSD. This is the same tile size as with VPMADDWD
// as the only difference is that VPDPWSSD accumulates. VPDPBUSD would
// call for {16, 4, 16} but we can't use it because of its unsigned LHS.
return {16, 2, 16};
return MatmulTileParams{16, 2, 16};
}
if (hasFeature(target, "+avx512bw")) {
// Aim to use VPMADDWD (zmm).
return {16, 2, 16};
return MatmulTileParams{16, 2, 16};
}
if (hasFeature(target, "+avx2")) {
// Aim to use VPMADDWD (ymm).
return {8, 2, 8};
return MatmulTileParams{8, 2, 8};
}
// SSE fallback. Aim to use PMADDWD (xmm).
return {8, 2, 4};
default:
assert(false);
return {};
return MatmulTileParams{8, 2, 4};
}

return failure();
}

static MatmulTileParams chooseMatmulTileParams(EncodingUser user,
ExecutableTargetAttr target) {
static FailureOr<MatmulTileParams>
chooseMatmulTileParams(EncodingUser user, TypeRange elementTypes,
ExecutableTargetAttr target) {
if (isAArch64(target)) {
return chooseMatmulTileParamsAArch64(user, target);
return chooseMatmulTileParamsAArch64(user, elementTypes, target);
}
if (isX86_64(target)) {
return chooseMatmulTileParamsX86_64(user, target);
return chooseMatmulTileParamsX86_64(user, elementTypes, target);
}
return chooseMatmulTileParamsGeneric(target);
}
Expand Down Expand Up @@ -184,9 +190,17 @@ materializeEncodingForTarget(RankedTensorType tensorType,
}
auto user = encoding.getUser().getValue();
auto role = encoding.getRole().getValue();
MatmulTileParams tileParams = chooseMatmulTileParams(user, targetAttr);
auto elementTypes = llvm::to_vector(
llvm::map_range(encoding.getElementTypes().getValue(), [](Attribute a) {
return a.cast<TypeAttr>().getValue();
}));
FailureOr<MatmulTileParams> tileParams =
chooseMatmulTileParams(user, elementTypes, targetAttr);
if (failed(tileParams)) {
return failure();
}
auto encodingInfo =
IREE::LinalgExt::chooseEncodingInfoForMatmul(user, role, tileParams);
IREE::LinalgExt::chooseEncodingInfoForMatmul(user, role, *tileParams);
auto originalTypeAttr = encoding.getOriginalType();
auto originalType = originalTypeAttr
? originalTypeAttr.getValue().cast<RankedTensorType>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ func.func @non_perfect_tiling_unpack() {
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%0:2 = iree_codegen.query_tile_sizes tensor<16x16xi32, #iree_linalg_ext.encoding<user = MATMUL_I8I8I32, role = RESULT>> -> index, index
%0:2 = iree_codegen.query_tile_sizes tensor<16x16xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i8, i8, i32]>> -> index, index
%1 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#0]
%2 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#1]
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c512) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%1, %2, %0#0, %0#1}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2242,7 +2242,7 @@ hal.executable private @dynamic_unpack_fusion {
%cst = arith.constant dense<[-918, -4433, 87, -234, -21393, 7738, 529, -8835, -16817, -375, -199, 572, 5082, 15569, -186, 4955]> : tensor<16xi32>
%c12544 = arith.constant 12544 : index
%c16 = arith.constant 16 : index
%0:2 = iree_codegen.query_tile_sizes tensor<12544x16xi32, #iree_linalg_ext.encoding<user = MATMUL_I8I8I32, role = RESULT>> -> index, index
%0:2 = iree_codegen.query_tile_sizes tensor<12544x16xi32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [i8, i8, i32]>> -> index, index
%1 = affine.apply affine_map<()[s0] -> (12544 ceildiv s0)>()[%0#0]
%2 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#1]
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c200960) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%1, %2, %0#0, %0#1}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,37 +383,46 @@ matchDAGForUKernel(RewriterBase &rewriter, tensor::UnPackOp op,
genericMicroKernelOp.getOperation());
}

static uint32_t flagForUser(IREE::LinalgExt::EncodingUser user) {
switch (user) {
case IREE::LinalgExt::EncodingUser::MATMUL_F32F32F32:
static uint32_t
getFlagForUserAndOperandTypes(IREE::LinalgExt::EncodingUser user,
ArrayRef<Attribute> operandTypes) {
if (user != IREE::LinalgExt::EncodingUser::MATMUL ||
operandTypes.size() != 3) {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_NONE;
}

Type lhs = operandTypes[0].cast<TypeAttr>().getValue();
Type rhs = operandTypes[1].cast<TypeAttr>().getValue();
Type out = operandTypes[2].cast<TypeAttr>().getValue();

if (lhs.isF32() && rhs.isF32() && out.isF32()) {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32;
case IREE::LinalgExt::EncodingUser::MATMUL_I8I8I32:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32;
case IREE::LinalgExt::EncodingUser::MATMUL_F16F16F32:
} else if (lhs.isF16() && rhs.isF16() && out.isF32()) {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F32;
case IREE::LinalgExt::EncodingUser::MATMUL_F16F16F16:
} else if (lhs.isF16() && rhs.isF16() && out.isF16()) {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F16;
case IREE::LinalgExt::EncodingUser::MATMUL_BF16BF16F32:
} else if (lhs.isBF16() && rhs.isBF16() && out.isF32()) {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16F32;
case IREE::LinalgExt::EncodingUser::MATMUL_BF16BF16BF16:
} else if (lhs.isBF16() && rhs.isBF16() && out.isBF16()) {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16BF16;
default: // Unreachable.
assert(false);
} else if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
out.isSignlessInteger(32)) {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32;
} else {
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_NONE;
}
}

static uint32_t flagForRole(IREE::LinalgExt::EncodingRole role) {
static uint32_t getFlagForRole(IREE::LinalgExt::EncodingRole role) {
switch (role) {
case IREE::LinalgExt::EncodingRole::LHS:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS;
case IREE::LinalgExt::EncodingRole::RHS:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS;
case IREE::LinalgExt::EncodingRole::RESULT:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT;
default: // Unreachable.
assert(false);
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS;
default:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_NONE;
}
}

Expand All @@ -439,11 +448,14 @@ matchDAGForUKernel(RewriterBase &rewriter, IREE::Codegen::QueryTileSizesOp op,
for (int64_t i : tensorType.getShape()) {
inputValues.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
}
uint32_t flagForUserAndOperandTypes = getFlagForUserAndOperandTypes(
encoding.getUser().getValue(), encoding.getElementTypes().getValue());
uint32_t flagForRole = getFlagForRole(encoding.getRole().getValue());
if (!flagForUserAndOperandTypes || !flagForRole) {
return rewriter.notifyMatchFailure(op, "unhandled encoding");
}
inputValues.push_back(rewriter.create<arith::ConstantIntOp>(
loc,
flagForUser(encoding.getUser().getValue()) |
flagForRole(encoding.getRole().getValue()),
32));
loc, flagForUserAndOperandTypes | flagForRole, 32));
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
auto fn = getFnNameAndDefAttrs("query_tile_sizes.2d", rewriter, targetAttr);
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,6 @@ func.func @unpack_f32f32_transpose_inner_and_outer(%arg0 : tensor<?x?x7x8xf32>,
func.func @query_tile_sizes_2d() -> (index, index) attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = true}>
} {
%result:2 = iree_codegen.query_tile_sizes tensor<?x?xf32, #iree_linalg_ext.encoding<user=MATMUL_F32F32F32, role=RESULT>> -> index, index
%result:2 = iree_codegen.query_tile_sizes tensor<?x?xf32, #iree_linalg_ext.encoding<user = MATMUL, role = RESULT, element_types = [f32, f32, f32]>> -> index, index
return %result#0, %result#1 : index, index
}
Loading

0 comments on commit 9d7a4ba

Please sign in to comment.