Skip to content

Commit

Permalink
Change LinalgExt::EncodingAttr from enum to structured. (iree-org#14336)
Browse files Browse the repository at this point in the history
This has long been discussed, as the enum was just an initial design
shortcut, but the number of enum cases was already getting uncomfortable
due to the multiple dimensions there.

This is a step in a chain towards fixing
iree-org#11632. The reason is that in
order to properly specialize for narrow matmul cases in
`MaterializeEncoding`, selecting adequately narrow matmul kernels that
avoid widening the entire matmul problem at hand, we will need to know
there the original (pre-padding) matrix shape. Since
`MaterializeEncoding` is a type-conversion, not just an ordinary rewrite
pattern, this information will need to be encoded in types --- we won't
just be able to walk from a value to its defining op to find the
pre-padding value, there just aren't values there. So I will want to add
the pre-padding shape (or type) to EncodingAttr. This is a step towards
that: by making EncodingAttr a data structure, it's easy then to add
another field. By contrast, if it's still an enum, the combinatorics get
out of hand.
  • Loading branch information
bjacob authored and nhasabni committed Aug 24, 2023
1 parent 1406fe5 commit 109dfdf
Show file tree
Hide file tree
Showing 33 changed files with 666 additions and 874 deletions.
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#ifndef IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGINFO_H_
#define IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGINFO_H_

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"

namespace mlir {
Expand All @@ -25,7 +25,7 @@ void adjustTileSizesToNarrowStaticShape(
ArrayRef<int64_t> shape);

IREE::LinalgExt::MaterializeEncodingInfo
chooseEncodingInfoForMatmul(MatmulType type, MatmulOperandRole operandRole,
chooseEncodingInfoForMatmul(IREE::LinalgExt::EncodingRole role,
MatmulTileParams tileParams);

IREE::LinalgExt::MaterializeEncodingValueFn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===---------------------------------------------------------------------===//

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree/compiler/Codegen/Common/EncodingInfo.h"
#include "iree/compiler/Codegen/Common/PassDetail.h"
Expand Down Expand Up @@ -246,22 +247,21 @@ struct MaterializeFlowDispatchTensorStoreOp
} // namespace

IREE::LinalgExt::MaterializeEncodingInfo
chooseEncodingInfoForMatmul(MatmulType type, MatmulOperandRole operandRole,
MatmulTileParams tileParams) {
chooseEncodingInfoForMatmul(EncodingRole role, MatmulTileParams tileParams) {
MaterializeEncodingInfo encodingInfo;
encodingInfo.innerDimsPos = {0, 1};
switch (operandRole) {
case (MatmulOperandRole::LHS): {
switch (role) {
case (EncodingRole::LHS): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
break;
}
case (MatmulOperandRole::RHS): {
case (EncodingRole::RHS): {
encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
encodingInfo.innerDimsPos = {1, 0};
encodingInfo.outerDimsPerm = {1, 0};
break;
}
case (MatmulOperandRole::RESULT): {
case (EncodingRole::RESULT): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ func.func @non_perfect_tiling_unpack() {
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%0:2 = iree_codegen.query_tile_sizes tensor<16x16xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>> -> index, index
%0:2 = iree_codegen.query_tile_sizes tensor<16x16xi32, #iree_linalg_ext.encoding<user = MATMUL_I8I8I32, role = RESULT>> -> index, index
%1 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#0]
%2 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#1]
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c512) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%1, %2, %0#0, %0#1}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2242,7 +2242,7 @@ hal.executable private @dynamic_unpack_fusion {
%cst = arith.constant dense<[-918, -4433, 87, -234, -21393, 7738, 529, -8835, -16817, -375, -199, 572, 5082, 15569, -186, 4955]> : tensor<16xi32>
%c12544 = arith.constant 12544 : index
%c16 = arith.constant 16 : index
%0:2 = iree_codegen.query_tile_sizes tensor<12544x16xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>> -> index, index
%0:2 = iree_codegen.query_tile_sizes tensor<12544x16xi32, #iree_linalg_ext.encoding<user = MATMUL_I8I8I32, role = RESULT>> -> index, index
%1 = affine.apply affine_map<()[s0] -> (12544 ceildiv s0)>()[%0#0]
%2 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#1]
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c200960) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%1, %2, %0#0, %0#1}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "iree/compiler/Codegen/Dialect/IREECodegenOps.h"

#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/UKernelOps.h"
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -21,6 +21,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace iree_compiler {

Expand Down Expand Up @@ -354,6 +355,34 @@ matchDAGForUKernel(RewriterBase &rewriter, tensor::UnPackOp op) {
genericMicroKernelOp.getOperation());
}

static uint32_t flagForUser(IREE::LinalgExt::EncodingUser user) {
switch (user) {
case IREE::LinalgExt::EncodingUser::MATMUL_F32F32F32:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32;
case IREE::LinalgExt::EncodingUser::MATMUL_I8I8I32:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32;
case IREE::LinalgExt::EncodingUser::MATMUL_F16F16F32:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F32;
case IREE::LinalgExt::EncodingUser::MATMUL_F16F16F16:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F16;
case IREE::LinalgExt::EncodingUser::MATMUL_BF16BF16F32:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16F32;
case IREE::LinalgExt::EncodingUser::MATMUL_BF16BF16BF16:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16BF16;
}
}

static uint32_t flagForRole(IREE::LinalgExt::EncodingRole role) {
switch (role) {
case IREE::LinalgExt::EncodingRole::LHS:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS;
case IREE::LinalgExt::EncodingRole::RHS:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS;
case IREE::LinalgExt::EncodingRole::RESULT:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT;
}
}

static FailureOr<IREE::Codegen::UKernelOpInterface>
matchDAGForUKernel(RewriterBase &rewriter, IREE::Codegen::QueryTileSizesOp op) {
auto tensorType = op.getTensorType().dyn_cast<RankedTensorType>();
Expand All @@ -364,47 +393,22 @@ matchDAGForUKernel(RewriterBase &rewriter, IREE::Codegen::QueryTileSizesOp op) {
if (tensorType.getRank() != 2) {
return rewriter.notifyMatchFailure(op, "only the 2D case is implemented");
}
auto encoding = getEncoding(tensorType);
auto encoding = tensorType.getEncoding()
.dyn_cast_or_null<IREE::LinalgExt::EncodingAttr>();
if (!encoding) {
return rewriter.notifyMatchFailure(op, "no TensorEncoding attribute");
}
auto matmulType = getMatmulType(*encoding);
auto matmulOperandRole = getMatmulOperandRole(*encoding);
if (!matmulType || !matmulOperandRole) {
return rewriter.notifyMatchFailure(op, "unhandled TensorEncoding");
}
uint32_t flags = 0;
if (*matmulType == MatmulType::F32F32F32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32;
} else if (*matmulType == MatmulType::I8I8I32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32;
} else if (*matmulType == MatmulType::F16F16F32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F32;
} else if (*matmulType == MatmulType::F16F16F16) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F16;
} else if (*matmulType == MatmulType::BF16BF16F32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16F32;
} else if (*matmulType == MatmulType::BF16BF16BF16) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16BF16;
} else {
return failure();
}
if (*matmulOperandRole == MatmulOperandRole::LHS) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS;
} else if (*matmulOperandRole == MatmulOperandRole::RHS) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS;
} else if (*matmulOperandRole == MatmulOperandRole::RESULT) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT;
} else {
return failure();
}
SmallVector<Type> resultTypes(tensorType.getRank(), rewriter.getIndexType());
SmallVector<Value> inputValues;
Location loc = op.getLoc();
for (int64_t i : tensorType.getShape()) {
inputValues.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
}
inputValues.push_back(rewriter.create<arith::ConstantIntOp>(loc, flags, 32));
inputValues.push_back(rewriter.create<arith::ConstantIntOp>(
loc,
flagForUser(encoding.getUser().getValue()) |
flagForRole(encoding.getRole().getValue()),
32));
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
auto fn = getFnNameAndDefAttrs("query_tile_sizes.2d", rewriter, targetAttr);
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ namespace {
static MatmulTileParams chooseMatmulTileParamsGeneric() { return {8, 4, 8}; }

static MatmulTileParams
chooseMatmulTileParamsAArch64(MatmulType type, ExecutableTargetAttr target) {
switch (type) {
case MatmulType::F32F32F32:
case MatmulType::F16F16F32:
case MatmulType::F16F16F16:
case MatmulType::BF16BF16F32:
case MatmulType::BF16BF16BF16:
chooseMatmulTileParamsAArch64(EncodingUser user, ExecutableTargetAttr target) {
switch (user) {
case EncodingUser::MATMUL_F32F32F32:
case EncodingUser::MATMUL_F16F16F32:
case EncodingUser::MATMUL_F16F16F16:
case EncodingUser::MATMUL_BF16BF16F32:
case EncodingUser::MATMUL_BF16BF16BF16:
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
// reconsider when taking advantage of native f16/bf16 arithmetic when the
// accumulator itself is f16/bf16.
return {8, 1, 8};
case MatmulType::I8I8I32:
case EncodingUser::MATMUL_I8I8I32:
if (hasFeature(target, "+i8mm")) {
// Aim to use SMMLA.
return {8, 8, 8};
Expand All @@ -60,13 +60,13 @@ chooseMatmulTileParamsAArch64(MatmulType type, ExecutableTargetAttr target) {
}

static MatmulTileParams
chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
switch (type) {
case MatmulType::F32F32F32:
case MatmulType::F16F16F32:
case MatmulType::F16F16F16:
case MatmulType::BF16BF16F32:
case MatmulType::BF16BF16BF16:
chooseMatmulTileParamsX86_64(EncodingUser user, ExecutableTargetAttr target) {
switch (user) {
case EncodingUser::MATMUL_F32F32F32:
case EncodingUser::MATMUL_F16F16F32:
case EncodingUser::MATMUL_F16F16F16:
case EncodingUser::MATMUL_BF16BF16F32:
case EncodingUser::MATMUL_BF16BF16BF16:
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
Expand All @@ -84,7 +84,7 @@ chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
}
// SSE fallback.
return {8, 1, 4};
case MatmulType::I8I8I32:
case EncodingUser::MATMUL_I8I8I32:
if (hasFeature(target, "+avx512vnni")) {
// Aim to use VPDPWSSD. This is the same tile size as with VPMADDWD
// as the only difference is that VPDPWSSD accumulates. VPDPBUSD would
Expand All @@ -107,13 +107,13 @@ chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
}
}

static MatmulTileParams chooseMatmulTileParams(MatmulType type,
static MatmulTileParams chooseMatmulTileParams(EncodingUser user,
ExecutableTargetAttr target) {
if (isAArch64(target)) {
return chooseMatmulTileParamsAArch64(type, target);
return chooseMatmulTileParamsAArch64(user, target);
}
if (isX86_64(target)) {
return chooseMatmulTileParamsX86_64(type, target);
return chooseMatmulTileParamsX86_64(user, target);
}
return chooseMatmulTileParamsGeneric();
}
Expand All @@ -139,19 +139,14 @@ void LLVMCPUMaterializeEncodingPass::runOnOperation() {
MaterializeEncodingTypeConverter typeConverter(
[targetAttr](
RankedTensorType tensorType) -> FailureOr<MaterializeEncodingInfo> {
std::optional<TensorEncoding> encoding = getEncoding(tensorType);
auto encoding =
tensorType.getEncoding().dyn_cast_or_null<EncodingAttr>();
if (!encoding)
return failure();

auto matmulType = getMatmulType(*encoding);
auto matmulOperandRole = getMatmulOperandRole(*encoding);
if (!matmulType || !matmulOperandRole) {
return failure();
}
MatmulTileParams tileParams =
chooseMatmulTileParams(*matmulType, targetAttr);
auto encodingInfo = chooseEncodingInfoForMatmul(
*matmulType, *matmulOperandRole, tileParams);
auto user = encoding.getUser().getValue();
auto role = encoding.getRole().getValue();
MatmulTileParams tileParams = chooseMatmulTileParams(user, targetAttr);
auto encodingInfo = chooseEncodingInfoForMatmul(role, tileParams);
adjustTileSizesToNarrowStaticShape(encodingInfo, tensorType.getShape());
return encodingInfo;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,6 @@ func.func @unpack_f32f32_transpose_inner_and_outer(%arg0 : tensor<?x?x7x8xf32>,
func.func @query_tile_sizes_2d() -> (index, index) attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
} {
%result:2 = iree_codegen.query_tile_sizes tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> index, index
%result:2 = iree_codegen.query_tile_sizes tensor<?x?xf32, #iree_linalg_ext.encoding<user=MATMUL_F32F32F32, role=RESULT>> -> index, index
return %result#0, %result#1 : index, index
}
Loading

0 comments on commit 109dfdf

Please sign in to comment.