From 1bccc51d1131fb9b0024e7db3adb04f4e749896b Mon Sep 17 00:00:00 2001 From: hailong Date: Mon, 22 Apr 2024 15:56:21 +0800 Subject: [PATCH] [MooreToCore] Lower moore operators into comb or hw. Co-authored-by: Fabian Schuiki --- lib/Conversion/MooreToCore/MooreToCore.cpp | 209 ++++++++++++++++++++- test/Conversion/MooreToCore/basic.mlir | 100 +++++++++- 2 files changed, 298 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/MooreToCore/MooreToCore.cpp b/lib/Conversion/MooreToCore/MooreToCore.cpp index 6357897f322d..8a4860c05efe 100644 --- a/lib/Conversion/MooreToCore/MooreToCore.cpp +++ b/lib/Conversion/MooreToCore/MooreToCore.cpp @@ -56,6 +56,51 @@ static Value adjustIntegerWidth(OpBuilder &builder, Value value, return builder.create(loc, isZero, lo, max, false); } +/// Due to the result type of the `lt`, or `le`, or `gt`, or `ge` ops are +/// always unsigned, estimating their operands type. +static bool isSignedType(Operation *op) { + return TypeSwitch(op) + .template Case([&](auto op) -> bool { + return cast(op->getOperand(0).getType()) + .castToSimpleBitVector() + .isSigned() && + cast(op->getOperand(1).getType()) + .castToSimpleBitVector() + .isSigned(); + }) + .Default([&](auto op) -> bool { + return cast(op->getResult(0).getType()) + .castToSimpleBitVector() + .isSigned(); + }); +} + +/// Not define the predicate for `relation` and `equality` operations in the +/// MooreDialect, but comb needs it. Return a correct `comb::ICmpPredicate` +/// corresponding to different moore `relation` and `equality` operations. +static comb::ICmpPredicate getCombPredicate(Operation *op) { + using comb::ICmpPredicate; + return TypeSwitch(op) + .Case([&](auto op) { + return isSignedType(op) ? ICmpPredicate::slt : ICmpPredicate::ult; + }) + .Case([&](auto op) { + return isSignedType(op) ? ICmpPredicate::sle : ICmpPredicate::ule; + }) + .Case([&](auto op) { + return isSignedType(op) ? ICmpPredicate::sgt : ICmpPredicate::ugt; + }) + .Case([&](auto op) { + return isSignedType(op) ? ICmpPredicate::sge : ICmpPredicate::uge; + }) + .Case([&](auto op) { return ICmpPredicate::eq; }) + .Case([&](auto op) { return ICmpPredicate::ne; }) + .Case([&](auto op) { return ICmpPredicate::ceq; }) + .Case([&](auto op) { return ICmpPredicate::cne; }) + .Case([&](auto op) { return ICmpPredicate::weq; }) + .Case([&](auto op) { return ICmpPredicate::wne; }); +} + //===----------------------------------------------------------------------===// // Expression Conversion //===----------------------------------------------------------------------===// @@ -82,6 +127,124 @@ struct ConcatOpConversion : public OpConversionPattern { } }; +struct ReplicateOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ReplicateOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = typeConverter->convertType(op.getResult().getType()); + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); + return success(); + } +}; + +struct ExtractOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = typeConverter->convertType(op.getResult().getType()); + auto width = typeConverter->convertType(op.getInput().getType()) + .getIntOrFloatBitWidth(); + Value amount = + adjustIntegerWidth(rewriter, adaptor.getLowBit(), width, op->getLoc()); + Value value = + rewriter.create(op->getLoc(), adaptor.getInput(), amount); + + rewriter.replaceOpWithNewOp(op, resultType, value, 0); + return success(); + } +}; + +template +struct UnaryOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; + + LogicalResult + matchAndRewrite(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getInput()); + return success(); + } +}; + +struct NotOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(NotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + ConversionPattern::typeConverter->convertType(op.getResult().getType()); + Value max = rewriter.create(op.getLoc(), resultType, -1); + + rewriter.replaceOpWithNewOp(op, adaptor.getInput(), max); + return success(); + } +}; + +template +struct BinaryOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename SourceOp::Adaptor; + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (isa(op) && isSignedType(op)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs(), false); + return success(); + } + if (isa(op) && isSignedType(op)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs(), false); + return success(); + } + + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs(), false); + return success(); + } +}; + +template +struct ICmpOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename SourceOp::Adaptor; + + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + ConversionPattern::typeConverter->convertType(op.getResult().getType()); + comb::ICmpPredicate pred = getCombPredicate(op); + + rewriter.replaceOpWithNewOp( + op, resultType, pred, adapter.getLhs(), adapter.getRhs()); + return success(); + } +}; + +struct ConversionOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConversionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = typeConverter->convertType(op.getResult().getType()); + Value amount = + adjustIntegerWidth(rewriter, adaptor.getInput(), + resultType.getIntOrFloatBitWidth(), op->getLoc()); + + rewriter.replaceOpWithNewOp(op, resultType, amount); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Statement Conversion //===----------------------------------------------------------------------===// @@ -271,6 +434,9 @@ static void populateTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion([&](UnpackedType type) -> std::optional { if (auto sbv = type.getSimpleBitVectorOrNull()) return mlir::IntegerType::get(type.getContext(), sbv.size); + if (isa(type)) + return mlir::IntegerType::get(type.getContext(), + type.getBitSize().value()); return std::nullopt; }); @@ -283,16 +449,39 @@ static void populateOpConversion(RewritePatternSet &patterns, auto *context = patterns.getContext(); // clang-format off patterns.add< - ConstantOpConv, - ConcatOpConversion, - ReturnOpConversion, - CondBranchOpConversion, - BranchOpConversion, - CallOpConversion, - ShlOpConversion, - ShrOpConversion, - AShrOpConversion, - UnrealizedConversionCastConversion + // Patterns of miscellaneous operations. + ConstantOpConv, ConcatOpConversion, ReplicateOpConversion, + ExtractOpConversion, ConversionOpConversion, + + // Patterns of unary operations. + UnaryOpConversion, UnaryOpConversion, + UnaryOpConversion, UnaryOpConversion, + NotOpConversion, + + // Patterns of binary operations. + BinaryOpConversion, + BinaryOpConversion, + BinaryOpConversion, + BinaryOpConversion, + BinaryOpConversion, + BinaryOpConversion, + BinaryOpConversion, + BinaryOpConversion, + + // Patterns of relational operations. + ICmpOpConversion, ICmpOpConversion, ICmpOpConversion, + ICmpOpConversion, ICmpOpConversion, ICmpOpConversion, + ICmpOpConversion, ICmpOpConversion, + ICmpOpConversion, ICmpOpConversion, + + // Patterns of shifting operations. + ShrOpConversion, ShlOpConversion, AShrOpConversion, + + // Patterns of branch operations. + CondBranchOpConversion, BranchOpConversion, + + // Patterns of other operations outside Moore dialect. + ReturnOpConversion, CallOpConversion, UnrealizedConversionCastConversion >(typeConverter, context); // clang-format on mlir::populateFunctionOpInterfaceTypeConversionPattern( diff --git a/test/Conversion/MooreToCore/basic.mlir b/test/Conversion/MooreToCore/basic.mlir index f08e36f3200f..babb2936830d 100644 --- a/test/Conversion/MooreToCore/basic.mlir +++ b/test/Conversion/MooreToCore/basic.mlir @@ -46,12 +46,25 @@ func.func @UnrealizedConversionCast(%arg0: !moore.byte) -> !moore.shortint { } // CHECK-LABEL: func @Expressions -func.func @Expressions(%arg0: !moore.bit, %arg1: !moore.logic, %arg2: !moore.packed>, %arg3: !moore.packed, 4:0>>) { +func.func @Expressions(%arg0: !moore.bit, %arg1: !moore.logic, %arg2: !moore.packed>, %arg3: !moore.packed, 4:0>>, %arg4: !moore.bit) { // CHECK-NEXT: %0 = comb.concat %arg0, %arg0 : i1, i1 // CHECK-NEXT: %1 = comb.concat %arg1, %arg1 : i1, i1 moore.concat %arg0, %arg0 : (!moore.bit, !moore.bit) -> !moore.packed> moore.concat %arg1, %arg1 : (!moore.logic, !moore.logic) -> !moore.packed> + // CHECK-NEXT: comb.replicate %arg0 : (i1) -> i2 + // CHECK-NEXT: comb.replicate %arg1 : (i1) -> i2 + moore.replicate %arg0 : (!moore.bit) -> !moore.packed> + moore.replicate %arg1 : (!moore.logic) -> !moore.packed> + + // CHECK-NEXT: %c12_i32 = hw.constant 12 : i32 + // CHECK-NEXT: %c3_i6 = hw.constant 3 : i6 + moore.constant 12 : !moore.int + moore.constant 3 : !moore.packed> + + // CHECK-NEXT: hw.bitcast %arg0 : (i1) -> i1 + moore.conversion %arg0 : !moore.bit -> !moore.logic + // CHECK-NEXT: [[V0:%.+]] = hw.constant 0 : i5 // CHECK-NEXT: [[V1:%.+]] = comb.concat [[V0]], %arg0 : i5, i1 // CHECK-NEXT: comb.shl %arg2, [[V1]] : i6 @@ -83,6 +96,91 @@ func.func @Expressions(%arg0: !moore.bit, %arg1: !moore.logic, %arg2: !moore.pac // CHECK-NEXT: comb.shrs %arg3, [[V15]] : i5 moore.ashr %arg3, %arg2 : !moore.packed, 4:0>>, !moore.packed> + // CHECK-NEXT: %c2_i32 = hw.constant 2 : i32 + %2 = moore.constant 2 : !moore.int + + // CHECK-NEXT: [[V16:%.+]] = comb.extract %c2_i32 from 6 : (i32) -> i26 + // CHECK-NEXT: %c0_i26 = hw.constant 0 : i26 + // CHECK-NEXT: [[V17:%.+]] = comb.icmp eq [[V16]], %c0_i26 : i26 + // CHECK-NEXT: [[V18:%.+]] = comb.extract %c2_i32 from 0 : (i32) -> i6 + // CHECK-NEXT: %c-1_i6 = hw.constant -1 : i6 + // CHECK-NEXT: [[V19:%.+]] = comb.mux [[V17]], [[V18]], %c-1_i6 : i6 + // CHECK-NEXT: [[V20:%.+]] = comb.shru %arg2, [[V19]] : i6 + // CHECK-NEXT: comb.extract [[V20]] from 0 : (i6) -> i2 + moore.extract %arg2 from %2 : !moore.packed>, !moore.int -> !moore.packed> + + // CHECK-NEXT: [[V21:%.+]] = comb.extract %c2_i32 from 6 : (i32) -> i26 + // CHECK-NEXT: %c0_i26_3 = hw.constant 0 : i26 + // CHECK-NEXT: [[V22:%.+]] = comb.icmp eq [[V21]], %c0_i26_3 : i26 + // CHECK-NEXT: [[V23:%.+]] = comb.extract %c2_i32 from 0 : (i32) -> i6 + // CHECK-NEXT: %c-1_i6_4 = hw.constant -1 : i6 + // CHECK-NEXT: [[V24:%.+]] = comb.mux [[V22]], [[V23]], %c-1_i6_4 : i6 + // CHECK-NEXT: [[V25:%.+]] = comb.shru %arg2, [[V24]] : i6 + // CHECK-NEXT: comb.extract [[V25]] from 0 : (i6) -> i1 + moore.extract %arg2 from %2 : !moore.packed>, !moore.int -> !moore.bit + + // CHECK-NEXT: comb.parity %arg0 : i1 + // CHECK-NEXT: comb.parity %arg0 : i1 + // CHECK-NEXT: comb.parity %arg1 : i1 + moore.reduce_and %arg0 : !moore.bit -> !moore.bit + moore.reduce_or %arg0 : !moore.bit -> !moore.bit + moore.reduce_xor %arg1 : !moore.logic -> !moore.logic + + // CHECK-NEXT: comb.parity %arg2 : i6 + moore.bool_cast %arg2 : !moore.packed> -> !moore.bit + + // CHECK-NEXT: [[V26:%.+]] = hw.constant -1 : i6 + // CHECK-NEXT: comb.xor %arg2, [[V26]] : i6 + moore.not %arg2 : !moore.packed> + + // CHECK-NEXT: comb.add %arg1, %arg1 : i1 + // CHECK-NEXT: comb.sub %arg1, %arg1 : i1 + // CHECK-NEXT: comb.mul %arg1, %arg1 : i1 + // CHECK-NEXT: comb.divu %arg0, %arg0 : i1 + // CHECK-NEXT: comb.modu %arg0, %arg0 : i1 + // CHECK-NEXT: comb.and %arg0, %arg0 : i1 + // CHECK-NEXT: comb.or %arg0, %arg0 : i1 + // CHECK-NEXT: comb.xor %arg0, %arg0 : i1 + moore.add %arg1, %arg1 : !moore.logic + moore.sub %arg1, %arg1 : !moore.logic + moore.mul %arg1, %arg1 : !moore.logic + moore.div %arg0, %arg0 : !moore.bit + moore.mod %arg0, %arg0 : !moore.bit + moore.and %arg0, %arg0 : !moore.bit + moore.or %arg0, %arg0 : !moore.bit + moore.xor %arg0, %arg0 : !moore.bit + + // CHECK-NEXT: comb.icmp ult %arg1, %arg1 : i1 + // CHECK-NEXT: comb.icmp ule %arg0, %arg0 : i1 + // CHECK-NEXT: comb.icmp ugt %arg0, %arg0 : i1 + // CHECK-NEXT: comb.icmp uge %arg0, %arg0 : i1 + moore.lt %arg1, %arg1 : !moore.logic -> !moore.logic + moore.le %arg0, %arg0 : !moore.bit -> !moore.bit + moore.gt %arg0, %arg0 : !moore.bit -> !moore.bit + moore.ge %arg0, %arg0 : !moore.bit -> !moore.bit + + // CHECK-NEXT: comb.icmp slt %arg4, %arg4 : i1 + // CHECK-NEXT: comb.icmp sle %arg4, %arg4 : i1 + // CHECK-NEXT: comb.icmp sgt %arg4, %arg4 : i1 + // CHECK-NEXT: comb.icmp sge %arg4, %arg4 : i1 + moore.lt %arg4, %arg4 : !moore.bit -> !moore.bit + moore.le %arg4, %arg4 : !moore.bit -> !moore.bit + moore.gt %arg4, %arg4 : !moore.bit -> !moore.bit + moore.ge %arg4, %arg4 : !moore.bit -> !moore.bit + + // CHECK-NEXT: comb.icmp eq %arg1, %arg1 : i1 + // CHECK-NEXT: comb.icmp ne %arg0, %arg0 : i1 + // CHECK-NEXT: comb.icmp ceq %arg0, %arg0 : i1 + // CHECK-NEXT: comb.icmp cne %arg0, %arg0 : i1 + // CHECK-NEXT: comb.icmp weq %arg0, %arg0 : i1 + // CHECK-NEXT: comb.icmp wne %arg0, %arg0 : i1 + moore.eq %arg1, %arg1 : !moore.logic -> !moore.logic + moore.ne %arg0, %arg0 : !moore.bit -> !moore.bit + moore.case_eq %arg0, %arg0 : !moore.bit + moore.case_ne %arg0, %arg0 : !moore.bit + moore.wildcard_eq %arg0, %arg0 : !moore.bit -> !moore.bit + moore.wildcard_ne %arg0, %arg0 : !moore.bit -> !moore.bit + // CHECK-NEXT: return return }