Skip to content

Commit

Permalink
[MemoryBanking] Adjust default dimension to be the innermost dimensio…
Browse files Browse the repository at this point in the history
…n that has shape > 1 (#8132)
  • Loading branch information
jiahanxie353 authored Jan 27, 2025
1 parent b951ce7 commit 0ab632d
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 40 deletions.
6 changes: 3 additions & 3 deletions include/circt/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ std::unique_ptr<mlir::Pass> createStripDebugInfoWithPredPass(
std::unique_ptr<mlir::Pass> createMaximizeSSAPass();
std::unique_ptr<mlir::Pass> createInsertMergeBlocksPass();
std::unique_ptr<mlir::Pass> createPrintOpCountPass();
std::unique_ptr<mlir::Pass> createMemoryBankingPass(
std::optional<unsigned> bankingFactor = std::nullopt,
std::optional<unsigned> bankingDimension = std::nullopt);
std::unique_ptr<mlir::Pass>
createMemoryBankingPass(std::optional<unsigned> bankingFactor = std::nullopt,
std::optional<int> bankingDimension = std::nullopt);
std::unique_ptr<mlir::Pass> createIndexSwitchToIfPass();

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions include/circt/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def MemoryBanking : Pass<"memory-banking", "::mlir::func::FuncOp"> {
let summary = "Partition the memories used in affine parallel loops into banks";
let constructor = "circt::createMemoryBankingPass()";
let options = [
Option<"bankingFactor", "banking-factor", "unsigned", /*default=*/"1",
Option<"bankingFactor", "factor", "unsigned", /*default=*/"1",
"Use this banking factor for all memories being partitioned">,
Option<"bankingDimension", "dimension", "unsigned", /*default=*/"0",
"The dimension along which to bank the memory. For rank=1, must be 0.">
Option<"bankingDimension", "dimension", "int", /*default=*/"-1",
"The dimension along which to bank the memory. If -1, the innermost dimension with size > 1 is used.">
];
let dependentDialects = ["mlir::memref::MemRefDialect, mlir::scf::SCFDialect, mlir::affine::AffineDialect"];
}
Expand Down
58 changes: 49 additions & 9 deletions lib/Transforms/MemoryBanking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct MemoryBankingPass
MemoryBankingPass(const MemoryBankingPass &other) = default;
explicit MemoryBankingPass(
std::optional<unsigned> bankingFactor = std::nullopt,
std::optional<unsigned> bankingDimension = std::nullopt) {}
std::optional<int> bankingDimension = std::nullopt) {}

void runOnOperation() override;

Expand Down Expand Up @@ -200,9 +200,40 @@ SmallVector<Value, 4> handleGetGlobalOp(memref::GetGlobalOp getGlobalOp,
return banks;
}

unsigned getBankingDimension(std::optional<int> bankingDimensionOpt,
int64_t rank, ArrayRef<int64_t> shape) {
// If the banking dimension is already specified, return it.
// Note, the banking dimension will always be nonempty because TableGen will
// assign it with a default value -1 if it's not specified by the user. Thus,
// -1 is the sentinel value to indicate the default behavior, which is the
// innermost dimension with shape greater than 1.
if (bankingDimensionOpt.has_value() && *bankingDimensionOpt >= 0) {
return static_cast<unsigned>(*bankingDimensionOpt);
}

// Otherwise, find the innermost dimension with size > 1.
// For example, [[1], [2], [3], [4]] with `bankingFactor`=2 will be banked to
// [[1], [3]] and [[2], [4]].
int bankingDimension = -1;
for (int dim = rank - 1; dim >= 0; --dim) {
if (shape[dim] > 1) {
bankingDimension = dim;
break;
}
}

assert(bankingDimension >= 0 && "No eligible dimension for banking");
return static_cast<unsigned>(bankingDimension);
}

SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor,
unsigned bankingDimension) {
std::optional<int> bankingDimensionOpt) {
MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
unsigned rank = originalMemRefType.getRank();
ArrayRef<int64_t> shape = originalMemRefType.getShape();

auto bankingDimension = getBankingDimension(bankingDimensionOpt, rank, shape);

MemRefType newMemRefType = computeBankedMemRefType(
originalMemRefType, bankingFactor, bankingDimension);
SmallVector<Value, 4> banks;
Expand Down Expand Up @@ -254,11 +285,11 @@ SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor,
struct BankAffineLoadPattern
: public OpRewritePattern<mlir::affine::AffineLoadOp> {
BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor,
unsigned bankingDimension,
std::optional<int> bankingDimensionOpt,
DenseMap<Value, SmallVector<Value>> &memoryToBanks,
DenseSet<Value> &oldMemRefVals)
: OpRewritePattern<mlir::affine::AffineLoadOp>(context),
bankingFactor(bankingFactor), bankingDimension(bankingDimension),
bankingFactor(bankingFactor), bankingDimensionOpt(bankingDimensionOpt),
memoryToBanks(memoryToBanks), oldMemRefVals(oldMemRefVals) {}

LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp,
Expand All @@ -267,6 +298,11 @@ struct BankAffineLoadPattern
auto banks = memoryToBanks[loadOp.getMemref()];
auto loadIndices = loadOp.getIndices();
int64_t memrefRank = loadOp.getMemRefType().getRank();
ArrayRef<int64_t> shape = loadOp.getMemRefType().getShape();

auto bankingDimension =
getBankingDimension(bankingDimensionOpt, memrefRank, shape);

auto modMap = AffineMap::get(
/*dimCount=*/memrefRank, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(bankingDimension) % bankingFactor});
Expand Down Expand Up @@ -320,7 +356,7 @@ struct BankAffineLoadPattern

private:
uint64_t bankingFactor;
unsigned bankingDimension;
std::optional<int> bankingDimensionOpt;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
DenseSet<Value> &oldMemRefVals;
};
Expand All @@ -329,13 +365,13 @@ struct BankAffineLoadPattern
struct BankAffineStorePattern
: public OpRewritePattern<mlir::affine::AffineStoreOp> {
BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor,
unsigned bankingDimension,
std::optional<int> bankingDimensionOpt,
DenseMap<Value, SmallVector<Value>> &memoryToBanks,
DenseSet<Operation *> &opsToErase,
DenseSet<Operation *> &processedOps,
DenseSet<Value> &oldMemRefVals)
: OpRewritePattern<mlir::affine::AffineStoreOp>(context),
bankingFactor(bankingFactor), bankingDimension(bankingDimension),
bankingFactor(bankingFactor), bankingDimensionOpt(bankingDimensionOpt),
memoryToBanks(memoryToBanks), opsToErase(opsToErase),
processedOps(processedOps), oldMemRefVals(oldMemRefVals) {}

Expand All @@ -348,6 +384,10 @@ struct BankAffineStorePattern
auto banks = memoryToBanks[storeOp.getMemref()];
auto storeIndices = storeOp.getIndices();
int64_t memrefRank = storeOp.getMemRefType().getRank();
ArrayRef<int64_t> shape = storeOp.getMemRefType().getShape();

auto bankingDimension =
getBankingDimension(bankingDimensionOpt, memrefRank, shape);

auto modMap = AffineMap::get(
/*dimCount=*/memrefRank, /*symbolCount=*/0,
Expand Down Expand Up @@ -397,7 +437,7 @@ struct BankAffineStorePattern

private:
uint64_t bankingFactor;
unsigned bankingDimension;
std::optional<int> bankingDimensionOpt;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
DenseSet<Operation *> &opsToErase;
DenseSet<Operation *> &processedOps;
Expand Down Expand Up @@ -540,7 +580,7 @@ void MemoryBankingPass::runOnOperation() {
namespace circt {
std::unique_ptr<mlir::Pass>
createMemoryBankingPass(std::optional<unsigned> bankingFactor,
std::optional<unsigned> bankingDimension) {
std::optional<int> bankingDimension) {
return std::make_unique<MemoryBankingPass>(bankingFactor, bankingDimension);
}
} // namespace circt
8 changes: 4 additions & 4 deletions test/Transforms/memory_banking.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix UNROLL-BY-2
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=1" | FileCheck %s --check-prefix UNROLL-BY-1
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=8" | FileCheck %s --check-prefix UNROLL-BY-8
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix ALLOC-UNROLL-2
// RUN: circt-opt %s -split-input-file -memory-banking="factor=2" | FileCheck %s --check-prefix UNROLL-BY-2
// RUN: circt-opt %s -split-input-file -memory-banking="factor=1" | FileCheck %s --check-prefix UNROLL-BY-1
// RUN: circt-opt %s -split-input-file -memory-banking="factor=8" | FileCheck %s --check-prefix UNROLL-BY-8
// RUN: circt-opt %s -split-input-file -memory-banking="factor=2" | FileCheck %s --check-prefix ALLOC-UNROLL-2

// -----

Expand Down
2 changes: 1 addition & 1 deletion test/Transforms/memory_banking_invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=0" -verify-diagnostics
// RUN: circt-opt %s -split-input-file -memory-banking="factor=0" -verify-diagnostics

// expected-error@+1 {{banking factor must be greater than 1}}
func.func @bank_one_dim_unroll0(%arg0: memref<8xf32>, %arg1: memref<8xf32>) -> (memref<8xf32>) {
Expand Down
33 changes: 13 additions & 20 deletions test/Transforms/memory_banking_multi_dim.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2 dimension=1" | FileCheck %s --check-prefix RANK2-BANKDIM1
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix GETGLOBAL
// RUN: circt-opt %s -memory-banking="factor=2 dimension=1" | FileCheck %s --check-prefix RANK2-BANKDIM1
// RUN: circt-opt %s -split-input-file -memory-banking="factor=2" | FileCheck %s --check-prefix GETGLOBAL

// RANK2-BANKDIM1: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1 mod 2)>
// RANK2-BANKDIM1: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d1 floordiv 2)>
Expand Down Expand Up @@ -76,10 +76,10 @@ func.func @rank_two_bank_dim1(%arg0: memref<8x6xf32>, %arg1: memref<8x6xf32>) ->

// -----

// GETGLOBAL-LABEL: memref.global "private" constant @__constant_4x8xf32_bank_0 : memref<2x8xf32> = dense<{{\[\[}}8.000000e+00, -2.000000e+00, -2.000000e+00, -1.000000e+00, -3.000000e+00, -2.000000e+00, 3.000000e+00, 6.000000e+00], [9.000000e+00, -1.000000e+00, -2.000000e+00, -2.000000e+00, -2.000000e+00, -2.000000e+00, -1.000000e+00, -2.000000e+00]]>
// GETGLOBAL: memref.global "private" constant @__constant_4x8xf32_bank_1 : memref<2x8xf32> = dense<{{\[\[}}1.000000e+00, -3.000000e+00, -2.000000e+00, -1.000000e+00, 5.000000e+00, -3.000000e+00, -1.000000e+00, -2.000000e+00], [2.000000e+00, -7.000000e+00, 3.000000e+00, 1.000000e+00, -2.000000e+00, 2.000000e+00, -9.000000e+00, -1.000000e+00]]>
// GETGLOBAL: memref.global "private" constant @__constant_8x6xf32_bank_0 : memref<4x6xf32> = dense<{{\[\[}}2.000000e+00, -2.000000e+00, -4.000000e+00, -1.000000e+00, -3.000000e+00, 3.000000e+00], [2.000000e+00, -2.000000e+00, 1.000000e+00, -1.000000e+00, 1.000000e+00, -8.000000e+00], [3.000000e+00, -3.000000e+00, -4.000000e+00, -3.000000e+00, -2.000000e+00, 1.000000e+00], [2.000000e+00, -9.000000e+00, 2.000000e+00, -3.000000e+00, -2.000000e+00, 1.000000e+00]]>
// GETGLOBAL: memref.global "private" constant @__constant_8x6xf32_bank_1 : memref<4x6xf32> = dense<{{\[\[}}1.000000e+00, 1.000000e+00, 1.000000e+00, -7.000000e+00, 3.000000e+00, -2.000000e+00], [3.000000e+00, -2.000000e+00, -2.000000e+00, -2.000000e+00, 3.000000e+00, 1.000000e+00], [1.000000e+00, 3.000000e+00, -2.000000e+00, -2.000000e+00, 2.000000e+00, -1.000000e+00], [8.000000e+00, -1.000000e+00, 2.000000e+00, 2.000000e+00, -2.000000e+00, -2.000000e+00]]>
// GETGLOBAL-LABEL: memref.global "private" constant @__constant_4x8xf32_bank_0 : memref<4x4xf32> = dense<{{\[\[}}8.000000e+00, -2.000000e+00, -3.000000e+00, 3.000000e+00], [1.000000e+00, -2.000000e+00, 5.000000e+00, -1.000000e+00], [9.000000e+00, -2.000000e+00, -2.000000e+00, -1.000000e+00], [2.000000e+00, 3.000000e+00, -2.000000e+00, -9.000000e+00]]>
// GETGLOBAL: memref.global "private" constant @__constant_4x8xf32_bank_1 : memref<4x4xf32> = dense<{{\[\[}}-2.000000e+00, -1.000000e+00, -2.000000e+00, 6.000000e+00], [-3.000000e+00, -1.000000e+00, -3.000000e+00, -2.000000e+00], [-1.000000e+00, -2.000000e+00, -2.000000e+00, -2.000000e+00], [-7.000000e+00, 1.000000e+00, 2.000000e+00, -1.000000e+00]]>
// GETGLOBAL: memref.global "private" constant @__constant_8x1xf32_bank_0 : memref<4x1xf32> = dense<{{\[\[}}2.000000e+00], [-1.000000e+00], [3.000000e+00], [2.000000e+00]]>
// GETGLOBAL: memref.global "private" constant @__constant_8x1xf32_bank_1 : memref<4x1xf32> = dense<{{\[\[}}-7.000000e+00], [3.000000e+00], [1.000000e+00], [8.000000e+00]]>

module {
memref.global "private" constant @__constant_4x8xf32 : memref<4x8xf32> = dense<[
Expand All @@ -88,25 +88,18 @@ module {
[9.0, -1.0, -2.0, -2.0, -2.0, -2.0, -1.0, -2.0],
[2.0, -7.0, 3.0, 1.0, -2.0, 2.0, -9.0, -1.0]
]>
memref.global "private" constant @__constant_8x6xf32 : memref<8x6xf32> = dense<[
[2.0, -2.0, -4.0, -1.0, -3.0, 3.0],
[1.0, 1.0, 1.0, -7.0, 3.0, -2.0],
[2.0, -2.0, 1.0, -1.0, 1.0, -8.0],
[3.0, -2.0, -2.0, -2.0, 3.0, 1.0],
[3.0, -3.0, -4.0, -3.0, -2.0, 1.0],
[1.0, 3.0, -2.0, -2.0, 2.0, -1.0],
[2.0, -9.0, 2.0, -3.0, -2.0, 1.0],
[8.0, -1.0, 2.0, 2.0, -2.0, -2.0]
memref.global "private" constant @__constant_8x1xf32 : memref<8x1xf32> = dense<[
[2.0], [-7.0], [-1.0], [3.0], [3.0], [1.0], [2.0], [8.0]
]>
func.func @main() {
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.get_global @__constant_8x6xf32 : memref<8x6xf32>
%0 = memref.get_global @__constant_8x1xf32 : memref<8x1xf32>
%2 = memref.get_global @__constant_4x8xf32 : memref<4x8xf32>
%alloc = memref.alloc() : memref<6x8xf32>
affine.parallel (%arg2) = (0) to (6) {
%alloc = memref.alloc() : memref<1x8xf32>
affine.parallel (%arg2) = (0) to (1) {
affine.parallel (%arg3) = (0) to (8) {
%4 = affine.load %0[%arg3, %arg2] : memref<8x6xf32>
affine.store %4, %alloc[%arg2, %arg3] : memref<6x8xf32>
%4 = affine.load %0[%arg3, %arg2] : memref<8x1xf32>
affine.store %4, %alloc[%arg2, %arg3] : memref<1x8xf32>
}
}
%alloc_5 = memref.alloc() : memref<8x4xf32>
Expand Down

0 comments on commit 0ab632d

Please sign in to comment.