Skip to content

Commit

Permalink
Revise folding for CBitExtractBitOp (#99)
Browse files Browse the repository at this point in the history
Simplify and add support for the base case of single-bit registers.
While at it, add a verifier for CBitExtractBitOp.
  • Loading branch information
mhillenbrand authored May 2, 2023
1 parent f932310 commit 90f688b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
3 changes: 3 additions & 0 deletions include/Dialect/OQ3/IR/OQ3CBitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def OQ3_CBitExtractBitOp : OQ3_Op<"cbit_extractbit", [NoSideEffect]> {
}];

let hasFolder = 1;

// TODO in LLVM 15 + this can become just let hasVerifier = 1;
let verifier = [{ return ::verify(*this);}];
}

// -----
Expand Down
27 changes: 21 additions & 6 deletions lib/Dialect/OQ3/IR/OQ3Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ verifyOQ3VariableOpSymbolUses(SymbolTableCollection &symbolTable,
static llvm::Optional<mlir::Value>
findDefiningBitInBitmap(mlir::Value val, mlir::IntegerAttr bitIndex) {

// for single-bit registers, CBitExtractBitOp is the identity.
if (val.getType().isInteger(1))
return val;

mlir::Operation *op = val.getDefiningOp();

// follow chains of CBit_InsertBit operations and try to find one matching the
Expand All @@ -79,12 +83,9 @@ findDefiningBitInBitmap(mlir::Value val, mlir::IntegerAttr bitIndex) {
op = insertBitOp.operand().getDefiningOp();
}

// is the value defined by an i1 constant? then that would be the bit
if (auto constantOp =
mlir::dyn_cast_or_null<mlir::arith::ConstantIntOp>(op)) {
if (constantOp.getType().isInteger(1))
return constantOp.getResult();
}
// did we identify an op that provides the single bit?
if (op && op->getResult(0).getType().isInteger(1))
return op->getResult(0);

return llvm::None;
}
Expand All @@ -99,6 +100,20 @@ CBitExtractBitOp::fold(::llvm::ArrayRef<::mlir::Attribute> operands) {
return nullptr;
}

static LogicalResult verify(CBitExtractBitOp op) {

auto t = op.getOperand().getType();

if (auto cbitType = t.dyn_cast<mlir::quir::CBitType>();
cbitType && op.index().ult(cbitType.getWidth()))
return success();

if (t.isIntOrIndex() && op.index().ult(t.getIntOrFloatBitWidth()))
return success();

return op.emitOpError("index must be less than the width of the operand.");
}

LogicalResult
CBitAssignBitOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

Expand Down
19 changes: 19 additions & 0 deletions test/Dialect/OQ3/IR/fold-cbit-extractbit.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: qss-opt %s --canonicalize | qss-opt | FileCheck %s --implicit-check-not cbit_extractbit
// Verify that all oq3.cbit_extractbit operations are eliminated

// CHECK: func @single_bit(%[[ARG0:.*]]: i1) -> i1 {
func @single_bit(%bit: i1) -> i1 {
%2 = oq3.cbit_extractbit(%bit : i1) [0] : i1
// CHECK: return %[[ARG0]] : i1
return %2 : i1
}

// CHECK: func @two_bits(%[[ARG0:.*]]: !quir.cbit<2>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: i1)
func @two_bits(%cbit: !quir.cbit<2>, %bit1: i1, %bit2: i1) -> i1 {
%0 = oq3.cbit_insertbit(%cbit : !quir.cbit<2>)[0] = %bit1 : !quir.cbit<2>
%1 = oq3.cbit_insertbit(%cbit : !quir.cbit<2>)[1] = %bit2 : !quir.cbit<2>

%2 = oq3.cbit_extractbit(%1 : !quir.cbit<2>) [1] : i1
// CHECK: return %[[ARG2]] : i1
return %2 : i1
}

0 comments on commit 90f688b

Please sign in to comment.