diff --git a/conan/qasm/conandata.yml b/conan/qasm/conandata.yml index a13e4e302..128270ff4 100644 --- a/conan/qasm/conandata.yml +++ b/conan/qasm/conandata.yml @@ -1,5 +1,5 @@ sources: - hash: "f6d695fd9f18462e65f6290d05ccb4ccb371b288" + hash: "ec7731bf645240a597cd9ebb2c395b114f155ed2" requirements: - "gmp/6.3.0" - "mpfr/4.1.0" diff --git a/conan/qasm/conanfile.py b/conan/qasm/conanfile.py index 3df4f1cc7..95532601a 100644 --- a/conan/qasm/conanfile.py +++ b/conan/qasm/conanfile.py @@ -17,7 +17,7 @@ class QasmConan(ConanFile): name = "qasm" - version = "0.3.2" + version = "0.3.3" url = "https://github.com/openqasm/qe-qasm.git" settings = "os", "compiler", "build_type", "arch" options = {"shared": [True, False], "examples": [True, False]} diff --git a/conandata.yml b/conandata.yml index 24d31df0a..4b85ea0ce 100644 --- a/conandata.yml +++ b/conandata.yml @@ -7,4 +7,4 @@ requirements: - pybind11/2.11.1 - clang-tools-extra/17.0.5-0@ - llvm/17.0.5-0@ - - qasm/0.3.2@qss/stable + - qasm/0.3.3@qss/stable diff --git a/include/Conversion/QUIRToPulse/QUIRToPulse.h b/include/Conversion/QUIRToPulse/QUIRToPulse.h index 8eb26ffb8..f80a66d1f 100644 --- a/include/Conversion/QUIRToPulse/QUIRToPulse.h +++ b/include/Conversion/QUIRToPulse/QUIRToPulse.h @@ -127,7 +127,7 @@ struct QUIRToPulsePass mlir::func::FuncOp &mainFunc); // map of the hashed location of quir angle/duration ops to their converted // pulse ops - std::unordered_map + std::unordered_map classicalQUIROpLocToConvertedPulseOpMap; // port name to Port_CreateOp map diff --git a/include/Dialect/QUIR/Transforms/ExtractCircuits.h b/include/Dialect/QUIR/Transforms/ExtractCircuits.h index d22feace0..935482a16 100644 --- a/include/Dialect/QUIR/Transforms/ExtractCircuits.h +++ b/include/Dialect/QUIR/Transforms/ExtractCircuits.h @@ -25,11 +25,11 @@ #include "Utils/SymbolCacheAnalysis.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "llvm/ADT/SmallVector.h" - +#include #include namespace mlir::quir { @@ -49,14 +49,14 @@ struct ExtractCircuitsPass OpBuilder circuitBuilder); OpBuilder startCircuit(mlir::Location location, OpBuilder topLevelBuilder); void endCircuit(mlir::Operation *firstOp, mlir::Operation *lastOp, - OpBuilder topLevelBuilder, OpBuilder circuitBuilder, - llvm::SmallVector &eraseList); - void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder, - llvm::SmallVector &eraseList); + OpBuilder topLevelBuilder, OpBuilder circuitBuilder); + void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder); + uint64_t circuitCount = 0; qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr}; mlir::quir::CircuitOp currentCircuitOp = nullptr; + mlir::IRMapping currentCircuitMapper; mlir::quir::CallCircuitOp newCallCircuitOp; llvm::SmallVector inputTypes; @@ -68,6 +68,8 @@ struct ExtractCircuitsPass std::unordered_map circuitOperands; llvm::SmallVector originalResults; + std::set eraseConstSet; + std::set eraseOpSet; }; // struct ExtractCircuitsPass } // namespace mlir::quir diff --git a/include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h b/include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h index 61222f601..0551782ef 100644 --- a/include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h +++ b/include/Frontend/OpenQASM3/QUIRGenQASM3Visitor.h @@ -321,6 +321,9 @@ class QUIRGenQASM3Visitor : public BaseQASM3Visitor { mlir::Type getQUIRTypeFromDeclaration(const QASM::ASTDeclarationNode *); bool enableParametersWarningEmitted = false; + + /// Cached dummy value for error handling + mlir::Value voidValue; }; } // namespace qssc::frontend::openqasm3 diff --git a/include/Frontend/OpenQASM3/QUIRVariableBuilder.h b/include/Frontend/OpenQASM3/QUIRVariableBuilder.h index 49397e74c..f5085db00 100644 --- a/include/Frontend/OpenQASM3/QUIRVariableBuilder.h +++ b/include/Frontend/OpenQASM3/QUIRVariableBuilder.h @@ -68,7 +68,7 @@ class QUIRVariableBuilder { mlir::Value generateParameterLoad(mlir::Location location, llvm::StringRef variableName, - mlir::Value assignedValue); + double initialValue); /// Generate code for declaring an array (at the builder's current insertion /// point). diff --git a/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp b/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp index e6d7d33b4..c57889a53 100644 --- a/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp +++ b/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp @@ -152,7 +152,7 @@ void LoadPulseCalsPass::loadPulseCals(CallCircuitOp callCircuitOp, LLVM_DEBUG(llvm::dbgs() << "no pulse cal loading needed for " << op); assert((!op->hasTrait() and !op->hasTrait()) && - "unkown operation"); + "unknown operation"); } }); } diff --git a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp index 8b65fc517..3330bdc6e 100644 --- a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp +++ b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp @@ -100,8 +100,10 @@ void QUIRToPulsePass::runOnOperation() { moduleOp->walk([&](CallCircuitOp callCircOp) { if (isa(callCircOp->getParentOp())) return; + auto convertedPulseCallSequenceOp = convertCircuitToSequence(callCircOp, mainFunc, moduleOp); + if (!callCircOp->use_empty()) callCircOp->replaceAllUsesWith(convertedPulseCallSequenceOp); callCircOp->erase(); @@ -229,8 +231,9 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, auto *newDelayCyclesOp = builder.clone(*quirOp); newDelayCyclesOp->moveAfter(callCircuitOp); } else - assert(((isa(quirOp) or isa(quirOp) or - isa(quirOp))) && + assert(((isa(quirOp) || + isa(quirOp) || + isa(quirOp) || isa(quirOp))) && "quir op is not allowed in this pass."); }); @@ -251,6 +254,7 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, convertedPulseSequenceOp, convertedPulseSequenceOpArgs); convertedPulseCallSequenceOp->moveAfter(callCircuitOp); + return convertedPulseCallSequenceOp; } @@ -286,7 +290,7 @@ void QUIRToPulsePass::processCircuitArgs( } else if (argumentType.isa()) { auto *qubitOp = callCircuitOp.getOperand(cnt).getDefiningOp(); } else - llvm_unreachable("unkown circuit argument."); + llvm_unreachable("unknown circuit argument."); } } @@ -339,7 +343,7 @@ void QUIRToPulsePass::processPulseCalArgs( } else if (argumentType.isa()) { assert(argAttr[index].dyn_cast().getValue().str() == "angle" && - "unkown argument."); + "unknown argument."); assert(angleOperands.size() && "no angle operand found."); auto nextAngle = angleOperands.front(); LLVM_DEBUG(llvm::dbgs() << "angle argument "); @@ -350,7 +354,7 @@ void QUIRToPulsePass::processPulseCalArgs( } else if (argumentType.isa()) { assert(argAttr[index].dyn_cast().getValue().str() == "duration" && - "unkown argument."); + "unknown argument."); assert(durationOperands.size() && "no duration operand found."); auto nextDuration = durationOperands.front(); LLVM_DEBUG(llvm::dbgs() << "duration argument "); @@ -359,7 +363,7 @@ void QUIRToPulsePass::processPulseCalArgs( pulseCalSequenceArgs, builder); durationOperands.pop(); } else - llvm_unreachable("unkown argument type."); + llvm_unreachable("unknown argument type."); } } @@ -379,12 +383,13 @@ void QUIRToPulsePass::getQUIROpClassicalOperands( } for (auto operand : classicalOperands) - if (operand.getType().isa()) + if (operand.getType().isa() || + operand.getType().isa()) angleOperands.push(operand); else if (operand.getType().isa()) durationOperands.push(operand); else - llvm_unreachable("unkown operand."); + llvm_unreachable("unknown operand."); } void QUIRToPulsePass::processMixFrameOpArg( @@ -463,21 +468,38 @@ void QUIRToPulsePass::processAngleArg(Value nextAngleOperand, pulseCalSequenceArgs.push_back( convertedPulseSequenceOp .getArguments()[circuitArgToConvertedSequenceArgMap[circNum]]); - } else { - auto angleOp = nextAngleOperand.getDefiningOp(); - std::string const angleLocHash = - std::to_string(mlir::hash_value(angleOp->getLoc())); - if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) == + } else if (auto angleOp = + nextAngleOperand.getDefiningOp()) { + auto *op = angleOp.getOperation(); + if (classicalQUIROpLocToConvertedPulseOpMap.find(op) == classicalQUIROpLocToConvertedPulseOpMap.end()) { double const angleVal = angleOp.getAngleValueFromConstant().convertToDouble(); auto f64Angle = entryBuilder.create( angleOp.getLoc(), entryBuilder.getFloatAttr(entryBuilder.getF64Type(), llvm::APFloat(angleVal))); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = f64Angle; + classicalQUIROpLocToConvertedPulseOpMap[op] = f64Angle; } - pulseCalSequenceArgs.push_back( - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash]); + pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]); + } else if (auto paramOp = + nextAngleOperand.getDefiningOp()) { + auto *op = paramOp.getOperation(); + if (classicalQUIROpLocToConvertedPulseOpMap.find(op) == + classicalQUIROpLocToConvertedPulseOpMap.end()) { + + auto newParam = entryBuilder.create( + paramOp->getLoc(), entryBuilder.getF64Type(), + paramOp.getParameterName()); + if (paramOp->hasAttr("initialValue")) { + auto initAttr = paramOp->getAttr("initialValue").dyn_cast(); + if (initAttr) + newParam->setAttr("initialValue", initAttr); + } + + classicalQUIROpLocToConvertedPulseOpMap[op] = newParam; + } + + pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]); } } @@ -501,25 +523,23 @@ void QUIRToPulsePass::processDurationArg( TimeUnits::dt && "this pass only accepts durations with dt unit"); - if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) == + auto *op = durationOp.getOperation(); + if (classicalQUIROpLocToConvertedPulseOpMap.find(op) == classicalQUIROpLocToConvertedPulseOpMap.end()) { auto dur64 = entryBuilder.create( durationOp.getLoc(), entryBuilder.getIntegerAttr(entryBuilder.getI64Type(), uint64_t(durVal))); - classicalQUIROpLocToConvertedPulseOpMap[durLocHash] = dur64; + classicalQUIROpLocToConvertedPulseOpMap[op] = dur64; } - pulseCalSequenceArgs.push_back( - classicalQUIROpLocToConvertedPulseOpMap[durLocHash]); + pulseCalSequenceArgs.push_back(classicalQUIROpLocToConvertedPulseOpMap[op]); } } mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp, mlir::OpBuilder &builder) { assert(angleOp && "angle op is null"); - std::string const angleLocHash = - std::to_string(mlir::hash_value(angleOp->getLoc())); - if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) == + if (classicalQUIROpLocToConvertedPulseOpMap.find(angleOp) == classicalQUIROpLocToConvertedPulseOpMap.end()) { if (auto castOp = dyn_cast(angleOp)) { double const angleVal = @@ -528,12 +548,19 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp, castOp->getLoc(), builder.getFloatAttr(builder.getF64Type(), llvm::APFloat(angleVal))); f64Angle->moveAfter(castOp); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = f64Angle; + classicalQUIROpLocToConvertedPulseOpMap[angleOp] = f64Angle; } else if (auto castOp = dyn_cast(angleOp)) { - auto angleCastedOp = builder.create( - castOp->getLoc(), builder.getF64Type(), castOp.getRes()); - angleCastedOp->moveAfter(castOp); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp; + // Just convert to an f64 directly + auto newParam = builder.create( + angleOp->getLoc(), builder.getF64Type(), castOp.getParameterName()); + if (castOp->hasAttr("initialValue")) { + auto initAttr = castOp->getAttr("initialValue").dyn_cast(); + if (initAttr) + newParam->setAttr("initialValue", initAttr); + } + newParam->moveAfter(castOp); + + classicalQUIROpLocToConvertedPulseOpMap[angleOp] = newParam; } else if (auto castOp = dyn_cast(angleOp)) { auto castOpArg = castOp.getArg(); if (auto paramCastOp = @@ -541,28 +568,26 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp, auto angleCastedOp = builder.create( paramCastOp->getLoc(), builder.getF64Type(), paramCastOp.getRes()); angleCastedOp->moveAfter(paramCastOp); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp; + classicalQUIROpLocToConvertedPulseOpMap[angleOp] = angleCastedOp; } else if (auto constOp = dyn_cast(castOpArg.getDefiningOp())) { // if cast from float64 then use directly assert(constOp.getType() == builder.getF64Type() && "expected angle type to be float 64"); - classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = constOp; + classicalQUIROpLocToConvertedPulseOpMap[angleOp] = constOp; } else llvm_unreachable("castOp arg unknown"); } else llvm_unreachable("angleOp unknown"); } - return classicalQUIROpLocToConvertedPulseOpMap[angleLocHash]; + return classicalQUIROpLocToConvertedPulseOpMap[angleOp]; } mlir::Value QUIRToPulsePass::convertDurationToI64( mlir::quir::CallCircuitOp &callCircuitOp, Operation *durationOp, uint &cnt, mlir::OpBuilder &builder, mlir::func::FuncOp &mainFunc) { assert(durationOp && "duration op is null"); - std::string const durLocHash = - std::to_string(mlir::hash_value(durationOp->getLoc())); - if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) == + if (classicalQUIROpLocToConvertedPulseOpMap.find(durationOp) == classicalQUIROpLocToConvertedPulseOpMap.end()) { if (auto castOp = dyn_cast(durationOp)) { auto durVal = @@ -575,11 +600,11 @@ mlir::Value QUIRToPulsePass::convertDurationToI64( castOp->getLoc(), builder.getIntegerAttr(builder.getI64Type(), uint64_t(durVal))); I64Dur->moveAfter(castOp); - classicalQUIROpLocToConvertedPulseOpMap[durLocHash] = I64Dur; + classicalQUIROpLocToConvertedPulseOpMap[durationOp] = I64Dur; } else - llvm_unreachable("unkown duration op"); + llvm_unreachable("unknown duration op"); } - return classicalQUIROpLocToConvertedPulseOpMap[durLocHash]; + return classicalQUIROpLocToConvertedPulseOpMap[durationOp]; } mlir::pulse::Port_CreateOp diff --git a/lib/Dialect/Pulse/IR/PulseOps.cpp b/lib/Dialect/Pulse/IR/PulseOps.cpp index 6b2d46c22..3c9717037 100644 --- a/lib/Dialect/Pulse/IR/PulseOps.cpp +++ b/lib/Dialect/Pulse/IR/PulseOps.cpp @@ -17,6 +17,7 @@ #include "Dialect/Pulse/IR/PulseOps.h" #include "Dialect/Pulse/IR/PulseTraits.h" +#include "Dialect/QCS/IR/QCSOps.h" #include "Dialect/QUIR/IR/QUIROps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -356,8 +357,9 @@ LogicalResult verifyClassical_(SequenceOp op) { mlir::Operation *classicalOp = nullptr; WalkResult const result = op->walk([&](Operation *subOp) { if (isa(subOp) || isa(subOp) || - isa(subOp) || isa(subOp) || - isa(subOp) || isa(subOp) || + isa(subOp) || isa(subOp) || + isa(subOp) || isa(subOp) || + isa(subOp) || subOp->hasTrait() || subOp->hasTrait()) return WalkResult::advance(); diff --git a/lib/Dialect/Pulse/Transforms/Scheduling.cpp b/lib/Dialect/Pulse/Transforms/Scheduling.cpp index 8117ed369..547ba3524 100644 --- a/lib/Dialect/Pulse/Transforms/Scheduling.cpp +++ b/lib/Dialect/Pulse/Transforms/Scheduling.cpp @@ -112,6 +112,7 @@ void QuantumCircuitPulseSchedulingPass::scheduleAlap( opEnd = quantumCircuitSequenceOpBlock->rend(); opIt != opEnd; ++opIt) { auto &op = *opIt; + if (auto quantumGateCallSequenceOp = dyn_cast(op)) { // find quantum gate SequenceOp diff --git a/lib/Dialect/QUIR/IR/QUIROps.cpp b/lib/Dialect/QUIR/IR/QUIROps.cpp index 29bef7d33..efe3c9bfd 100644 --- a/lib/Dialect/QUIR/IR/QUIROps.cpp +++ b/lib/Dialect/QUIR/IR/QUIROps.cpp @@ -380,8 +380,9 @@ LogicalResult verifyClassical_(CircuitOp op) { mlir::Operation *classicalOp = nullptr; WalkResult const result = op->walk([&](Operation *subOp) { if (isa(subOp) || isa(subOp) || - isa(subOp) || isa(subOp) || - isa(subOp) || subOp->hasTrait() || + isa(subOp) || isa(subOp) || + isa(subOp) || isa(subOp) || + subOp->hasTrait() || subOp->hasTrait()) return WalkResult::advance(); classicalOp = subOp; diff --git a/lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp b/lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp index 3d9c02462..b0bf5df26 100644 --- a/lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp +++ b/lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp @@ -31,12 +31,12 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" @@ -93,6 +93,7 @@ OpBuilder ExtractCircuitsPass::startCircuit(Location location, topLevelBuilder.getFunctionType( /*inputs=*/ArrayRef(), /*results=*/ArrayRef())); + currentCircuitMapper = IRMapping(); currentCircuitOp.addEntryBlock(); symbolCache->addCallee(currentCircuitOp); @@ -107,35 +108,51 @@ OpBuilder ExtractCircuitsPass::startCircuit(Location location, return circuitBuilder; } -void ExtractCircuitsPass::addToCircuit( - Operation *currentOp, OpBuilder circuitBuilder, - llvm::SmallVector &eraseList) { +void ExtractCircuitsPass::addToCircuit(Operation *currentOp, + OpBuilder circuitBuilder) { - IRMapping mapper; // add operands to circuit input list for (auto operand : currentOp->getOperands()) { auto *defOp = operand.getDefiningOp(); auto search = circuitOperands.find(defOp); uint argumentIndex = 0; + mlir::Value mappedValue; if (search == circuitOperands.end()) { - argumentIndex = inputValues.size(); - inputValues.push_back(operand); - inputTypes.push_back(operand.getType()); - circuitOperands[defOp] = argumentIndex; - currentCircuitOp.getBody().addArgument(operand.getType(), - currentOp->getLoc()); - if (isa(defOp)) { - auto id = defOp->getAttrOfType("id").getInt(); - phyiscalIds.push_back(id); - argToId[argumentIndex] = id; + // Check if we should embed in the circuit + auto constantLike = (isa(defOp) || + isa(defOp)); + if (constantLike) { + // Don't clone/map if we already have + if (currentCircuitMapper.contains(operand)) + continue; + auto *newDefOp = circuitBuilder.clone(*defOp, currentCircuitMapper); + mappedValue = newDefOp->getResult(0); + // May be used multiple times so we must remove all users + // before erasing. + eraseConstSet.insert(defOp); + } else { + // Otherwise we add to the circuit signature + argumentIndex = inputValues.size(); + inputValues.push_back(operand); + inputTypes.push_back(operand.getType()); + circuitOperands[defOp] = argumentIndex; + currentCircuitOp.getBody().addArgument(operand.getType(), + currentOp->getLoc()); + + if (isa(defOp)) { + auto id = defOp->getAttrOfType("id").getInt(); + phyiscalIds.push_back(id); + argToId[argumentIndex] = id; + } + mappedValue = currentCircuitOp.getArgument(argumentIndex); } } else { argumentIndex = search->second; + mappedValue = currentCircuitOp.getArgument(argumentIndex); } - - mapper.map(operand, currentCircuitOp.getArgument(argumentIndex)); + currentCircuitMapper.map(operand, mappedValue); } - auto *newOp = circuitBuilder.clone(*currentOp, mapper); + auto *newOp = circuitBuilder.clone(*currentOp, currentCircuitMapper); outputTypes.append(newOp->getResultTypes().begin(), newOp->getResultTypes().end()); @@ -143,12 +160,12 @@ void ExtractCircuitsPass::addToCircuit( originalResults.append(currentOp->getResults().begin(), currentOp->getResults().end()); - eraseList.push_back(currentOp); + eraseOpSet.insert(currentOp); } -void ExtractCircuitsPass::endCircuit( - Operation *firstOp, Operation *lastOp, OpBuilder topLevelBuilder, - OpBuilder circuitBuilder, llvm::SmallVector &eraseList) { +void ExtractCircuitsPass::endCircuit(Operation *firstOp, Operation *lastOp, + OpBuilder topLevelBuilder, + OpBuilder circuitBuilder) { LLVM_DEBUG(llvm::dbgs() << "Ending circuit " << currentCircuitOp.getSymName() << "\n"); @@ -189,16 +206,6 @@ void ExtractCircuitsPass::endCircuit( assert(originalResults[cnt].use_empty() && "usage expected to be empty"); } - // erase operations - while (!eraseList.empty()) { - auto *op = eraseList.back(); - eraseList.pop_back(); - assert(op->use_empty() && "operation usage expected to be empty"); - LLVM_DEBUG(llvm::dbgs() << "Erasing: "); - LLVM_DEBUG(op->dump()); - op->erase(); - } - currentCircuitOp = nullptr; } @@ -212,7 +219,6 @@ void ExtractCircuitsPass::processRegion(mlir::Region ®ion, void ExtractCircuitsPass::processBlock(mlir::Block &block, OpBuilder topLevelBuilder, OpBuilder circuitBuilder) { - llvm::SmallVector eraseList; Operation *firstQuantumOp = nullptr; Operation *lastQuantumOp = nullptr; @@ -244,7 +250,7 @@ void ExtractCircuitsPass::processBlock(mlir::Block &block, circuitBuilder = startCircuit(firstQuantumOp->getLoc(), topLevelBuilder); } - addToCircuit(¤tOp, circuitBuilder, eraseList); + addToCircuit(¤tOp, circuitBuilder); continue; } if (terminatesCircuit(currentOp)) { @@ -252,7 +258,7 @@ void ExtractCircuitsPass::processBlock(mlir::Block &block, // progress there is an in progress circuit to be ended. if (currentCircuitOp) { endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder, - circuitBuilder, eraseList); + circuitBuilder); } // handle control flow by recursively calling processBlock for control @@ -262,10 +268,8 @@ void ExtractCircuitsPass::processBlock(mlir::Block &block, } } // End of block complete the circuit - if (currentCircuitOp) { - endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder, circuitBuilder, - eraseList); - } + if (currentCircuitOp) + endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder, circuitBuilder); } void ExtractCircuitsPass::runOnOperation() { @@ -284,6 +288,20 @@ void ExtractCircuitsPass::runOnOperation() { auto const builder = OpBuilder(mainFunc); processRegion(mainFunc.getRegion(), builder, builder); + + // erase operations + for (auto *op : eraseOpSet) { + LLVM_DEBUG(llvm::dbgs() << "Erasing: "); + LLVM_DEBUG(op->dump()); + op->erase(); + } + for (auto *op : eraseConstSet) { + assert(op->use_empty() && "operation usage expected to be empty"); + LLVM_DEBUG(llvm::dbgs() << "Erasing: "); + LLVM_DEBUG(op->dump()); + op->erase(); + } + } // runOnOperation llvm::StringRef ExtractCircuitsPass::getArgument() const { diff --git a/lib/Dialect/QUIR/Transforms/ReorderMeasurements.cpp b/lib/Dialect/QUIR/Transforms/ReorderMeasurements.cpp index bf0126ac6..f74304456 100644 --- a/lib/Dialect/QUIR/Transforms/ReorderMeasurements.cpp +++ b/lib/Dialect/QUIR/Transforms/ReorderMeasurements.cpp @@ -21,6 +21,7 @@ #include "Dialect/QUIR/Transforms/ReorderMeasurements.h" #include "Dialect/OQ3/IR/OQ3Ops.h" +#include "Dialect/QCS/IR/QCSOps.h" #include "Dialect/QUIR/IR/QUIRInterfaces.h" #include "Dialect/QUIR/IR/QUIROps.h" #include "Dialect/QUIR/IR/QUIRTraits.h" @@ -85,11 +86,14 @@ bool mayMoveVariableLoadOp(MeasureOp measureOp, bool mayMoveCastOp(MeasureOp measureOp, oq3::CastOp castOp, MoveListVec &moveList) { bool moveCastOp = false; - auto variableLoadOp = - dyn_cast(castOp.getArg().getDefiningOp()); - if (variableLoadOp) + + auto *definingOp = castOp.getArg().getDefiningOp(); + if (auto variableLoadOp = dyn_cast(definingOp)) moveCastOp = mayMoveVariableLoadOp(measureOp, variableLoadOp, moveList); - auto castMeasureOp = dyn_cast(castOp.getArg().getDefiningOp()); + else if (isa(definingOp)) + moveCastOp = true; + + auto castMeasureOp = dyn_cast(definingOp); if (castMeasureOp) moveCastOp = ((castMeasureOp != measureOp) && (castMeasureOp->isBeforeInBlock(measureOp) || @@ -170,6 +174,13 @@ struct ReorderMeasureAndNonMeasurePat : public OpRewritePattern { mayMoveVariableLoadOp(measureOp, variableLoadOp, moveList); } + // if the defining op is a parameter load op we are are safe + // to move + if (auto parameterLoadOp = dyn_cast(defOp)) { + moveOps = true; + moveList.push_back(parameterLoadOp); + } + auto castOp = dyn_cast(defOp); if (castOp) moveOps = mayMoveCastOp(measureOp, castOp, moveList); diff --git a/lib/Dialect/QUIR/Transforms/UnusedVariable.cpp b/lib/Dialect/QUIR/Transforms/UnusedVariable.cpp index 4c8ad09d4..b19ce10f3 100644 --- a/lib/Dialect/QUIR/Transforms/UnusedVariable.cpp +++ b/lib/Dialect/QUIR/Transforms/UnusedVariable.cpp @@ -23,74 +23,45 @@ #include "Dialect/OQ3/IR/OQ3Ops.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" -#include - using namespace mlir; using namespace quir; using namespace oq3; -namespace { -/// This pattern matches on variable declarations that are not marked 'output' -/// and are not followed by a use of the same variable, and removes them -struct UnusedVariablePat : public OpRewritePattern { - UnusedVariablePat(MLIRContext *context, mlir::SymbolUserMap &symbolUses) - : OpRewritePattern(context, /*benefit=*/1), - symbolUses(symbolUses) {} - mlir::SymbolUserMap &symbolUses; - LogicalResult - matchAndRewrite(DeclareVariableOp declOp, - mlir::PatternRewriter &rewriter) const override { +/// +/// \brief Entry point for the pass. +void UnusedVariablePass::runOnOperation() { + mlir::SymbolTableCollection symbolTable; + mlir::SymbolUserMap symbolUsers(symbolTable, getOperation()); + + getOperation()->walk([&](DeclareVariableOp declOp) { if (declOp.isOutputVariable()) - return failure(); + return mlir::WalkResult::advance(); // iterate through uses - for (auto *useOp : symbolUses.getUsers(declOp)) { + for (auto *useOp : symbolUsers.getUsers(declOp)) { if (auto useVariable = dyn_cast(useOp)) { if (!useVariable || !useVariable.use_empty()) - return failure(); + return mlir::WalkResult::advance(); } } // No uses found, so now we can erase all references (just stores) and the // declaration - for (auto *useOp : symbolUses.getUsers(declOp)) - rewriter.eraseOp(useOp); - - rewriter.eraseOp(declOp); - return success(); - } // matchAndRewrite - -}; // struct UnusedVariablePat -} // anonymous namespace - -/// -/// \brief Entry point for the pass. -void UnusedVariablePass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - mlir::GreedyRewriteConfig config; - mlir::SymbolTableCollection symbolTable; - mlir::SymbolUserMap symbolUsers(symbolTable, getOperation()); - - // use cheaper top-down traversal (in this case, bottom-up would not behave - // any differently) - config.useTopDownTraversal = true; - // Disable to improve performance - config.enableRegionSimplification = false; + for (auto *useOp : symbolUsers.getUsers(declOp)) + useOp->erase(); + ; - patterns.add(&getContext(), symbolUsers); + declOp->erase(); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) - signalPassFailure(); + return mlir::WalkResult::advance(); + }); } llvm::StringRef UnusedVariablePass::getArgument() const { diff --git a/lib/Frontend/OpenQASM3/QUIRGenQASM3Visitor.cpp b/lib/Frontend/OpenQASM3/QUIRGenQASM3Visitor.cpp index 1069ccaf5..506ab7036 100644 --- a/lib/Frontend/OpenQASM3/QUIRGenQASM3Visitor.cpp +++ b/lib/Frontend/OpenQASM3/QUIRGenQASM3Visitor.cpp @@ -839,11 +839,13 @@ ExpressionValueType QUIRGenQASM3Visitor::visit_(const ASTGateNode *node) { // must be a normal angle variable use if (!assign(pos, param->GetGateParamName())) { if (const auto *const ident = param->GetValueIdentifier()) { - pos = varHandler.generateVariableUse(getLocation(node), ident); - if (pos.getType() != builder.getType(64)) { - pos = circuitParentBuilder.create( - pos.getLoc(), builder.getType(64), pos); - } + + double initialValue = 0.0; + if (param->IsNumber()) + initialValue = param->AsDouble(); + + pos = varHandler.generateParameterLoad( + getLocation(node), ident->GetName(), initialValue); ssaOtherValues.push_back(pos); } else { reportError(node, mlir::DiagnosticSeverity::Error) @@ -1143,27 +1145,15 @@ void QUIRGenQASM3Visitor::visit(const ASTDeclarationNode *node) { case ASTTypeMPDecimal: case ASTTypeMPComplex: { switchCircuit(false, getLocation(node)); - auto variableType = varHandler.resolveQUIRVariableType(node); - auto valOrError = visitAndGetExpressionValue(node->GetExpression()); - varHandler.generateVariableDeclaration( - loc, idNode->GetName(), variableType, - node->GetModifierType() == QASM::ASTTypeInputModifier, - node->GetModifierType() == QASM::ASTTypeOutputModifier); - - if (!valOrError) { - assert(hasFailed && "visitAndGetExpressionValue returned error but did " - "not set state to failed."); - return; - } - auto val = valOrError.get(); + auto variableType = varHandler.resolveQUIRVariableType(node); // generate variable assignment so that they are reinitialized on every // shot. bool genVariableWithVal = true; - // parameter support currently limited to quir::AngleType + // parameter support currently limited to quir::AngleType/Float64Type if (node->GetModifierType() == QASM::ASTTypeInputModifier) { bool genParameter = true; if (!enableParameters) { @@ -1183,17 +1173,26 @@ void QUIRGenQASM3Visitor::visit(const ASTDeclarationNode *node) { genParameter = false; } - if (genParameter) { - auto load = - varHandler.generateParameterLoad(loc, idNode->GetName(), val); - varHandler.generateVariableAssignment(loc, idNode->GetName(), load); + if (genParameter) genVariableWithVal = false; - } } - if (genVariableWithVal) + if (genVariableWithVal) { + auto valOrError = visitAndGetExpressionValue(node->GetExpression()); + if (!valOrError) { + assert(hasFailed && "visitAndGetExpressionValue returned error but did " + "not set state to failed."); + return; + } + auto val = valOrError.get(); varHandler.generateVariableAssignment(loc, idNode->GetName(), val); + varHandler.generateVariableDeclaration( + loc, idNode->GetName(), variableType, + node->GetModifierType() == QASM::ASTTypeInputModifier, + node->GetModifierType() == QASM::ASTTypeOutputModifier); + } + return; } @@ -1442,7 +1441,7 @@ QUIRGenQASM3Visitor::handleAssign(const ASTBinaryOpNode *node) { "set state to failed."); return rightRefOrError; } - Value const rightRef = rightRefOrError.get(); + const Value rightRef = rightRefOrError.get(); return handleAssign(node, rightRef); } @@ -1553,6 +1552,7 @@ QUIRGenQASM3Visitor::visitAndGetExpressionValue(const ASTExpressionNode *node) { BaseQASM3Visitor::visit(node); if (expression) ssaOtherValues.push_back((expression.get())); + return std::move(expression); } @@ -2255,8 +2255,12 @@ QUIRGenQASM3Visitor::visit_(const ASTCastExpressionNode *node) { } mlir::Value QUIRGenQASM3Visitor::createVoidValue(mlir::Location location) { - return builder.create( - location, builder.getZeroAttr(builder.getI1Type())); + // Only create void value for error propagation reasons once + // to avoid adding many unused operations to the program. + if (!voidValue) + voidValue = builder.create( + location, builder.getZeroAttr(builder.getI1Type())); + return voidValue; } mlir::Value QUIRGenQASM3Visitor::createVoidValue(QASM::ASTBase const *node) { diff --git a/lib/Frontend/OpenQASM3/QUIRVariableBuilder.cpp b/lib/Frontend/OpenQASM3/QUIRVariableBuilder.cpp index faa63bb46..b4185fbe6 100644 --- a/lib/Frontend/OpenQASM3/QUIRVariableBuilder.cpp +++ b/lib/Frontend/OpenQASM3/QUIRVariableBuilder.cpp @@ -23,7 +23,6 @@ #include "Dialect/OQ3/IR/OQ3Ops.h" #include "Dialect/QCS/IR/QCSOps.h" -#include "Dialect/QUIR/IR/QUIRAttributes.h" #include "Dialect/QUIR/IR/QUIROps.h" #include "Dialect/QUIR/IR/QUIRTypes.h" @@ -54,6 +53,11 @@ void QUIRVariableBuilder::generateVariableDeclaration( mlir::Location location, llvm::StringRef variableName, mlir::Type type, bool isInputVariable, bool isOutputVariable) { + // Input variables are not used as parameter loads replace them + // for performance reasons. + // TODO: Replace many parameters with array accesses. + if (isInputVariable) + return; // variables are symbols and thus need to be placed directly in a surrounding // Op that contains a symbol table. mlir::OpBuilder::InsertionGuard const g(builder); @@ -73,8 +77,6 @@ void QUIRVariableBuilder::generateVariableDeclaration( lastDeclaration[surroundingModuleOp] = declareOp; // save this to insert after - if (isInputVariable) - declareOp.setInputAttr(builder.getUnitAttr()); if (isOutputVariable) declareOp.setOutputAttr(builder.getUnitAttr()); variables.emplace(variableName.str(), type); @@ -120,49 +122,19 @@ void QUIRVariableBuilder::generateParameterDeclaration( mlir::Value QUIRVariableBuilder::generateParameterLoad(mlir::Location location, llvm::StringRef variableName, - mlir::Value assignedValue) { + double initialValue) { - if (auto constantOp = mlir::dyn_cast( - assignedValue.getDefiningOp())) { - auto op = getClassicalBuilder().create( - location, builder.getType(64), - variableName.str()); - - double initialValue = 0.0; - - auto constFloatAttr = constantOp.getValue().dyn_cast(); - if (constFloatAttr) { - initialValue = constFloatAttr.getValueAsDouble(); - } else { - auto constAngleAttr = - constantOp.getValue().dyn_cast(); - if (constAngleAttr) - initialValue = constAngleAttr.getValue().convertToDouble(); - } + auto op = getClassicalBuilder().create( + location, builder.getType(64), variableName.str()); + // Only store initial value if it is not zero for performance reasons. + if (initialValue != 0.0) { mlir::FloatAttr const floatAttr = getClassicalBuilder().getF64FloatAttr(initialValue); op->setAttr("initialValue", floatAttr); - return op; - } - - // if the source is a arith::ConstantOp cast to angle - if (auto constantOp = mlir::dyn_cast( - assignedValue.getDefiningOp())) { - auto loadOp = getClassicalBuilder().create( - location, constantOp.getType(), variableName.str()); - double initialValue = 0.0; - auto constAttr = constantOp.getValue().dyn_cast(); - if (constAttr) - initialValue = constAttr.getValueAsDouble(); - mlir::FloatAttr const floatAttr = - getClassicalBuilder().getF64FloatAttr(initialValue); - loadOp->setAttr("initialValue", floatAttr); - return loadOp; } - llvm_unreachable( - "Unsupported defining value operation for parameter variable"); + return op; } void QUIRVariableBuilder::generateArrayVariableDeclaration( diff --git a/releasenotes/notes/update-parameter-handling-cfa04a0bd7250401.yaml b/releasenotes/notes/update-parameter-handling-cfa04a0bd7250401.yaml new file mode 100644 index 000000000..f4f63319d --- /dev/null +++ b/releasenotes/notes/update-parameter-handling-cfa04a0bd7250401.yaml @@ -0,0 +1,11 @@ +--- +features: + - | + Handling of ``qcs.parameter_load`` operations has been modified to be more direct + with reads straight from the angle variables. This brings significant performance enhancements + as a large number of MLIR operations have been removed. The consequence is that if an OpenQASM 3 + input parameter value is written to this value will not be dynamically resolved. This could be + fixed in later versions of the compiler by using memref like semantics for parameters. +fixes: + - | + Significant performance enhancements for both constant and parameter gate angles. diff --git a/test/Dialect/QUIR/Transforms/extract-circuits.mlir b/test/Dialect/QUIR/Transforms/extract-circuits.mlir index 00f6a90bc..072f4d115 100644 --- a/test/Dialect/QUIR/Transforms/extract-circuits.mlir +++ b/test/Dialect/QUIR/Transforms/extract-circuits.mlir @@ -18,8 +18,8 @@ module { return } // CHECK: quir.circuit @circuit_0 - // CHECK: quir.delay %arg0, (%arg1) - // CHECK: %0:2 = quir.measure(%arg2, %arg3) + // CHECK: quir.delay %dur, (%arg0) + // CHECK: %0:2 = quir.measure(%arg1, %arg2) // CHECK: quir.return %0#0, %0#1 : i1, i1 // CHECK: quir.circuit @circuit_1 // CHECK: quir.call_gate @x(%arg0) @@ -47,7 +47,7 @@ module { %4:2 = quir.measure(%0, %2) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) // CHECK-NOT: quir.delay %dur_0, (%1) : !quir.duration
, (!quir.qubit<1>) -> () // CHECK-NOT: %4:2 = quir.measure(%0, %2) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) - // CHECK: %4:2 = quir.call_circuit @circuit_0(%dur_0, %1, %0, %2) + // CHECK: %4:2 = quir.call_circuit @circuit_0(%1, %0, %2) qcs.parallel_control_flow { // CHECK: qcs.parallel_control_flow scf.if %4#0 { diff --git a/test/Dialect/QUIR/Transforms/reorder-measurements.mlir b/test/Dialect/QUIR/Transforms/reorder-measurements.mlir index da18a489d..602a0ca88 100644 --- a/test/Dialect/QUIR/Transforms/reorder-measurements.mlir +++ b/test/Dialect/QUIR/Transforms/reorder-measurements.mlir @@ -28,10 +28,13 @@ func.func @three(%c : memref<1xi1>, %ind : index, %angle_0 : !quir.angle<64>) { quir.call_gate @rz(%q1, %angle_0) : (!quir.qubit<1>, !quir.angle<64>) -> () %res1 = quir.measure(%q1) : (!quir.qubit<1>) -> (i1) memref.store %res1, %c[%ind] : memref<1xi1> - quir.call_gate @rz(%q2, %angle_0) : (!quir.qubit<1>, !quir.angle<64>) -> () + %angle_1 = "qcs.parameter_load"() {parameter_name = "test"} : () -> !quir.angle<64> + quir.call_gate @rz(%q2, %angle_1) : (!quir.qubit<1>, !quir.angle<64>) -> () quir.call_gate @sx(%q2) : (!quir.qubit<1>) -> () - quir.call_gate @rz(%q2, %angle_0) : (!quir.qubit<1>, !quir.angle<64>) -> () + %angle_2 = quir.constant #quir.angle<3.0> : !quir.angle<64> + quir.call_gate @rz(%q2, %angle_2) : (!quir.qubit<1>, !quir.angle<64>) -> () %res2 = quir.measure(%q2) : (!quir.qubit<1>) -> (i1) +// CHECK: {{.*}} = quir.constant #quir.angle<3.000000e+00> : !quir.angle<64> // CHECK: [[Q00:%.*]] = quir.declare_qubit {id = 0 : i32} : !quir.qubit<1> // CHECK: [[Q01:%.*]] = quir.declare_qubit {id = 1 : i32} : !quir.qubit<1> // CHECK: [[Q02:%.*]] = quir.declare_qubit {id = 2 : i32} : !quir.qubit<1> @@ -41,7 +44,8 @@ func.func @three(%c : memref<1xi1>, %ind : index, %angle_0 : !quir.angle<64>) { // CHECK-NEXT: quir.call_gate @rz([[Q01]], {{%.*}}) : (!quir.qubit<1>, !quir.angle<64>) -> () // CHECK-NEXT: quir.call_gate @sx([[Q01]]) : (!quir.qubit<1>) -> () // CHECK-NEXT: quir.call_gate @rz([[Q01]], {{%.*}}) : (!quir.qubit<1>, !quir.angle<64>) -> () -// CHECK-NEXT: quir.call_gate @rz([[Q02]], {{%.*}}) : (!quir.qubit<1>, !quir.angle<64>) -> () +// CHECK-NEXT: [[ANGLE:%.*]] = qcs.parameter_load "test" : !quir.angle<64> +// CHECK-NEXT: quir.call_gate @rz([[Q02]], [[ANGLE]]) : (!quir.qubit<1>, !quir.angle<64>) -> () // CHECK-NEXT: quir.call_gate @sx([[Q02]]) : (!quir.qubit<1>) -> () // CHECK-NEXT: quir.call_gate @rz([[Q02]], {{%.*}}) : (!quir.qubit<1>, !quir.angle<64>) -> () // CHECK-NEXT: [[RES00:%.*]] = quir.measure([[Q00]]) : (!quir.qubit<1>) -> i1 diff --git a/test/Frontend/OpenQASM3/input-output-variables.qasm b/test/Frontend/OpenQASM3/input-output-variables.qasm index 6304b97a8..86f8485ad 100644 --- a/test/Frontend/OpenQASM3/input-output-variables.qasm +++ b/test/Frontend/OpenQASM3/input-output-variables.qasm @@ -1,6 +1,5 @@ OPENQASM 3.0; // RUN: qss-compiler -X=qasm --emit=ast-pretty %s | FileCheck %s --match-full-lines --check-prefix AST-PRETTY -// RUN: qss-compiler -X=qasm --emit=mlir %s --enable-parameters=false | FileCheck %s --match-full-lines --check-prefix MLIR // RUN: (! qss-compiler -X=qasm --emit=mlir --enable-parameters --enable-circuits-from-qasm %s 2>&1 ) | FileCheck %s --check-prefix CIRCUITS // @@ -28,12 +27,10 @@ input int basis; // CIRCUITS: error: Input parameter basis type error. Input parameters must be angle or float[64]. // AST-PRETTY: DeclarationNode(type=ASTTypeBitset, CBitNode(name=flags, bits=32), inputVariable) -// MLIR-DAG: oq3.declare_variable {input} @flags : !quir.cbit<32> input bit[32] flags; // CIRCUITS: error: Input parameter flags type error. Input parameters must be angle or float[64]. // AST-PRETTY: DeclarationNode(type=ASTTypeBitset, CBitNode(name=result, bits=1), outputVariable) -// MLIR-DAG: oq3.declare_variable {output} @result : !quir.cbit<1> output bit result; // TODO diff --git a/test/Frontend/OpenQASM3/input-parameters-if.qasm b/test/Frontend/OpenQASM3/input-parameters-if.qasm index 20eb19a82..18d77904e 100644 --- a/test/Frontend/OpenQASM3/input-parameters-if.qasm +++ b/test/Frontend/OpenQASM3/input-parameters-if.qasm @@ -25,7 +25,7 @@ bit result; gate x q { } gate rz(phi) q { } -input angle theta = 3.141; +input angle theta; x $2; rz(theta) $2; @@ -50,8 +50,7 @@ is_excited = measure $2; // CHECK: [[QUBIT2:%.*]] = quir.declare_qubit {id = 2 : i32} : !quir.qubit<1> // CHECK: [[QUBIT3:%.*]] = quir.declare_qubit {id = 3 : i32} : !quir.qubit<1> -// CHECK: [[PARAM:%.*]] = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.141000e+00 : f64} -// CHECK: oq3.variable_assign @theta : !quir.angle<64> = [[PARAM]] +// CHECK: [[PARAM:%.*]] = qcs.parameter_load "theta" : !quir.angle<64> // CHECK: [[EXCITED:%.*]] = oq3.variable_load @is_excited : !quir.cbit<1> // CHECK: [[CONST:%[0-9a-z_]+]] = arith.constant 1 : i32 @@ -68,7 +67,7 @@ if (is_excited == 1) { // CHECK: [[COND1:%.*]] = arith.cmpi eq, [[OTHERCAST]], [[CONST]] : i32 // CHECK: scf.if [[COND1]] { if (other == 1){ -// CHECK: [[THETA:%.*]] = oq3.variable_load @theta : !quir.angle<64> +// CHECK: [[THETA:%.*]] = qcs.parameter_load "theta" : !quir.angle<64> // CHECK: quir.call_circuit @circuit_2([[QUBIT2]], [[THETA]]) : (!quir.qubit<1>, !quir.angle<64>) -> () x $2; rz(theta) $2; diff --git a/test/Frontend/OpenQASM3/input-parameters-while.qasm b/test/Frontend/OpenQASM3/input-parameters-while.qasm index d3a73228f..317786f07 100644 --- a/test/Frontend/OpenQASM3/input-parameters-while.qasm +++ b/test/Frontend/OpenQASM3/input-parameters-while.qasm @@ -58,7 +58,6 @@ bit is_excited; // CHECK: func.func @main() -> i32 { // CHECK: scf.for %arg0 = %c0 to %c1000 step %c1 { -// CHECK: {{.*}} = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.141000e+00 : f64} // CHECK: [[QUBIT:%.*]] = quir.declare_qubit {id = 0 : i32} : !quir.qubit<1> // CHECK: scf.while : () -> () { // CHECK: [[N:%.*]] = oq3.variable_load @n : i32 @@ -76,7 +75,7 @@ while (n != 0) { // CHECK: scf.if [[COND2]] { if (is_excited) { - // CHECK: [[THETA:%.*]] = oq3.variable_load @theta : !quir.angle<64> + // CHECK: [[THETA:%.*]] = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.141000e+00 : f64} // CHECK: quir.call_circuit @circuit_2([[QUBIT]], [[THETA]]) : (!quir.qubit<1>, !quir.angle<64>) -> () // CHECK: } h $0; diff --git a/test/Frontend/OpenQASM3/input-parameters.qasm b/test/Frontend/OpenQASM3/input-parameters.qasm index f057a31a4..55879b660 100644 --- a/test/Frontend/OpenQASM3/input-parameters.qasm +++ b/test/Frontend/OpenQASM3/input-parameters.qasm @@ -1,5 +1,5 @@ OPENQASM 3; -// RUN: qss-compiler -X=qasm --emit=mlir --enable-parameters --enable-circuits-from-qasm %s | FileCheck %s --check-prefixes=CHECK,CHECK-XX +// RUN: qss-compiler -X=qasm --emit=mlir --enable-parameters --enable-circuits-from-qasm %s | FileCheck %s --check-prefixes=CHECK // // This code is part of Qiskit. @@ -66,16 +66,10 @@ c = measure $0; // CHECK: %1 = quir.declare_qubit {id = 2 : i32} : !quir.qubit<1> // CHECK: %2 = qcs.parameter_load "theta" : !quir.angle<64> {initialValue = 3.141000e+00 : f64} -// CHECK: oq3.variable_assign @theta : !quir.angle<64> = %2 -// CHECK: %3 = qcs.parameter_load "theta2" : f64 {initialValue = 1.560000e+00 : f64} -// CHECK: oq3.variable_assign @theta2 : f64 = %3 -// CHECK-XX: quir.reset %0 : !quir.qubit<1> -// CHECK-NOT: oq3.variable_assign @theta : !quir.angle<64> = %angle - -// CHECK: quir.call_circuit @circuit_0(%0, %4) : (!quir.qubit<1>, !quir.angle<64>) -> () -// CHECK: %6 = quir.call_circuit @circuit_1(%0) : (!quir.qubit<1>) -> i1 -// CHECK: oq3.cbit_assign_bit @b<1> [0] : i1 = %6 - -// CHECK: %7 = oq3.variable_load @theta2 : f64 -// CHECK: %8 = "oq3.cast"(%7) : (f64) -> !quir.angle<64> -// CHECK: quir.call_circuit @circuit_2(%0, %8) : (!quir.qubit<1>, !quir.angle<64>) -> () + +// CHECK: quir.call_circuit @circuit_0(%0, %2) : (!quir.qubit<1>, !quir.angle<64>) -> () +// CHECK: %4 = quir.call_circuit @circuit_1(%0) : (!quir.qubit<1>) -> i1 +// CHECK: oq3.cbit_assign_bit @b<1> [0] : i1 = %4 + +// CHECK: %5 = qcs.parameter_load "theta2" : !quir.angle<64> {initialValue = 1.560000e+00 : f64} +// CHECK: quir.call_circuit @circuit_2(%0, %5) : (!quir.qubit<1>, !quir.angle<64>) -> () diff --git a/test/unittest/quir-dialect.cpp b/test/unittest/quir-dialect.cpp index c3110bcb0..563fcd387 100644 --- a/test/unittest/quir-dialect.cpp +++ b/test/unittest/quir-dialect.cpp @@ -36,13 +36,13 @@ namespace { class QUIRDialect : public ::testing::Test { protected: mlir::MLIRContext ctx; - mlir::UnknownLoc unkownLoc; + mlir::UnknownLoc unknownLoc; mlir::ModuleOp rootModule; mlir::OpBuilder builder; QUIRDialect() - : unkownLoc(mlir::UnknownLoc::get(&ctx)), - rootModule(mlir::ModuleOp::create(unkownLoc)), builder(rootModule) { + : unknownLoc(mlir::UnknownLoc::get(&ctx)), + rootModule(mlir::ModuleOp::create(unknownLoc)), builder(rootModule) { mlir::DialectRegistry registry; registry.insert(); ctx.appendDialectRegistry(registry); @@ -55,10 +55,10 @@ class QUIRDialect : public ::testing::Test { TEST_F(QUIRDialect, CPTPOpTrait) { auto declareQubitOp = builder.create( - unkownLoc, builder.getType(1), + unknownLoc, builder.getType(1), builder.getIntegerAttr(builder.getI32Type(), 0)); auto reset = builder.create( - unkownLoc, mlir::ValueRange{declareQubitOp.getResult()}); + unknownLoc, mlir::ValueRange{declareQubitOp.getResult()}); EXPECT_FALSE(declareQubitOp->hasTrait()); EXPECT_FALSE(declareQubitOp->hasTrait()); @@ -73,10 +73,10 @@ TEST_F(QUIRDialect, CPTPOpTrait) { TEST_F(QUIRDialect, UnitaryOpTrait) { auto declareQubitOp = builder.create( - unkownLoc, builder.getType(1), + unknownLoc, builder.getType(1), builder.getIntegerAttr(builder.getI32Type(), 0)); auto barrier = builder.create( - unkownLoc, mlir::ValueRange{declareQubitOp.getResult()}); + unknownLoc, mlir::ValueRange{declareQubitOp.getResult()}); EXPECT_TRUE(barrier->hasTrait()); EXPECT_FALSE(barrier->hasTrait()); @@ -88,11 +88,11 @@ TEST_F(QUIRDialect, UnitaryOpTrait) { TEST_F(QUIRDialect, MeasureSideEffects) { auto qubitDecl = builder.create( - unkownLoc, builder.getType(1), + unknownLoc, builder.getType(1), builder.getIntegerAttr(builder.getI32Type(), 0)); auto measureOp = builder.create( - unkownLoc, builder.getI1Type(), qubitDecl.getRes()); + unknownLoc, builder.getI1Type(), qubitDecl.getRes()); EXPECT_TRUE(measureOp);