From a8736e717e89be450aa1070ac0542e2b56b187bb Mon Sep 17 00:00:00 2001 From: Marius Hillenbrand Date: Wed, 26 Apr 2023 05:06:45 -0400 Subject: [PATCH] Revise folding for CBitExtractBitOp Simplify and add support for the base case of single-bit registers. While at it, add a verifier for CBitExtractBitOp. --- include/Dialect/OQ3/IR/OQ3CBitOps.td | 3 +++ lib/Dialect/OQ3/IR/OQ3Ops.cpp | 27 ++++++++++++++----- test/Dialect/OQ3/IR/fold-cbit-extractbit.mlir | 19 +++++++++++++ 3 files changed, 43 insertions(+), 6 deletions(-) create mode 100644 test/Dialect/OQ3/IR/fold-cbit-extractbit.mlir diff --git a/include/Dialect/OQ3/IR/OQ3CBitOps.td b/include/Dialect/OQ3/IR/OQ3CBitOps.td index e6a91efae..0a80c2a64 100644 --- a/include/Dialect/OQ3/IR/OQ3CBitOps.td +++ b/include/Dialect/OQ3/IR/OQ3CBitOps.td @@ -204,6 +204,9 @@ def OQ3_CBitExtractBitOp : OQ3_Op<"cbit_extractbit", [NoSideEffect]> { }]; let hasFolder = 1; + + // TODO in LLVM 15 + this can become just let hasVerifier = 1; + let verifier = [{ return ::verify(*this);}]; } // ----- diff --git a/lib/Dialect/OQ3/IR/OQ3Ops.cpp b/lib/Dialect/OQ3/IR/OQ3Ops.cpp index 61c41405f..3a5765dca 100644 --- a/lib/Dialect/OQ3/IR/OQ3Ops.cpp +++ b/lib/Dialect/OQ3/IR/OQ3Ops.cpp @@ -68,6 +68,10 @@ verifyOQ3VariableOpSymbolUses(SymbolTableCollection &symbolTable, static llvm::Optional findDefiningBitInBitmap(mlir::Value val, mlir::IntegerAttr bitIndex) { + // for single-bit registers, CBitExtractBitOp is the identity. + if (val.getType().isInteger(1)) + return val; + mlir::Operation *op = val.getDefiningOp(); // follow chains of CBit_InsertBit operations and try to find one matching the @@ -79,12 +83,9 @@ findDefiningBitInBitmap(mlir::Value val, mlir::IntegerAttr bitIndex) { op = insertBitOp.operand().getDefiningOp(); } - // is the value defined by an i1 constant? then that would be the bit - if (auto constantOp = - mlir::dyn_cast_or_null(op)) { - if (constantOp.getType().isInteger(1)) - return constantOp.getResult(); - } + // did we identify an op that provides the single bit? + if (op && op->getResult(0).getType().isInteger(1)) + return op->getResult(0); return llvm::None; } @@ -99,6 +100,20 @@ CBitExtractBitOp::fold(::llvm::ArrayRef<::mlir::Attribute> operands) { return nullptr; } +static LogicalResult verify(CBitExtractBitOp op) { + + auto t = op.getOperand().getType(); + + if (auto cbitType = t.dyn_cast(); + cbitType && op.index().ult(cbitType.getWidth())) + return success(); + + if (t.isIntOrIndex() && op.index().ult(t.getIntOrFloatBitWidth())) + return success(); + + return op.emitOpError("index must be less than the width of the operand."); +} + LogicalResult CBitAssignBitOp::verifySymbolUses(SymbolTableCollection &symbolTable) { diff --git a/test/Dialect/OQ3/IR/fold-cbit-extractbit.mlir b/test/Dialect/OQ3/IR/fold-cbit-extractbit.mlir new file mode 100644 index 000000000..eac1115dc --- /dev/null +++ b/test/Dialect/OQ3/IR/fold-cbit-extractbit.mlir @@ -0,0 +1,19 @@ +// RUN: qss-opt %s --canonicalize | qss-opt | FileCheck %s --implicit-check-not cbit_extractbit +// Verify that all oq3.cbit_extractbit operations are eliminated + +// CHECK: func @single_bit(%[[ARG0:.*]]: i1) -> i1 { +func @single_bit(%bit: i1) -> i1 { + %2 = oq3.cbit_extractbit(%bit : i1) [0] : i1 + // CHECK: return %[[ARG0]] : i1 + return %2 : i1 +} + +// CHECK: func @two_bits(%[[ARG0:.*]]: !quir.cbit<2>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: i1) +func @two_bits(%cbit: !quir.cbit<2>, %bit1: i1, %bit2: i1) -> i1 { + %0 = oq3.cbit_insertbit(%cbit : !quir.cbit<2>)[0] = %bit1 : !quir.cbit<2> + %1 = oq3.cbit_insertbit(%cbit : !quir.cbit<2>)[1] = %bit2 : !quir.cbit<2> + + %2 = oq3.cbit_extractbit(%1 : !quir.cbit<2>) [1] : i1 + // CHECK: return %[[ARG2]] : i1 + return %2 : i1 +}