diff --git a/yarn-project/prover-client/src/mocks/test_context.ts b/yarn-project/prover-client/src/mocks/test_context.ts index 222af08400d..dbb9c8e5033 100644 --- a/yarn-project/prover-client/src/mocks/test_context.ts +++ b/yarn-project/prover-client/src/mocks/test_context.ts @@ -78,18 +78,22 @@ export class TestContext { const publicKernel = new RealPublicKernelCircuitSimulator(new WASMSimulator()); const telemetry = new NoopTelemetryClient(); - let actualDb: MerkleTreeWriteOperations; + // Separated dbs for public processor and prover - see public_processor for context + let publicDb: MerkleTreeWriteOperations; + let proverDb: MerkleTreeWriteOperations; if (worldState === 'native') { const ws = await NativeWorldStateService.tmp(); - actualDb = await ws.fork(); + publicDb = await ws.fork(); + proverDb = await ws.fork(); } else { const ws = await MerkleTrees.new(openTmpStore(), telemetry); - actualDb = await ws.getLatest(); + publicDb = await ws.getLatest(); + proverDb = await ws.getLatest(); } const processor = PublicProcessor.create( - actualDb, + publicDb, publicExecutor, publicKernel, globalVariables, @@ -122,7 +126,7 @@ export class TestContext { } const queue = new MemoryProvingQueue(telemetry); - const orchestrator = new ProvingOrchestrator(actualDb, queue, telemetry); + const orchestrator = new ProvingOrchestrator(proverDb, queue, telemetry); const agent = new ProverAgent(localProver, proverCount); queue.start(); @@ -134,7 +138,7 @@ export class TestContext { processor, simulationProvider, globalVariables, - actualDb, + proverDb, localProver, agent, orchestrator, diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_multi_public_functions.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_multi_public_functions.test.ts index b3954917628..e805a15dd3b 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_multi_public_functions.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_multi_public_functions.test.ts @@ -1,4 +1,4 @@ -import { mockTx } from '@aztec/circuit-types'; +import { EmptyTxValidator, mockTx } from '@aztec/circuit-types'; import { times } from '@aztec/foundation/collection'; import { createDebugLogger } from '@aztec/foundation/log'; import { getVKTreeRoot } from '@aztec/noir-protocol-circuits-types'; @@ -43,10 +43,19 @@ describe('prover/orchestrator/public-functions', () => { context.orchestrator.startNewEpoch(1, 1); await context.orchestrator.startNewBlock(numTransactions, context.globalVariables, []); - const [processed, failed] = await context.processPublicFunctions(txs, numTransactions, context.epochProver); + const [processed, failed] = await context.processPublicFunctions( + txs, + numTransactions, + undefined, + new EmptyTxValidator(), + ); expect(processed.length).toBe(numTransactions); expect(failed.length).toBe(0); + for (const tx of processed) { + await context.orchestrator.addNewTx(tx); + } + const block = await context.orchestrator.setBlockCompleted(); await context.orchestrator.finaliseEpoch(); diff --git a/yarn-project/prover-node/src/job/epoch-proving-job.ts b/yarn-project/prover-node/src/job/epoch-proving-job.ts index 6257f0be369..753a19717f9 100644 --- a/yarn-project/prover-node/src/job/epoch-proving-job.ts +++ b/yarn-project/prover-node/src/job/epoch-proving-job.ts @@ -4,12 +4,21 @@ import { type L1ToL2MessageSource, type L2Block, type L2BlockSource, + MerkleTreeId, type MerkleTreeWriteOperations, type ProcessedTx, type ProverCoordination, type Tx, type TxHash, } from '@aztec/circuit-types'; +import { + KernelCircuitPublicInputs, + MAX_TOTAL_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, + NULLIFIER_SUBTREE_HEIGHT, + PUBLIC_DATA_SUBTREE_HEIGHT, + PublicDataTreeLeaf, +} from '@aztec/circuits.js'; +import { padArrayEnd } from '@aztec/foundation/collection'; import { createDebugLogger } from '@aztec/foundation/log'; import { promiseWithResolvers } from '@aztec/foundation/promise'; import { Timer } from '@aztec/foundation/timer'; @@ -110,6 +119,11 @@ export class EpochProvingJob { uuid: this.uuid, }); + if (txCount > txs.length) { + // If this block has a padding tx, ensure that the public processor's db has its state + await this.addPaddingTxState(); + } + // Mark block as completed and update archive tree await this.prover.setBlockCompleted(block.header); previousHeader = block.header; @@ -177,6 +191,28 @@ export class EpochProvingJob { return processedTxs; } + + private async addPaddingTxState() { + const emptyKernelOutput = KernelCircuitPublicInputs.empty(); + await this.db.appendLeaves(MerkleTreeId.NOTE_HASH_TREE, emptyKernelOutput.end.noteHashes); + await this.db.batchInsert( + MerkleTreeId.NULLIFIER_TREE, + emptyKernelOutput.end.nullifiers.map(n => n.toBuffer()), + NULLIFIER_SUBTREE_HEIGHT, + ); + const allPublicDataWrites = padArrayEnd( + emptyKernelOutput.end.publicDataUpdateRequests.map( + ({ leafSlot, newValue }) => new PublicDataTreeLeaf(leafSlot, newValue), + ), + PublicDataTreeLeaf.empty(), + MAX_TOTAL_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, + ); + await this.db.batchInsert( + MerkleTreeId.PUBLIC_DATA_TREE, + allPublicDataWrites.map(x => x.toBuffer()), + PUBLIC_DATA_SUBTREE_HEIGHT, + ); + } } export type EpochProvingJobState = diff --git a/yarn-project/prover-node/src/prover-node.test.ts b/yarn-project/prover-node/src/prover-node.test.ts index bca781310f4..3b1f88a1215 100644 --- a/yarn-project/prover-node/src/prover-node.test.ts +++ b/yarn-project/prover-node/src/prover-node.test.ts @@ -378,13 +378,14 @@ describe('prover-node', () => { protected override doCreateEpochProvingJob( epochNumber: bigint, _blocks: L2Block[], - db: MerkleTreeWriteOperations, + publicDb: MerkleTreeWriteOperations, + _proverDb: MerkleTreeWriteOperations, _publicProcessorFactory: PublicProcessorFactory, cleanUp: (job: EpochProvingJob) => Promise, ): EpochProvingJob { const job = mock({ getState: () => 'processing', run: () => Promise.resolve() }); job.getId.mockReturnValue(jobs.length.toString()); - jobs.push({ epochNumber, job, cleanUp, db }); + jobs.push({ epochNumber, job, cleanUp, db: publicDb }); return job; } diff --git a/yarn-project/prover-node/src/prover-node.ts b/yarn-project/prover-node/src/prover-node.ts index 3582d357f28..9669a1c3983 100644 --- a/yarn-project/prover-node/src/prover-node.ts +++ b/yarn-project/prover-node/src/prover-node.ts @@ -232,7 +232,10 @@ export class ProverNode implements ClaimsMonitorHandler, EpochMonitorHandler { // Fast forward world state to right before the target block and get a fork this.log.verbose(`Creating proving job for epoch ${epochNumber} for block range ${fromBlock} to ${toBlock}`); await this.worldState.syncImmediate(fromBlock - 1); - const db = await this.worldState.fork(fromBlock - 1); + // NB: separated the dbs as both a block builder and public processor need to track and update tree state + // see public_processor.ts for context + const publicDb = await this.worldState.fork(fromBlock - 1); + const proverDb = await this.worldState.fork(fromBlock - 1); // Create a processor using the forked world state const publicProcessorFactory = new PublicProcessorFactory( @@ -242,11 +245,12 @@ export class ProverNode implements ClaimsMonitorHandler, EpochMonitorHandler { ); const cleanUp = async () => { - await db.close(); + await publicDb.close(); + await proverDb.close(); this.jobs.delete(job.getId()); }; - const job = this.doCreateEpochProvingJob(epochNumber, blocks, db, publicProcessorFactory, cleanUp); + const job = this.doCreateEpochProvingJob(epochNumber, blocks, publicDb, proverDb, publicProcessorFactory, cleanUp); this.jobs.set(job.getId(), job); return job; } @@ -255,15 +259,16 @@ export class ProverNode implements ClaimsMonitorHandler, EpochMonitorHandler { protected doCreateEpochProvingJob( epochNumber: bigint, blocks: L2Block[], - db: MerkleTreeWriteOperations, + publicDb: MerkleTreeWriteOperations, + proverDb: MerkleTreeWriteOperations, publicProcessorFactory: PublicProcessorFactory, cleanUp: () => Promise, ) { return new EpochProvingJob( - db, + publicDb, epochNumber, blocks, - this.prover.createEpochProver(db), + this.prover.createEpochProver(proverDb), publicProcessorFactory, this.publisher, this.l2BlockSource, diff --git a/yarn-project/sequencer-client/src/sequencer/sequencer.ts b/yarn-project/sequencer-client/src/sequencer/sequencer.ts index 21ae7179c6b..1ef6a270955 100644 --- a/yarn-project/sequencer-client/src/sequencer/sequencer.ts +++ b/yarn-project/sequencer-client/src/sequencer/sequencer.ts @@ -430,16 +430,23 @@ export class Sequencer { const numRealTxs = validTxs.length; const blockSize = Math.max(2, numRealTxs); - const fork = await this.worldState.fork(); + // NB: separating the dbs because both should update the state + const publicProcessorFork = await this.worldState.fork(); + const orchestratorFork = await this.worldState.fork(); try { // We create a fresh processor each time to reset any cached state (eg storage writes) - const processor = this.publicProcessorFactory.create(fork, historicalHeader, newGlobalVariables); + const processor = this.publicProcessorFactory.create(publicProcessorFork, historicalHeader, newGlobalVariables); const blockBuildingTimer = new Timer(); - const blockBuilder = this.blockBuilderFactory.create(fork); + const blockBuilder = this.blockBuilderFactory.create(orchestratorFork); await blockBuilder.startNewBlock(blockSize, newGlobalVariables, l1ToL2Messages); const [publicProcessorDuration, [processedTxs, failedTxs]] = await elapsed(() => - processor.process(validTxs, blockSize, blockBuilder, this.txValidatorFactory.validatorForProcessedTxs(fork)), + processor.process( + validTxs, + blockSize, + blockBuilder, + this.txValidatorFactory.validatorForProcessedTxs(publicProcessorFork), + ), ); if (failedTxs.length > 0) { const failedTxData = failedTxs.map(fail => fail.tx); @@ -510,7 +517,8 @@ export class Sequencer { throw err; } } finally { - await fork.close(); + await publicProcessorFork.close(); + await orchestratorFork.close(); } } diff --git a/yarn-project/simulator/src/public/public_processor.ts b/yarn-project/simulator/src/public/public_processor.ts index 3f58c4d43db..4e0cda3b947 100644 --- a/yarn-project/simulator/src/public/public_processor.ts +++ b/yarn-project/simulator/src/public/public_processor.ts @@ -1,5 +1,6 @@ import { type FailedTx, + MerkleTreeId, type MerkleTreeWriteOperations, NestedProcessReturnValues, type ProcessedTx, @@ -15,10 +16,14 @@ import { type GlobalVariables, type Header, MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, + MAX_TOTAL_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, + NULLIFIER_SUBTREE_HEIGHT, PROTOCOL_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, + PUBLIC_DATA_SUBTREE_HEIGHT, + PublicDataTreeLeaf, PublicDataUpdateRequest, } from '@aztec/circuits.js'; -import { times } from '@aztec/foundation/collection'; +import { padArrayEnd, times } from '@aztec/foundation/collection'; import { createDebugLogger } from '@aztec/foundation/log'; import { Timer } from '@aztec/foundation/timer'; import { ProtocolContractAddress } from '@aztec/protocol-contracts'; @@ -183,6 +188,42 @@ export class PublicProcessor { if (processedTxHandler) { await processedTxHandler.addNewTx(processedTx); } + // Update the state so that the next tx in the loop has the correct .startState + // NB: before this change, all .startStates were actually incorrect, but the issue was never caught because we either: + // a) had only 1 tx with public calls per block, so this loop had len 1 + // b) always had a txHandler with the same db passed to it as this.db, which updated the db in buildBaseRollupHints in this loop + // To see how this ^ happens, move back to one shared db in test_context and run orchestrator_multi_public_functions.test.ts + // The below is taken from buildBaseRollupHints: + await this.db.appendLeaves(MerkleTreeId.NOTE_HASH_TREE, processedTx.data.end.noteHashes); + try { + await this.db.batchInsert( + MerkleTreeId.NULLIFIER_TREE, + processedTx.data.end.nullifiers.map(n => n.toBuffer()), + NULLIFIER_SUBTREE_HEIGHT, + ); + } catch (error) { + if (txValidator) { + // Ideally the validator has already caught this above, but just in case: + throw new Error(`Transaction ${processedTx.hash} invalid after processing public functions`); + } else { + // We have no validator and assume this call should blindly process txs with duplicates being caught later + this.log.warn(`Detected duplicate nullifier after public processing for: ${processedTx.hash}.`); + } + } + + const allPublicDataUpdateRequests = padArrayEnd( + processedTx.finalPublicDataUpdateRequests, + PublicDataUpdateRequest.empty(), + MAX_TOTAL_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, + ); + const allPublicDataWrites = allPublicDataUpdateRequests.map( + ({ leafSlot, newValue }) => new PublicDataTreeLeaf(leafSlot, newValue), + ); + await this.db.batchInsert( + MerkleTreeId.PUBLIC_DATA_TREE, + allPublicDataWrites.map(x => x.toBuffer()), + PUBLIC_DATA_SUBTREE_HEIGHT, + ); result.push(processedTx); returns = returns.concat(returnValues ?? []); } catch (err: any) {