Skip to content

Commit

Permalink
[MooreToCore] Lower moore operators into comb or hw.
Browse files Browse the repository at this point in the history
Co-authored-by: Fabian Schuiki <[email protected]>
  • Loading branch information
hailongSun2000 and fabianschuiki committed Apr 22, 2024
1 parent 61a18fe commit 1bccc51
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 11 deletions.
209 changes: 199 additions & 10 deletions lib/Conversion/MooreToCore/MooreToCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,51 @@ static Value adjustIntegerWidth(OpBuilder &builder, Value value,
return builder.create<comb::MuxOp>(loc, isZero, lo, max, false);
}

/// Due to the result type of the `lt`, or `le`, or `gt`, or `ge` ops are
/// always unsigned, estimating their operands type.
static bool isSignedType(Operation *op) {
return TypeSwitch<Operation *, bool>(op)
.template Case<LtOp, LeOp, GtOp, GeOp>([&](auto op) -> bool {
return cast<UnpackedType>(op->getOperand(0).getType())
.castToSimpleBitVector()
.isSigned() &&
cast<UnpackedType>(op->getOperand(1).getType())
.castToSimpleBitVector()
.isSigned();
})
.Default([&](auto op) -> bool {
return cast<UnpackedType>(op->getResult(0).getType())
.castToSimpleBitVector()
.isSigned();
});
}

/// Not define the predicate for `relation` and `equality` operations in the
/// MooreDialect, but comb needs it. Return a correct `comb::ICmpPredicate`
/// corresponding to different moore `relation` and `equality` operations.
static comb::ICmpPredicate getCombPredicate(Operation *op) {
using comb::ICmpPredicate;
return TypeSwitch<Operation *, ICmpPredicate>(op)
.Case<LtOp>([&](auto op) {
return isSignedType(op) ? ICmpPredicate::slt : ICmpPredicate::ult;
})
.Case<LeOp>([&](auto op) {
return isSignedType(op) ? ICmpPredicate::sle : ICmpPredicate::ule;
})
.Case<GtOp>([&](auto op) {
return isSignedType(op) ? ICmpPredicate::sgt : ICmpPredicate::ugt;
})
.Case<GeOp>([&](auto op) {
return isSignedType(op) ? ICmpPredicate::sge : ICmpPredicate::uge;
})
.Case<EqOp>([&](auto op) { return ICmpPredicate::eq; })
.Case<NeOp>([&](auto op) { return ICmpPredicate::ne; })
.Case<CaseEqOp>([&](auto op) { return ICmpPredicate::ceq; })
.Case<CaseNeOp>([&](auto op) { return ICmpPredicate::cne; })
.Case<WildcardEqOp>([&](auto op) { return ICmpPredicate::weq; })
.Case<WildcardNeOp>([&](auto op) { return ICmpPredicate::wne; });
}

//===----------------------------------------------------------------------===//
// Expression Conversion
//===----------------------------------------------------------------------===//
Expand All @@ -82,6 +127,124 @@ struct ConcatOpConversion : public OpConversionPattern<ConcatOp> {
}
};

struct ReplicateOpConversion : public OpConversionPattern<ReplicateOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(op.getResult().getType());

rewriter.replaceOpWithNewOp<comb::ReplicateOp>(op, resultType,
adaptor.getValue());
return success();
}
};

struct ExtractOpConversion : public OpConversionPattern<ExtractOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(op.getResult().getType());
auto width = typeConverter->convertType(op.getInput().getType())
.getIntOrFloatBitWidth();
Value amount =
adjustIntegerWidth(rewriter, adaptor.getLowBit(), width, op->getLoc());
Value value =
rewriter.create<comb::ShrUOp>(op->getLoc(), adaptor.getInput(), amount);

rewriter.replaceOpWithNewOp<comb::ExtractOp>(op, resultType, value, 0);
return success();
}
};

template <typename OpTy>
struct UnaryOpConversion : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;

LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<comb::ParityOp>(op, adaptor.getInput());
return success();
}
};

struct NotOpConversion : public OpConversionPattern<NotOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(NotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType =
ConversionPattern::typeConverter->convertType(op.getResult().getType());
Value max = rewriter.create<hw::ConstantOp>(op.getLoc(), resultType, -1);

rewriter.replaceOpWithNewOp<comb::XorOp>(op, adaptor.getInput(), max);
return success();
}
};

template <typename SourceOp, typename TargetOp>
struct BinaryOpConversion : public OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename SourceOp::Adaptor;

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<DivOp>(op) && isSignedType(op)) {
rewriter.replaceOpWithNewOp<comb::DivSOp>(op, adaptor.getLhs(),
adaptor.getRhs(), false);
return success();
}
if (isa<ModOp>(op) && isSignedType(op)) {
rewriter.replaceOpWithNewOp<comb::ModSOp>(op, adaptor.getLhs(),
adaptor.getRhs(), false);
return success();
}

rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getLhs(),
adaptor.getRhs(), false);
return success();
}
};

template <typename SourceOp>
struct ICmpOpConversion : public OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename SourceOp::Adaptor;

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adapter,
ConversionPatternRewriter &rewriter) const override {
Type resultType =
ConversionPattern::typeConverter->convertType(op.getResult().getType());
comb::ICmpPredicate pred = getCombPredicate(op);

rewriter.replaceOpWithNewOp<comb::ICmpOp>(
op, resultType, pred, adapter.getLhs(), adapter.getRhs());
return success();
}
};

struct ConversionOpConversion : public OpConversionPattern<ConversionOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(ConversionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = typeConverter->convertType(op.getResult().getType());
Value amount =
adjustIntegerWidth(rewriter, adaptor.getInput(),
resultType.getIntOrFloatBitWidth(), op->getLoc());

rewriter.replaceOpWithNewOp<hw::BitcastOp>(op, resultType, amount);
return success();
}
};

//===----------------------------------------------------------------------===//
// Statement Conversion
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -271,6 +434,9 @@ static void populateTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion([&](UnpackedType type) -> std::optional<Type> {
if (auto sbv = type.getSimpleBitVectorOrNull())
return mlir::IntegerType::get(type.getContext(), sbv.size);
if (isa<UnpackedRangeDim, PackedRangeDim>(type))
return mlir::IntegerType::get(type.getContext(),
type.getBitSize().value());
return std::nullopt;
});

Expand All @@ -283,16 +449,39 @@ static void populateOpConversion(RewritePatternSet &patterns,
auto *context = patterns.getContext();
// clang-format off
patterns.add<
ConstantOpConv,
ConcatOpConversion,
ReturnOpConversion,
CondBranchOpConversion,
BranchOpConversion,
CallOpConversion,
ShlOpConversion,
ShrOpConversion,
AShrOpConversion,
UnrealizedConversionCastConversion
// Patterns of miscellaneous operations.
ConstantOpConv, ConcatOpConversion, ReplicateOpConversion,
ExtractOpConversion, ConversionOpConversion,

// Patterns of unary operations.
UnaryOpConversion<BoolCastOp>, UnaryOpConversion<ReduceAndOp>,
UnaryOpConversion<ReduceOrOp>, UnaryOpConversion<ReduceXorOp>,
NotOpConversion,

// Patterns of binary operations.
BinaryOpConversion<AddOp, comb::AddOp>,
BinaryOpConversion<SubOp, comb::SubOp>,
BinaryOpConversion<MulOp, comb::MulOp>,
BinaryOpConversion<DivOp, comb::DivUOp>,
BinaryOpConversion<ModOp, comb::ModUOp>,
BinaryOpConversion<AndOp, comb::AndOp>,
BinaryOpConversion<OrOp, comb::OrOp>,
BinaryOpConversion<XorOp, comb::XorOp>,

// Patterns of relational operations.
ICmpOpConversion<LtOp>, ICmpOpConversion<LeOp>, ICmpOpConversion<GtOp>,
ICmpOpConversion<GeOp>, ICmpOpConversion<EqOp>, ICmpOpConversion<NeOp>,
ICmpOpConversion<CaseEqOp>, ICmpOpConversion<CaseNeOp>,
ICmpOpConversion<WildcardEqOp>, ICmpOpConversion<WildcardNeOp>,

// Patterns of shifting operations.
ShrOpConversion, ShlOpConversion, AShrOpConversion,

// Patterns of branch operations.
CondBranchOpConversion, BranchOpConversion,

// Patterns of other operations outside Moore dialect.
ReturnOpConversion, CallOpConversion, UnrealizedConversionCastConversion
>(typeConverter, context);
// clang-format on
mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
Expand Down
100 changes: 99 additions & 1 deletion test/Conversion/MooreToCore/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,25 @@ func.func @UnrealizedConversionCast(%arg0: !moore.byte) -> !moore.shortint {
}

// CHECK-LABEL: func @Expressions
func.func @Expressions(%arg0: !moore.bit, %arg1: !moore.logic, %arg2: !moore.packed<range<bit, 5:0>>, %arg3: !moore.packed<range<bit<signed>, 4:0>>) {
func.func @Expressions(%arg0: !moore.bit, %arg1: !moore.logic, %arg2: !moore.packed<range<bit, 5:0>>, %arg3: !moore.packed<range<bit<signed>, 4:0>>, %arg4: !moore.bit<signed>) {
// CHECK-NEXT: %0 = comb.concat %arg0, %arg0 : i1, i1
// CHECK-NEXT: %1 = comb.concat %arg1, %arg1 : i1, i1
moore.concat %arg0, %arg0 : (!moore.bit, !moore.bit) -> !moore.packed<range<bit, 1:0>>
moore.concat %arg1, %arg1 : (!moore.logic, !moore.logic) -> !moore.packed<range<logic, 1:0>>

// CHECK-NEXT: comb.replicate %arg0 : (i1) -> i2
// CHECK-NEXT: comb.replicate %arg1 : (i1) -> i2
moore.replicate %arg0 : (!moore.bit) -> !moore.packed<range<bit, 1:0>>
moore.replicate %arg1 : (!moore.logic) -> !moore.packed<range<logic, 1:0>>

// CHECK-NEXT: %c12_i32 = hw.constant 12 : i32
// CHECK-NEXT: %c3_i6 = hw.constant 3 : i6
moore.constant 12 : !moore.int
moore.constant 3 : !moore.packed<range<bit, 5:0>>

// CHECK-NEXT: hw.bitcast %arg0 : (i1) -> i1
moore.conversion %arg0 : !moore.bit -> !moore.logic

// CHECK-NEXT: [[V0:%.+]] = hw.constant 0 : i5
// CHECK-NEXT: [[V1:%.+]] = comb.concat [[V0]], %arg0 : i5, i1
// CHECK-NEXT: comb.shl %arg2, [[V1]] : i6
Expand Down Expand Up @@ -83,6 +96,91 @@ func.func @Expressions(%arg0: !moore.bit, %arg1: !moore.logic, %arg2: !moore.pac
// CHECK-NEXT: comb.shrs %arg3, [[V15]] : i5
moore.ashr %arg3, %arg2 : !moore.packed<range<bit<signed>, 4:0>>, !moore.packed<range<bit, 5:0>>

// CHECK-NEXT: %c2_i32 = hw.constant 2 : i32
%2 = moore.constant 2 : !moore.int

// CHECK-NEXT: [[V16:%.+]] = comb.extract %c2_i32 from 6 : (i32) -> i26
// CHECK-NEXT: %c0_i26 = hw.constant 0 : i26
// CHECK-NEXT: [[V17:%.+]] = comb.icmp eq [[V16]], %c0_i26 : i26
// CHECK-NEXT: [[V18:%.+]] = comb.extract %c2_i32 from 0 : (i32) -> i6
// CHECK-NEXT: %c-1_i6 = hw.constant -1 : i6
// CHECK-NEXT: [[V19:%.+]] = comb.mux [[V17]], [[V18]], %c-1_i6 : i6
// CHECK-NEXT: [[V20:%.+]] = comb.shru %arg2, [[V19]] : i6
// CHECK-NEXT: comb.extract [[V20]] from 0 : (i6) -> i2
moore.extract %arg2 from %2 : !moore.packed<range<bit, 5:0>>, !moore.int -> !moore.packed<range<bit, 3:2>>

// CHECK-NEXT: [[V21:%.+]] = comb.extract %c2_i32 from 6 : (i32) -> i26
// CHECK-NEXT: %c0_i26_3 = hw.constant 0 : i26
// CHECK-NEXT: [[V22:%.+]] = comb.icmp eq [[V21]], %c0_i26_3 : i26
// CHECK-NEXT: [[V23:%.+]] = comb.extract %c2_i32 from 0 : (i32) -> i6
// CHECK-NEXT: %c-1_i6_4 = hw.constant -1 : i6
// CHECK-NEXT: [[V24:%.+]] = comb.mux [[V22]], [[V23]], %c-1_i6_4 : i6
// CHECK-NEXT: [[V25:%.+]] = comb.shru %arg2, [[V24]] : i6
// CHECK-NEXT: comb.extract [[V25]] from 0 : (i6) -> i1
moore.extract %arg2 from %2 : !moore.packed<range<bit, 5:0>>, !moore.int -> !moore.bit

// CHECK-NEXT: comb.parity %arg0 : i1
// CHECK-NEXT: comb.parity %arg0 : i1
// CHECK-NEXT: comb.parity %arg1 : i1
moore.reduce_and %arg0 : !moore.bit -> !moore.bit
moore.reduce_or %arg0 : !moore.bit -> !moore.bit
moore.reduce_xor %arg1 : !moore.logic -> !moore.logic

// CHECK-NEXT: comb.parity %arg2 : i6
moore.bool_cast %arg2 : !moore.packed<range<bit, 5:0>> -> !moore.bit

// CHECK-NEXT: [[V26:%.+]] = hw.constant -1 : i6
// CHECK-NEXT: comb.xor %arg2, [[V26]] : i6
moore.not %arg2 : !moore.packed<range<bit, 5:0>>

// CHECK-NEXT: comb.add %arg1, %arg1 : i1
// CHECK-NEXT: comb.sub %arg1, %arg1 : i1
// CHECK-NEXT: comb.mul %arg1, %arg1 : i1
// CHECK-NEXT: comb.divu %arg0, %arg0 : i1
// CHECK-NEXT: comb.modu %arg0, %arg0 : i1
// CHECK-NEXT: comb.and %arg0, %arg0 : i1
// CHECK-NEXT: comb.or %arg0, %arg0 : i1
// CHECK-NEXT: comb.xor %arg0, %arg0 : i1
moore.add %arg1, %arg1 : !moore.logic
moore.sub %arg1, %arg1 : !moore.logic
moore.mul %arg1, %arg1 : !moore.logic
moore.div %arg0, %arg0 : !moore.bit
moore.mod %arg0, %arg0 : !moore.bit
moore.and %arg0, %arg0 : !moore.bit
moore.or %arg0, %arg0 : !moore.bit
moore.xor %arg0, %arg0 : !moore.bit

// CHECK-NEXT: comb.icmp ult %arg1, %arg1 : i1
// CHECK-NEXT: comb.icmp ule %arg0, %arg0 : i1
// CHECK-NEXT: comb.icmp ugt %arg0, %arg0 : i1
// CHECK-NEXT: comb.icmp uge %arg0, %arg0 : i1
moore.lt %arg1, %arg1 : !moore.logic -> !moore.logic
moore.le %arg0, %arg0 : !moore.bit -> !moore.bit
moore.gt %arg0, %arg0 : !moore.bit -> !moore.bit
moore.ge %arg0, %arg0 : !moore.bit -> !moore.bit

// CHECK-NEXT: comb.icmp slt %arg4, %arg4 : i1
// CHECK-NEXT: comb.icmp sle %arg4, %arg4 : i1
// CHECK-NEXT: comb.icmp sgt %arg4, %arg4 : i1
// CHECK-NEXT: comb.icmp sge %arg4, %arg4 : i1
moore.lt %arg4, %arg4 : !moore.bit<signed> -> !moore.bit
moore.le %arg4, %arg4 : !moore.bit<signed> -> !moore.bit
moore.gt %arg4, %arg4 : !moore.bit<signed> -> !moore.bit
moore.ge %arg4, %arg4 : !moore.bit<signed> -> !moore.bit

// CHECK-NEXT: comb.icmp eq %arg1, %arg1 : i1
// CHECK-NEXT: comb.icmp ne %arg0, %arg0 : i1
// CHECK-NEXT: comb.icmp ceq %arg0, %arg0 : i1
// CHECK-NEXT: comb.icmp cne %arg0, %arg0 : i1
// CHECK-NEXT: comb.icmp weq %arg0, %arg0 : i1
// CHECK-NEXT: comb.icmp wne %arg0, %arg0 : i1
moore.eq %arg1, %arg1 : !moore.logic -> !moore.logic
moore.ne %arg0, %arg0 : !moore.bit -> !moore.bit
moore.case_eq %arg0, %arg0 : !moore.bit
moore.case_ne %arg0, %arg0 : !moore.bit
moore.wildcard_eq %arg0, %arg0 : !moore.bit -> !moore.bit
moore.wildcard_ne %arg0, %arg0 : !moore.bit -> !moore.bit

// CHECK-NEXT: return
return
}

0 comments on commit 1bccc51

Please sign in to comment.