Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change LinalgExt::EncodingAttr from enum to structured. #14336

Merged
merged 2 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#ifndef IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGINFO_H_
#define IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGINFO_H_

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"

namespace mlir {
Expand All @@ -25,7 +25,7 @@ void adjustTileSizesToNarrowStaticShape(
ArrayRef<int64_t> shape);

IREE::LinalgExt::MaterializeEncodingInfo
chooseEncodingInfoForMatmul(MatmulType type, MatmulOperandRole operandRole,
chooseEncodingInfoForMatmul(IREE::LinalgExt::EncodingRole role,
MatmulTileParams tileParams);

IREE::LinalgExt::MaterializeEncodingValueFn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===---------------------------------------------------------------------===//

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree/compiler/Codegen/Common/EncodingInfo.h"
#include "iree/compiler/Codegen/Common/PassDetail.h"
Expand Down Expand Up @@ -246,22 +247,21 @@ struct MaterializeFlowDispatchTensorStoreOp
} // namespace

IREE::LinalgExt::MaterializeEncodingInfo
chooseEncodingInfoForMatmul(MatmulType type, MatmulOperandRole operandRole,
MatmulTileParams tileParams) {
chooseEncodingInfoForMatmul(EncodingRole role, MatmulTileParams tileParams) {
MaterializeEncodingInfo encodingInfo;
encodingInfo.innerDimsPos = {0, 1};
switch (operandRole) {
case (MatmulOperandRole::LHS): {
switch (role) {
case (EncodingRole::LHS): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.K};
break;
}
case (MatmulOperandRole::RHS): {
case (EncodingRole::RHS): {
encodingInfo.innerTileSizes = {tileParams.N, tileParams.K};
encodingInfo.innerDimsPos = {1, 0};
encodingInfo.outerDimsPerm = {1, 0};
break;
}
case (MatmulOperandRole::RESULT): {
case (EncodingRole::RESULT): {
encodingInfo.innerTileSizes = {tileParams.M, tileParams.N};
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ func.func @non_perfect_tiling_unpack() {
%c512 = arith.constant 512 : index
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%0:2 = iree_codegen.query_tile_sizes tensor<16x16xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>> -> index, index
%0:2 = iree_codegen.query_tile_sizes tensor<16x16xi32, #iree_linalg_ext.encoding<user = MATMUL_I8I8I32, role = RESULT>> -> index, index
%1 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#0]
%2 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#1]
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c512) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%1, %2, %0#0, %0#1}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2242,7 +2242,7 @@ hal.executable private @dynamic_unpack_fusion {
%cst = arith.constant dense<[-918, -4433, 87, -234, -21393, 7738, 529, -8835, -16817, -375, -199, 572, 5082, 15569, -186, 4955]> : tensor<16xi32>
%c12544 = arith.constant 12544 : index
%c16 = arith.constant 16 : index
%0:2 = iree_codegen.query_tile_sizes tensor<12544x16xi32, #iree_linalg_ext.encoding<MATMUL_I8I8I32_RESULT>> -> index, index
%0:2 = iree_codegen.query_tile_sizes tensor<12544x16xi32, #iree_linalg_ext.encoding<user = MATMUL_I8I8I32, role = RESULT>> -> index, index
%1 = affine.apply affine_map<()[s0] -> (12544 ceildiv s0)>()[%0#0]
%2 = affine.apply affine_map<()[s0] -> (16 ceildiv s0)>()[%0#1]
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c200960) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%1, %2, %0#0, %0#1}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "iree/compiler/Codegen/Dialect/IREECodegenOps.h"

#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/UKernelOps.h"
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -21,6 +21,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace iree_compiler {

Expand Down Expand Up @@ -354,6 +355,34 @@ matchDAGForUKernel(RewriterBase &rewriter, tensor::UnPackOp op) {
genericMicroKernelOp.getOperation());
}

static uint32_t flagForUser(IREE::LinalgExt::EncodingUser user) {
switch (user) {
case IREE::LinalgExt::EncodingUser::MATMUL_F32F32F32:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32;
case IREE::LinalgExt::EncodingUser::MATMUL_I8I8I32:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32;
case IREE::LinalgExt::EncodingUser::MATMUL_F16F16F32:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F32;
case IREE::LinalgExt::EncodingUser::MATMUL_F16F16F16:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F16;
case IREE::LinalgExt::EncodingUser::MATMUL_BF16BF16F32:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16F32;
case IREE::LinalgExt::EncodingUser::MATMUL_BF16BF16BF16:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16BF16;
}
}

static uint32_t flagForRole(IREE::LinalgExt::EncodingRole role) {
switch (role) {
case IREE::LinalgExt::EncodingRole::LHS:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS;
case IREE::LinalgExt::EncodingRole::RHS:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS;
case IREE::LinalgExt::EncodingRole::RESULT:
return IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT;
}
}

static FailureOr<IREE::Codegen::UKernelOpInterface>
matchDAGForUKernel(RewriterBase &rewriter, IREE::Codegen::QueryTileSizesOp op) {
auto tensorType = op.getTensorType().dyn_cast<RankedTensorType>();
Expand All @@ -364,47 +393,22 @@ matchDAGForUKernel(RewriterBase &rewriter, IREE::Codegen::QueryTileSizesOp op) {
if (tensorType.getRank() != 2) {
return rewriter.notifyMatchFailure(op, "only the 2D case is implemented");
}
auto encoding = getEncoding(tensorType);
auto encoding = tensorType.getEncoding()
.dyn_cast_or_null<IREE::LinalgExt::EncodingAttr>();
if (!encoding) {
return rewriter.notifyMatchFailure(op, "no TensorEncoding attribute");
}
auto matmulType = getMatmulType(*encoding);
auto matmulOperandRole = getMatmulOperandRole(*encoding);
if (!matmulType || !matmulOperandRole) {
return rewriter.notifyMatchFailure(op, "unhandled TensorEncoding");
}
uint32_t flags = 0;
if (*matmulType == MatmulType::F32F32F32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F32F32F32;
} else if (*matmulType == MatmulType::I8I8I32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32;
} else if (*matmulType == MatmulType::F16F16F32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F32;
} else if (*matmulType == MatmulType::F16F16F16) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_F16F16F16;
} else if (*matmulType == MatmulType::BF16BF16F32) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16F32;
} else if (*matmulType == MatmulType::BF16BF16BF16) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_BF16BF16BF16;
} else {
return failure();
}
if (*matmulOperandRole == MatmulOperandRole::LHS) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_LHS;
} else if (*matmulOperandRole == MatmulOperandRole::RHS) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RHS;
} else if (*matmulOperandRole == MatmulOperandRole::RESULT) {
flags |= IREE_UK_FLAG_QUERY_TILE_SIZES_OPERAND_ROLE_RESULT;
} else {
return failure();
}
SmallVector<Type> resultTypes(tensorType.getRank(), rewriter.getIndexType());
SmallVector<Value> inputValues;
Location loc = op.getLoc();
for (int64_t i : tensorType.getShape()) {
inputValues.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
}
inputValues.push_back(rewriter.create<arith::ConstantIntOp>(loc, flags, 32));
inputValues.push_back(rewriter.create<arith::ConstantIntOp>(
loc,
flagForUser(encoding.getUser().getValue()) |
flagForRole(encoding.getRole().getValue()),
32));
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
auto fn = getFnNameAndDefAttrs("query_tile_sizes.2d", rewriter, targetAttr);
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ namespace {
static MatmulTileParams chooseMatmulTileParamsGeneric() { return {8, 4, 8}; }

static MatmulTileParams
chooseMatmulTileParamsAArch64(MatmulType type, ExecutableTargetAttr target) {
switch (type) {
case MatmulType::F32F32F32:
case MatmulType::F16F16F32:
case MatmulType::F16F16F16:
case MatmulType::BF16BF16F32:
case MatmulType::BF16BF16BF16:
chooseMatmulTileParamsAArch64(EncodingUser user, ExecutableTargetAttr target) {
switch (user) {
case EncodingUser::MATMUL_F32F32F32:
case EncodingUser::MATMUL_F16F16F32:
case EncodingUser::MATMUL_F16F16F16:
case EncodingUser::MATMUL_BF16BF16F32:
case EncodingUser::MATMUL_BF16BF16BF16:
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
// reconsider when taking advantage of native f16/bf16 arithmetic when the
// accumulator itself is f16/bf16.
return {8, 1, 8};
case MatmulType::I8I8I32:
case EncodingUser::MATMUL_I8I8I32:
if (hasFeature(target, "+i8mm")) {
// Aim to use SMMLA.
return {8, 8, 8};
Expand All @@ -60,13 +60,13 @@ chooseMatmulTileParamsAArch64(MatmulType type, ExecutableTargetAttr target) {
}

static MatmulTileParams
chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
switch (type) {
case MatmulType::F32F32F32:
case MatmulType::F16F16F32:
case MatmulType::F16F16F16:
case MatmulType::BF16BF16F32:
case MatmulType::BF16BF16BF16:
chooseMatmulTileParamsX86_64(EncodingUser user, ExecutableTargetAttr target) {
switch (user) {
case EncodingUser::MATMUL_F32F32F32:
case EncodingUser::MATMUL_F16F16F32:
case EncodingUser::MATMUL_F16F16F16:
case EncodingUser::MATMUL_BF16BF16F32:
case EncodingUser::MATMUL_BF16BF16BF16:
// Note: 16-bit floating point types currently use the same tile size as
// f32. This makes sense when either (1) the accumulator is f32, or (2)
// the arithmetic will have to expand f16 to f32 in registers. We may
Expand All @@ -84,7 +84,7 @@ chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
}
// SSE fallback.
return {8, 1, 4};
case MatmulType::I8I8I32:
case EncodingUser::MATMUL_I8I8I32:
if (hasFeature(target, "+avx512vnni")) {
// Aim to use VPDPWSSD. This is the same tile size as with VPMADDWD
// as the only difference is that VPDPWSSD accumulates. VPDPBUSD would
Expand All @@ -107,13 +107,13 @@ chooseMatmulTileParamsX86_64(MatmulType type, ExecutableTargetAttr target) {
}
}

static MatmulTileParams chooseMatmulTileParams(MatmulType type,
static MatmulTileParams chooseMatmulTileParams(EncodingUser user,
ExecutableTargetAttr target) {
if (isAArch64(target)) {
return chooseMatmulTileParamsAArch64(type, target);
return chooseMatmulTileParamsAArch64(user, target);
}
if (isX86_64(target)) {
return chooseMatmulTileParamsX86_64(type, target);
return chooseMatmulTileParamsX86_64(user, target);
}
return chooseMatmulTileParamsGeneric();
}
Expand All @@ -139,19 +139,14 @@ void LLVMCPUMaterializeEncodingPass::runOnOperation() {
MaterializeEncodingTypeConverter typeConverter(
[targetAttr](
RankedTensorType tensorType) -> FailureOr<MaterializeEncodingInfo> {
std::optional<TensorEncoding> encoding = getEncoding(tensorType);
auto encoding =
tensorType.getEncoding().dyn_cast_or_null<EncodingAttr>();
if (!encoding)
return failure();

auto matmulType = getMatmulType(*encoding);
auto matmulOperandRole = getMatmulOperandRole(*encoding);
if (!matmulType || !matmulOperandRole) {
return failure();
}
MatmulTileParams tileParams =
chooseMatmulTileParams(*matmulType, targetAttr);
auto encodingInfo = chooseEncodingInfoForMatmul(
*matmulType, *matmulOperandRole, tileParams);
auto user = encoding.getUser().getValue();
auto role = encoding.getRole().getValue();
MatmulTileParams tileParams = chooseMatmulTileParams(user, targetAttr);
auto encodingInfo = chooseEncodingInfoForMatmul(role, tileParams);
adjustTileSizesToNarrowStaticShape(encodingInfo, tensorType.getShape());
return encodingInfo;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,6 @@ func.func @unpack_f32f32_transpose_inner_and_outer(%arg0 : tensor<?x?x7x8xf32>,
func.func @query_tile_sizes_2d() -> (index, index) attributes {
hal.executable.target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
} {
%result:2 = iree_codegen.query_tile_sizes tensor<?x?xf32, #iree_linalg_ext.encoding<MATMUL_F32F32F32_RESULT>> -> index, index
%result:2 = iree_codegen.query_tile_sizes tensor<?x?xf32, #iree_linalg_ext.encoding<user=MATMUL_F32F32F32, role=RESULT>> -> index, index
return %result#0, %result#1 : index, index
}
Loading