Skip to content

Commit

Permalink
[FIRRTL] no back-prop for width of mux selectors, support narrower (l…
Browse files Browse the repository at this point in the history
…lvm#6917)

* Allow mux selectors to be zero-width or 1-bit (for mux4).

This is legal per FIRRTL spec.

* InferWidths: mux no back-prop.

Fixes llvm#5444

* canonicalizers for small mux selectors

Co-authored-by: Schuyler Eldridge <[email protected]>
  • Loading branch information
2 people authored and cepheus69 committed Apr 22, 2024
1 parent 79299a5 commit ae819a7
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 22 deletions.
35 changes: 35 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def IntTypeWidthGEQ32 : Constraint<CPred<
def IntTypeWidthGT32 : Constraint<CPred<
"type_cast<IntType>($0.getType()).getBitWidthOrSentinel() > type_cast<IntType>($1.getType()).getBitWidthOrSentinel()">>;

// sizeof(0) < X
class IntTypeWidthLTX<int X> : Constraint<CPred<
"type_cast<IntType>($0.getType()).getBitWidthOrSentinel() >= 0 &&"
"type_cast<IntType>($0.getType()).getBitWidthOrSentinel()< " # X
>>;

// Constraint that enforces int types
def IntTypes : Constraint<CPred<"type_isa<IntType>($0.getType())">>;

Expand Down Expand Up @@ -562,6 +568,35 @@ def MuxNEQ : Pat<
(MoveNameHint $old, (MuxPrimOp (EQPrimOp $a, $b), $y, $x)),
[(EqualTypes $x, $y), (KnownWidth $x)]>;

// mux(cond : u0, a, b) -> mux(0 : u1, a, b)
def MuxPadSel : Pat<
(MuxPrimOp:$old $cond, $a, $b),
(MoveNameHint $old, (MuxPrimOp
(ConstantOp
(NativeCodeCall<"$_builder.getUI32IntegerAttr(0)">),
(returnType "$_builder.getType<UIntType>(1)")),
$a, $b)),
[(IntTypeWidthLTX<1> $cond)]>;

// mux2(cond : u0, a, b) -> mux2(0 : u1, a, b)
def Mux2PadSel : Pat<
(Mux2CellIntrinsicOp:$old $cond, $a, $b),
(MoveNameHint $old, (Mux2CellIntrinsicOp
(ConstantOp
(NativeCodeCall<"$_builder.getUI32IntegerAttr(0)">),
(returnType "$_builder.getType<UIntType>(1)")),
$a, $b)),
[(IntTypeWidthLTX<1> $cond)]>;

// mux4(cond : u0/u1, a, b) -> mux4(pad(cond -> u2), a, b)
def Mux4PadSel : Pat<
(Mux4CellIntrinsicOp:$old $cond, $a, $b, $c, $d),
(MoveNameHint $old, (Mux4CellIntrinsicOp
(PadPrimOp $cond,
(NativeCodeCall<"$_builder.getI32IntegerAttr(2)">)),
$a, $b, $c, $d)),
[(IntTypeWidthLTX<2> $cond)]>;

def CatDoubleConst : Pat <
(CatPrimOp:$old $cst1, (CatPrimOp $cst2, $v)),
(MoveNameHint $old, (CatPrimOp (CatPrimOp $cst1, (AsUIntPrimOp $cst2)), (AsUIntPrimOp $v))),
Expand Down
10 changes: 7 additions & 3 deletions include/circt/Dialect/FIRRTL/FIRRTLExpressions.td
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def HeadPrimOp : PrimOp<"head"> {
}

def MuxPrimOp : PrimOp<"mux"> {
let arguments = (ins UInt1OrUnsizedType:$sel, PassiveType:$high,
let arguments = (ins UIntLTE1OrUnsizedType:$sel, PassiveType:$high,
PassiveType:$low);
let results = (outs PassiveType:$result);

Expand Down Expand Up @@ -842,10 +842,12 @@ def Mux2CellIntrinsicOp : PrimOp<"int.mux2cell"> {
the inference process in the same way as a normal mux operation.
}];

let arguments = (ins UInt1OrUnsizedType:$sel, PassiveType:$high,
let arguments = (ins UIntLTE1OrUnsizedType:$sel, PassiveType:$high,
PassiveType:$low);
let results = (outs PassiveType:$result);

let hasCanonicalizer = true;

let assemblyFormat =
"`(` operands `)` attr-dict `:` functional-type(operands, $result)";
}
Expand All @@ -861,11 +863,13 @@ def Mux4CellIntrinsicOp : PrimOp<"int.mux4cell"> {
the inference process as a sugar of mux operation chains.
}];

let arguments = (ins UInt2OrUnsizedType:$sel, PassiveType:$v3,
let arguments = (ins UIntLTE2OrUnsizedType:$sel, PassiveType:$v3,
PassiveType:$v2, PassiveType:$v1,
PassiveType:$v0);
let results = (outs PassiveType:$result);

let hasCanonicalizer = true;

let assemblyFormat =
"`(` operands `)` attr-dict `:` functional-type(operands, $result)";
}
Expand Down
10 changes: 8 additions & 2 deletions include/circt/Dialect/FIRRTL/FIRRTLTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ class SizedUIntType<int width> : FIRRTLDialectType<
"type_cast<UIntType>($_self).getWidth() == " # width>,
width # "-bit uint", "::circt::firrtl::UIntType">;

class SizedUIntTypeLTE<int width> : FIRRTLDialectType<
CPred<"type_isa<UIntType>($_self) && "
"type_cast<UIntType>($_self).getWidth() <= " # width>,
"uint with width less than or equal to " # width # " bits",
"::circt::firrtl::UIntType">;

class NonConstSizedUIntType<int width> :
SizedUIntType<width>,
BuildableType<
Expand All @@ -138,8 +144,8 @@ def UInt2Type : SizedUIntType<2>;
def UInt32Type : SizedUIntType<32>;
def NonConstUInt1Type : NonConstSizedUIntType<1>;

def UInt1OrUnsizedType : AnyTypeOf<[UInt1Type, UnsizedUIntType]>;
def UInt2OrUnsizedType : AnyTypeOf<[UInt2Type, UnsizedUIntType]>;
def UIntLTE1OrUnsizedType : AnyTypeOf<[SizedUIntTypeLTE<1>, UnsizedUIntType]>;
def UIntLTE2OrUnsizedType : AnyTypeOf<[SizedUIntTypeLTE<2>, UnsizedUIntType]>;

//===----------------------------------------------------------------------===//
// FIRRTL Types Predicates
Expand Down
20 changes: 16 additions & 4 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1476,10 +1476,22 @@ class MuxSharedCond : public mlir::RewritePattern {

void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ,
patterns::MuxNot, patterns::MuxSameTrue, patterns::MuxSameFalse,
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS>(context);
results
.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
patterns::MuxSameTrue, patterns::MuxSameFalse,
patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
context);
}

void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<patterns::Mux2PadSel>(context);
}

void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<patterns::Mux4PadSel>(context);
}

OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/FIRRTL/Transforms/InferWidths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1596,12 +1596,12 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {
})
.Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](auto op) {
auto *sel = getExpr(op.getSel());
constrainTypes(sel, solver.known(1));
constrainTypes(solver.known(1), sel, /*imposeUpperBounds=*/true);
maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
})
.Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
auto *sel = getExpr(op.getSel());
constrainTypes(sel, solver.known(2));
constrainTypes(solver.known(2), sel, /*imposeUpperBounds=*/true);
maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
Expand Down
13 changes: 12 additions & 1 deletion test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,9 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
out %out1: !firrtl.uint<1>,
out %out2: !firrtl.uint<0>,
out %out3: !firrtl.uint<1>,
out %out4: !firrtl.uint<4>) {
out %out4: !firrtl.uint<4>,
out %out5: !firrtl.uint<1>,
out %out6: !firrtl.uint<1>) {
// CHECK: firrtl.strictconnect %out, %in
%0 = firrtl.int.mux2cell (%cond, %in, %in) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.connect %out, %0 : !firrtl.uint<4>, !firrtl.uint<4>
Expand Down Expand Up @@ -560,6 +562,15 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
// CHECK-NEXT: [[V2:%.+]] = firrtl.mux(%cond
// CHECK-NEXT: firrtl.strictconnect %out4, [[V2]]
firrtl.connect %out4, %15 : !firrtl.uint<4>, !firrtl.uint<4>

// CHECK-NEXT: firrtl.strictconnect %out5, %val2
%16 = firrtl.mux (%val0, %val1, %val2) : (!firrtl.uint<0>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out5, %16 : !firrtl.uint<1>

// CHECK-NEXT: %[[SEL:.+]] = firrtl.pad %val1, 2 : (!firrtl.uint<1>) -> !firrtl.uint<2>
// CHECK-NEXT: mux4cell(%[[SEL]],
%17 = firrtl.int.mux4cell (%val1, %val1, %val2, %val1, %val2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out6, %17 : !firrtl.uint<1>
}

// CHECK-LABEL: firrtl.module @Pad
Expand Down
26 changes: 26 additions & 0 deletions test/Dialect/FIRRTL/infer-widths-errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,29 @@ firrtl.circuit "NoWidthEnum" {
firrtl.module @NoWidthEnum(out %o: !firrtl.enum<Some: uint>) {
}
}

// -----

firrtl.circuit "MuxSelBackProp" {
firrtl.module @MuxSelBackProp() {
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
// expected-error @below {{uninferred width: wire is unconstrained}}
%0 = firrtl.wire : !firrtl.uint
%1 = firrtl.mux(%0, %c1_ui1, %c1_ui1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
}
}

// -----

firrtl.circuit "MuxSelTooWide" {
firrtl.module @MuxSelTooWide() {
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
// expected-error @below {{uninferred width: wire cannot satisfy all width requirements}}
%0 = firrtl.wire : !firrtl.uint
// expected-note @below {{width is constrained to be at most 1 here:}}
%1 = firrtl.mux(%0, %c1_ui1, %c1_ui1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// expected-note @below {{width is constrained to be at least 2 here:}}
firrtl.connect %0, %c2_ui2 : !firrtl.uint, !firrtl.uint<2>
}
}
20 changes: 10 additions & 10 deletions test/Dialect/FIRRTL/infer-widths.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -364,20 +364,18 @@ firrtl.circuit "Foo" {
firrtl.module @MuxOp() {
// CHECK: %0 = firrtl.wire : !firrtl.uint<2>
// CHECK: %1 = firrtl.wire : !firrtl.uint<3>
// CHECK: %2 = firrtl.wire : !firrtl.uint<1>
// CHECK: %2 = firrtl.wire : !firrtl.uint<0>
// CHECK: %3 = firrtl.mux{{.*}} -> !firrtl.uint<3>
%0 = firrtl.wire : !firrtl.uint
%1 = firrtl.wire : !firrtl.uint
%2 = firrtl.wire : !firrtl.uint
%3 = firrtl.mux(%2, %0, %1) : (!firrtl.uint, !firrtl.uint, !firrtl.uint) -> !firrtl.uint
// CHECK: %4 = firrtl.wire : !firrtl.uint<1>
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%4 = firrtl.wire : !firrtl.uint
%5 = firrtl.mux(%4, %c1_ui1, %c1_ui1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
%c1_ui2 = firrtl.constant 1 : !firrtl.uint<2>
%c2_ui3 = firrtl.constant 2 : !firrtl.uint<3>
%c0_ui0 = firrtl.constant 0 : !firrtl.uint<0>
firrtl.connect %0, %c1_ui2 : !firrtl.uint, !firrtl.uint<2>
firrtl.connect %1, %c2_ui3 : !firrtl.uint, !firrtl.uint<3>
firrtl.connect %2, %c0_ui0 : !firrtl.uint, !firrtl.uint<0>
}

// see https://github.com/llvm/circt/issues/3070
Expand Down Expand Up @@ -957,22 +955,24 @@ firrtl.circuit "Foo" {
firrtl.module @Property(in %a: !firrtl.string) { }

// CHECK-LABEL: module @MuxIntrinsics
// CHECK-SAME: %sel: !firrtl.uint<1>
// CHECK-SAME: %sel2: !firrtl.uint<2>
firrtl.module @MuxIntrinsics(in %sel: !firrtl.uint, in %sel2: !firrtl.uint, in %high: !firrtl.uint<1>, in %low: !firrtl.uint<1>, out %out1: !firrtl.uint, out %out2: !firrtl.uint) {
firrtl.module @MuxIntrinsics(in %sel_0w: !firrtl.uint<0>, in %sel_1w: !firrtl.uint<1>, in %high: !firrtl.uint<1>, in %low: !firrtl.uint<1>, out %out1: !firrtl.uint, out %out2: !firrtl.uint) {
%c3_ui4 = firrtl.constant 3 : !firrtl.uint<4>
%c3_ui3 = firrtl.constant 3 : !firrtl.uint<3>
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%c1_ui2 = firrtl.constant 1 : !firrtl.uint<2>
%c0_ui1 = firrtl.constant 0 : !firrtl.uint<1>
%c1 = firrtl.constant 0: !firrtl.uint
%sel = firrtl.wire : !firrtl.uint
firrtl.connect %sel, %sel_0w : !firrtl.uint, !firrtl.uint<0>
// CHECK: firrtl.int.mux2cell
// CHECK-SAME: (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-SAME: (!firrtl.uint<0>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
%0 = firrtl.int.mux2cell(%sel, %c0_ui1, %c1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint) -> !firrtl.uint
firrtl.connect %out1, %0: !firrtl.uint, !firrtl.uint
%sel2 = firrtl.wire : !firrtl.uint
firrtl.connect %sel2, %sel_1w : !firrtl.uint, !firrtl.uint<1>
// CHECK: firrtl.int.mux4cell
// CHECK-SAME: (!firrtl.uint<2>, !firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<3>, !firrtl.uint<1>) -> !firrtl.uint<3>
// CHECK-SAME: (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<3>, !firrtl.uint<1>) -> !firrtl.uint<3>
%1 = firrtl.int.mux4cell(%sel2, %c1_ui1, %c2_ui2, %c3_ui3, %c1) : (!firrtl.uint, !firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<3>, !firrtl.uint) -> !firrtl.uint
firrtl.connect %out2, %1: !firrtl.uint, !firrtl.uint
}
Expand Down

0 comments on commit ae819a7

Please sign in to comment.