Skip to content

Commit

Permalink
Change LinalgExt::EncodingAttr from enum to structured.
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacob committed Jul 7, 2023
1 parent edf18a5 commit 1d2d79a
Show file tree
Hide file tree
Showing 29 changed files with 592 additions and 784 deletions.
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#ifndef IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGINFO_H_
#define IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGINFO_H_

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"

namespace mlir {
Expand All @@ -25,7 +25,7 @@ void adjustTileSizesToNarrowStaticShape(
ArrayRef<int64_t> shape);

IREE::LinalgExt::MaterializeEncodingInfo
chooseEncodingInfoForMatmul(MatmulType type, MatmulOperandRole operandRole,
chooseEncodingInfoForMatmul(IREE::LinalgExt::EncodingRole role,
MatmulTileParams tileParams);

IREE::LinalgExt::MaterializeEncodingValueFn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===---------------------------------------------------------------------===//

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree/compiler/Codegen/Common/EncodingInfo.h"
#include "iree/compiler/Codegen/Common/PassDetail.h"
Expand Down Expand Up @@ -246,22 +247,21 @@ struct MaterializeFlowDispatchTensorStoreOp
} // namespace

IREE::LinalgExt::MaterializeEncodingInfo
chooseEncodingInfoForMatmul(MatmulType type, MatmulOperandRole operandRole,
MatmulTileParams tileParams) {
chooseEncodingInfoForMatmul(EncodingRole role, MatmulTileParams tileParams) {
MaterializeEncodingInfo encodingInfo;
encodingInfo.innerDimsPos = {0, 1};
switch (operandRole) {
case (MatmulOperandRole::LHS): {
switch (role) {
case (EncodingRole::LHS): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
break;
}
case (MatmulOperandRole::RHS): {
case (EncodingRole::RHS): {
encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
encodingInfo.innerDimsPos = {1, 0};
encodingInfo.outerDimsPerm = {1, 0};
break;
}
case (MatmulOperandRole::RESULT): {
case (EncodingRole::RESULT): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[PAD_VAL:[A-Za-z0-9]+]]:
// CHECK: %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[3, 1]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[PAD:.+]] = tensor.pad %[[IN]] low[%[[C0]], %[[C0]]] high[%[[C3]], %[[C1]]]
// CHECK: tensor.yield %[[PAD_VAL]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xf32>
// CHECK: %[[TRANS:.+]] = linalg.transpose ins(%[[PAD]] : tensor<8x2xf32>) outs(%[[EMPTY:.+]] : tensor<8x2xf32>) permutation = [0, 1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "iree/compiler/Codegen/Dialect/IREECodegenOps.h"

#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/UKernelOps.h"
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -21,6 +21,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace iree_compiler {

Expand Down Expand Up @@ -364,36 +365,34 @@ matchDAGForUKernel(RewriterBase &rewriter, IREE::Codegen::QueryTileSizesOp op) {
if (tensorType.getRank() != 2) {
return rewriter.notifyMatchFailure(op, "only the 2D case is implemented");
}
auto encoding = getEncoding(tensorType);
auto encoding = tensorType.getEncoding()
.dyn_cast_or_null<IREE::LinalgExt::EncodingAttr>();
if (!encoding) {
return rewriter.notifyMatchFailure(op, "no TensorEncoding attribute");
}
auto matmulType = getMatmulType(*encoding);
auto matmulOperandRole = getMatmulOperandRole(*encoding);
if (!matmulType || !matmulOperandRole) {
return rewriter.notifyMatchFailure(op, "unhandled TensorEncoding");
}
auto opKind = encoding.getOpKind().getValue();
auto role = encoding.getRole().getValue();
uint32_t flags = 0;
if (*matmulType == MatmulType::F32F32F32) {
if (opKind == IREE::LinalgExt::EncodingOpKind::MATMUL_F32F32F32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32;
} else if (*matmulType == MatmulType::I8I8I32) {
} else if (opKind == IREE::LinalgExt::EncodingOpKind::MATMUL_I8I8I32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32;
} else if (*matmulType == MatmulType::F16F16F32) {
} else if (opKind == IREE::LinalgExt::EncodingOpKind::MATMUL_F16F16F32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F32;
} else if (*matmulType == MatmulType::F16F16F16) {
} else if (opKind == IREE::LinalgExt::EncodingOpKind::MATMUL_F16F16F16) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F16;
} else if (*matmulType == MatmulType::BF16BF16F32) {
} else if (opKind == IREE::LinalgExt::EncodingOpKind::MATMUL_BF16BF16F32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16F32;
} else if (*matmulType == MatmulType::BF16BF16BF16) {
} else if (opKind == IREE::LinalgExt::EncodingOpKind::MATMUL_BF16BF16BF16) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16BF16;
} else {
return failure();
}
if (*matmulOperandRole == MatmulOperandRole::LHS) {
if (role == IREE::LinalgExt::EncodingRole::LHS) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS;
} else if (*matmulOperandRole == MatmulOperandRole::RHS) {
} else if (role == IREE::LinalgExt::EncodingRole::RHS) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS;
} else if (*matmulOperandRole == MatmulOperandRole::RESULT) {
} else if (role == IREE::LinalgExt::EncodingRole::RESULT) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT;
} else {
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,21 @@ namespace {
static MatmulTileParams chooseMatmulTileParamsGeneric() { return {8, 4, 8}; }

static MatmulTileParams
chooseMatmulTileParamsAArch64(MatmulType type, ExecutableTargetAttr target) {
switch (type) {
case MatmulType::F32F32F32:
case MatmulType::F16F16F32:
case MatmulType::F16F16F16:
case MatmulType::BF16BF16F32:
case MatmulType::BF16BF16BF16:
chooseMatmulTileParamsAArch64(EncodingOpKind opKind,
ExecutableTargetAttr target) {
switch (opKind) {
case EncodingOpKind::MATMUL_F32F32F32:
case EncodingOpKind::MATMUL_F16F16F32:
case EncodingOpKind::MATMUL_F16F16F16:
case EncodingOpKind::MATMUL_BF16BF16F32:
case EncodingOpKind::MATMUL_BF16BF16BF16:
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
// reconsider when taking advantage of native f16/bf16 arithmetic when the
// accumulator itself is f16/bf16.
return {8, 1, 8};
case MatmulType::I8I8I32:
case EncodingOpKind::MATMUL_I8I8I32:
if (hasFeature(target, "+i8mm")) {
// Aim to use SMMLA.
return {8, 8, 8};
Expand All @@ -60,13 +61,14 @@ chooseMatmulTileParamsAArch64(MatmulType type, ExecutableTargetAttr target) {
}

static MatmulTileParams
chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
switch (type) {
case MatmulType::F32F32F32:
case MatmulType::F16F16F32:
case MatmulType::F16F16F16:
case MatmulType::BF16BF16F32:
case MatmulType::BF16BF16BF16:
chooseMatmulTileParamsX86_64(EncodingOpKind opKind,
ExecutableTargetAttr target) {
switch (opKind) {
case EncodingOpKind::MATMUL_F32F32F32:
case EncodingOpKind::MATMUL_F16F16F32:
case EncodingOpKind::MATMUL_F16F16F16:
case EncodingOpKind::MATMUL_BF16BF16F32:
case EncodingOpKind::MATMUL_BF16BF16BF16:
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
Expand All @@ -84,7 +86,7 @@ chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
}
// SSE fallback.
return {8, 1, 4};
case MatmulType::I8I8I32:
case EncodingOpKind::MATMUL_I8I8I32:
if (hasFeature(target, "+avx512vnni")) {
// Aim to use VPDPWSSD. This is the same tile size as with VPMADDWD
// as the only difference is that VPDPWSSD accumulates. VPDPBUSD would
Expand All @@ -107,13 +109,13 @@ chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
}
}

static MatmulTileParams chooseMatmulTileParams(MatmulType type,
static MatmulTileParams chooseMatmulTileParams(EncodingOpKind opKind,
ExecutableTargetAttr target) {
if (isAArch64(target)) {
return chooseMatmulTileParamsAArch64(type, target);
return chooseMatmulTileParamsAArch64(opKind, target);
}
if (isX86_64(target)) {
return chooseMatmulTileParamsX86_64(type, target);
return chooseMatmulTileParamsX86_64(opKind, target);
}
return chooseMatmulTileParamsGeneric();
}
Expand All @@ -139,19 +141,15 @@ void LLVMCPUMaterializeEncodingPass::runOnOperation() {
MaterializeEncodingTypeConverter typeConverter(
[targetAttr](
RankedTensorType tensorType) -> FailureOr<MaterializeEncodingInfo> {
std::optional<TensorEncoding> encoding = getEncoding(tensorType);
EncodingAttr encoding =
tensorType.getEncoding().dyn_cast_or_null<EncodingAttr>();
if (!encoding)
return failure();

auto matmulType = getMatmulType(*encoding);
auto matmulOperandRole = getMatmulOperandRole(*encoding);
if (!matmulType || !matmulOperandRole) {
return failure();
}
auto opKind = encoding.getOpKind().getValue();
auto role = encoding.getRole().getValue();
MatmulTileParams tileParams =
chooseMatmulTileParams(*matmulType, targetAttr);
auto encodingInfo = chooseEncodingInfoForMatmul(
*matmulType, *matmulOperandRole, tileParams);
chooseMatmulTileParams(opKind, targetAttr);
auto encodingInfo = chooseEncodingInfoForMatmul(role, tileParams);
adjustTileSizesToNarrowStaticShape(encodingInfo, tensorType.getShape());
return encodingInfo;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,6 @@ func.func @unpack_f32f32_transpose_inner_and_outer(%arg0 : tensor<?x?x7x8xf32>,
func.func @query_tile_sizes_2d() -> (index, index) attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
} {
%result:2 = iree_codegen.query_tile_sizes tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> index, index
%result:2 = iree_codegen.query_tile_sizes tensor<?x?xf32, #iree_linalg_ext.encoding<op_kind=MATMUL_F32F32F32, role=RESULT>> -> index, index
return %result#0, %result#1 : index, index
}
Loading

0 comments on commit 1d2d79a

Please sign in to comment.