Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: ServerCircuitProver return values #9391

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 10 additions & 24 deletions yarn-project/bb-prover/src/prover/bb_prover.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ import {
type KernelCircuitPublicInputs,
type MergeRollupInputs,
NESTED_RECURSIVE_PROOF_LENGTH,
type ParityPublicInputs,
type PrivateBaseRollupInputs,
type PrivateKernelEmptyInputData,
PrivateKernelEmptyInputs,
Proof,
type PublicBaseRollupInputs,
RECURSIVE_PROOF_LENGTH,
RecursiveProof,
RootParityInput,
type RootParityInputs,
type RootRollupInputs,
type RootRollupPublicInputs,
Expand All @@ -45,7 +45,6 @@ import { createDebugLogger } from '@aztec/foundation/log';
import { BufferReader } from '@aztec/foundation/serialize';
import { Timer } from '@aztec/foundation/timer';
import {
ProtocolCircuitVkIndexes,
ProtocolCircuitVks,
ServerCircuitArtifacts,
type ServerProtocolArtifact,
Expand All @@ -69,7 +68,6 @@ import {
convertRootParityOutputsFromWitnessMap,
convertRootRollupInputsToWitnessMap,
convertRootRollupOutputsFromWitnessMap,
getVKSiblingPath,
} from '@aztec/noir-protocol-circuits-types';
import { NativeACVMSimulator } from '@aztec/simulator';
import { Attributes, type TelemetryClient, trackSpan } from '@aztec/telemetry-client';
Expand Down Expand Up @@ -147,7 +145,9 @@ export class BBNativeRollupProver implements ServerCircuitProver {
* @returns The public inputs of the parity circuit.
*/
@trackSpan('BBNativeRollupProver.getBaseParityProof', { [Attributes.PROTOCOL_CIRCUIT_NAME]: 'base-parity' })
public async getBaseParityProof(inputs: BaseParityInputs): Promise<RootParityInput<typeof RECURSIVE_PROOF_LENGTH>> {
public async getBaseParityProof(
inputs: BaseParityInputs,
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
const { circuitOutput, proof } = await this.createRecursiveProof(
inputs,
'BaseParityArtifact',
Expand All @@ -157,15 +157,9 @@ export class BBNativeRollupProver implements ServerCircuitProver {
);

const verificationKey = await this.getVerificationKeyDataForCircuit('BaseParityArtifact');

await this.verifyProof('BaseParityArtifact', proof.binaryProof);

return new RootParityInput(
proof,
verificationKey.keyAsFields,
getVKSiblingPath(ProtocolCircuitVkIndexes.BaseParityArtifact),
circuitOutput,
);
return makePublicInputsAndRecursiveProof(circuitOutput, proof, verificationKey);
}

/**
Expand All @@ -176,7 +170,7 @@ export class BBNativeRollupProver implements ServerCircuitProver {
@trackSpan('BBNativeRollupProver.getRootParityProof', { [Attributes.PROTOCOL_CIRCUIT_NAME]: 'root-parity' })
public async getRootParityProof(
inputs: RootParityInputs,
): Promise<RootParityInput<typeof NESTED_RECURSIVE_PROOF_LENGTH>> {
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof NESTED_RECURSIVE_PROOF_LENGTH>> {
const { circuitOutput, proof } = await this.createRecursiveProof(
inputs,
'RootParityArtifact',
Expand All @@ -186,15 +180,9 @@ export class BBNativeRollupProver implements ServerCircuitProver {
);

const verificationKey = await this.getVerificationKeyDataForCircuit('RootParityArtifact');

await this.verifyProof('RootParityArtifact', proof.binaryProof);

return new RootParityInput(
proof,
verificationKey.keyAsFields,
getVKSiblingPath(ProtocolCircuitVkIndexes.RootParityArtifact),
circuitOutput,
);
return makePublicInputsAndRecursiveProof(circuitOutput, proof, verificationKey);
}

/**
Expand All @@ -207,7 +195,7 @@ export class BBNativeRollupProver implements ServerCircuitProver {
}))
public async getAvmProof(
inputs: AvmCircuitInputs,
): Promise<ProofAndVerificationKey<RecursiveProof<typeof AVM_PROOF_LENGTH_IN_FIELDS>>> {
): Promise<ProofAndVerificationKey<typeof AVM_PROOF_LENGTH_IN_FIELDS>> {
const proofAndVk = await this.createAvmProof(inputs);
await this.verifyAvmProof(proofAndVk.proof.binaryProof, proofAndVk.verificationKey);
return proofAndVk;
Expand Down Expand Up @@ -574,7 +562,7 @@ export class BBNativeRollupProver implements ServerCircuitProver {

private async createAvmProof(
input: AvmCircuitInputs,
): Promise<ProofAndVerificationKey<RecursiveProof<typeof AVM_PROOF_LENGTH_IN_FIELDS>>> {
): Promise<ProofAndVerificationKey<typeof AVM_PROOF_LENGTH_IN_FIELDS>> {
const operation = async (bbWorkingDirectory: string) => {
const provingResult = await this.generateAvmProofWithBB(input, bbWorkingDirectory);

Expand Down Expand Up @@ -610,9 +598,7 @@ export class BBNativeRollupProver implements ServerCircuitProver {
return await this.runInDirectory(operation);
}

public async getTubeProof(
input: TubeInputs,
): Promise<ProofAndVerificationKey<RecursiveProof<typeof TUBE_PROOF_LENGTH>>> {
public async getTubeProof(input: TubeInputs): Promise<ProofAndVerificationKey<typeof TUBE_PROOF_LENGTH>> {
// this probably is gonna need to call client ivc
const operation = async (bbWorkingDirectory: string) => {
logger.debug(`createTubeProof: ${bbWorkingDirectory}`);
Expand Down
35 changes: 9 additions & 26 deletions yarn-project/bb-prover/src/test/test_circuit_prover.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ import {
type KernelCircuitPublicInputs,
type MergeRollupInputs,
NESTED_RECURSIVE_PROOF_LENGTH,
type ParityPublicInputs,
type PrivateBaseRollupInputs,
type PrivateKernelEmptyInputData,
PrivateKernelEmptyInputs,
type Proof,
type PublicBaseRollupInputs,
RECURSIVE_PROOF_LENGTH,
type RecursiveProof,
RootParityInput,
type RootParityInputs,
type RootRollupInputs,
type RootRollupPublicInputs,
Expand All @@ -40,7 +39,6 @@ import { createDebugLogger } from '@aztec/foundation/log';
import { sleep } from '@aztec/foundation/sleep';
import { Timer } from '@aztec/foundation/timer';
import {
ProtocolCircuitVkIndexes,
ProtocolCircuitVks,
type ServerProtocolArtifact,
SimulatedServerCircuitArtifacts,
Expand All @@ -64,7 +62,6 @@ import {
convertSimulatedPrivateKernelEmptyOutputsFromWitnessMap,
convertSimulatedPublicBaseRollupInputsToWitnessMap,
convertSimulatedPublicBaseRollupOutputsFromWitnessMap,
getVKSiblingPath,
} from '@aztec/noir-protocol-circuits-types';
import { type SimulationProvider, WASMSimulator, emitCircuitSimulationStats } from '@aztec/simulator';
import { type TelemetryClient, trackSpan } from '@aztec/telemetry-client';
Expand Down Expand Up @@ -126,21 +123,16 @@ export class TestCircuitProver implements ServerCircuitProver {
* @returns The public inputs of the parity circuit.
*/
@trackSpan('TestCircuitProver.getBaseParityProof')
public async getBaseParityProof(inputs: BaseParityInputs): Promise<RootParityInput<typeof RECURSIVE_PROOF_LENGTH>> {
const result = await this.simulate(
public async getBaseParityProof(
inputs: BaseParityInputs,
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>> {
return await this.simulate(
inputs,
'BaseParityArtifact',
RECURSIVE_PROOF_LENGTH,
convertBaseParityInputsToWitnessMap,
convertBaseParityOutputsFromWitnessMap,
);

return new RootParityInput(
result.proof,
result.verificationKey.keyAsFields,
getVKSiblingPath(ProtocolCircuitVkIndexes['BaseParityArtifact']),
result.inputs,
);
}

/**
Expand All @@ -151,26 +143,17 @@ export class TestCircuitProver implements ServerCircuitProver {
@trackSpan('TestCircuitProver.getRootParityProof')
public async getRootParityProof(
inputs: RootParityInputs,
): Promise<RootParityInput<typeof NESTED_RECURSIVE_PROOF_LENGTH>> {
const result = await this.simulate(
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof NESTED_RECURSIVE_PROOF_LENGTH>> {
return await this.simulate(
inputs,
'RootParityArtifact',
NESTED_RECURSIVE_PROOF_LENGTH,
convertRootParityInputsToWitnessMap,
convertRootParityOutputsFromWitnessMap,
);

return new RootParityInput(
result.proof,
result.verificationKey.keyAsFields,
getVKSiblingPath(ProtocolCircuitVkIndexes['RootParityArtifact']),
result.inputs,
);
}

public async getTubeProof(
_tubeInput: TubeInputs,
): Promise<ProofAndVerificationKey<RecursiveProof<typeof TUBE_PROOF_LENGTH>>> {
public async getTubeProof(_tubeInput: TubeInputs): Promise<ProofAndVerificationKey<typeof TUBE_PROOF_LENGTH>> {
await this.delay();
return makeProofAndVerificationKey(makeEmptyRecursiveProof(TUBE_PROOF_LENGTH), VerificationKeyData.makeFakeHonk());
}
Expand Down Expand Up @@ -293,7 +276,7 @@ export class TestCircuitProver implements ServerCircuitProver {

public async getAvmProof(
_inputs: AvmCircuitInputs,
): Promise<ProofAndVerificationKey<RecursiveProof<typeof AVM_PROOF_LENGTH_IN_FIELDS>>> {
): Promise<ProofAndVerificationKey<typeof AVM_PROOF_LENGTH_IN_FIELDS>> {
// We can't simulate the AVM because we don't have enough context to do so (e.g., DBs).
// We just return an empty proof and VK data.
this.logger.debug('Skipping AVM simulation in TestCircuitProver.');
Expand Down
34 changes: 20 additions & 14 deletions yarn-project/circuit-types/src/interfaces/proving-job.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import {
KernelCircuitPublicInputs,
MergeRollupInputs,
NESTED_RECURSIVE_PROOF_LENGTH,
ParityPublicInputs,
PrivateBaseRollupInputs,
PrivateKernelEmptyInputData,
PublicBaseRollupInputs,
RECURSIVE_PROOF_LENGTH,
RecursiveProof,
RootParityInput,
RootParityInputs,
RootRollupInputs,
RootRollupPublicInputs,
Expand All @@ -29,21 +29,24 @@ import { z } from 'zod';

import { type CircuitName } from '../stats/index.js';

export type ProofAndVerificationKey<P> = { proof: P; verificationKey: VerificationKeyData };
export type ProofAndVerificationKey<N extends number> = {
proof: RecursiveProof<N>;
verificationKey: VerificationKeyData;
};

function schemaForRecursiveProofAndVerificationKey<N extends number>(
proofLength: N,
): ZodFor<ProofAndVerificationKey<RecursiveProof<N>>> {
): ZodFor<ProofAndVerificationKey<N>> {
return z.object({
proof: RecursiveProof.schemaFor(proofLength),
verificationKey: VerificationKeyData.schema,
}) as ZodFor<ProofAndVerificationKey<RecursiveProof<N>>>;
});
}

export function makeProofAndVerificationKey<P>(
proof: P,
export function makeProofAndVerificationKey<N extends number>(
proof: RecursiveProof<N>,
verificationKey: VerificationKeyData,
): ProofAndVerificationKey<P> {
): ProofAndVerificationKey<N> {
return { proof, verificationKey };
}

Expand All @@ -55,8 +58,8 @@ export type PublicInputsAndRecursiveProof<T, N extends number = typeof NESTED_RE

function schemaForPublicInputsAndRecursiveProof<T extends object>(
inputs: ZodFor<T>,
proofSize = NESTED_RECURSIVE_PROOF_LENGTH,
): ZodFor<PublicInputsAndRecursiveProof<T>> {
const proofSize = NESTED_RECURSIVE_PROOF_LENGTH;
return z.object({
inputs,
proof: RecursiveProof.schemaFor(proofSize),
Expand Down Expand Up @@ -155,17 +158,20 @@ export const ProvingJobSchema = z.object({ id: JobIdSchema, request: ProvingRequ

type ProvingRequestResultsMap = {
[ProvingRequestType.PRIVATE_KERNEL_EMPTY]: PublicInputsAndRecursiveProof<KernelCircuitPublicInputs>;
[ProvingRequestType.PUBLIC_VM]: ProofAndVerificationKey<RecursiveProof<typeof AVM_PROOF_LENGTH_IN_FIELDS>>;
[ProvingRequestType.PUBLIC_VM]: ProofAndVerificationKey<typeof AVM_PROOF_LENGTH_IN_FIELDS>;
[ProvingRequestType.PRIVATE_BASE_ROLLUP]: PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs>;
[ProvingRequestType.PUBLIC_BASE_ROLLUP]: PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs>;
[ProvingRequestType.MERGE_ROLLUP]: PublicInputsAndRecursiveProof<BaseOrMergeRollupPublicInputs>;
[ProvingRequestType.EMPTY_BLOCK_ROOT_ROLLUP]: PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs>;
[ProvingRequestType.BLOCK_ROOT_ROLLUP]: PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs>;
[ProvingRequestType.BLOCK_MERGE_ROLLUP]: PublicInputsAndRecursiveProof<BlockRootOrBlockMergePublicInputs>;
[ProvingRequestType.ROOT_ROLLUP]: PublicInputsAndRecursiveProof<RootRollupPublicInputs>;
[ProvingRequestType.BASE_PARITY]: RootParityInput<typeof RECURSIVE_PROOF_LENGTH>;
[ProvingRequestType.ROOT_PARITY]: RootParityInput<typeof NESTED_RECURSIVE_PROOF_LENGTH>;
[ProvingRequestType.TUBE_PROOF]: ProofAndVerificationKey<RecursiveProof<typeof TUBE_PROOF_LENGTH>>;
[ProvingRequestType.BASE_PARITY]: PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>;
[ProvingRequestType.ROOT_PARITY]: PublicInputsAndRecursiveProof<
ParityPublicInputs,
typeof NESTED_RECURSIVE_PROOF_LENGTH
>;
[ProvingRequestType.TUBE_PROOF]: ProofAndVerificationKey<typeof TUBE_PROOF_LENGTH>;
};

export type ProvingRequestResultFor<T extends ProvingRequestType> = { type: T; result: ProvingRequestResultsMap[T] };
Expand Down Expand Up @@ -220,11 +226,11 @@ export const ProvingRequestResultSchema = z.discriminatedUnion('type', [
}),
z.object({
type: z.literal(ProvingRequestType.BASE_PARITY),
result: RootParityInput.schemaFor(RECURSIVE_PROOF_LENGTH),
result: schemaForPublicInputsAndRecursiveProof(ParityPublicInputs.schema, RECURSIVE_PROOF_LENGTH),
}),
z.object({
type: z.literal(ProvingRequestType.ROOT_PARITY),
result: RootParityInput.schemaFor(NESTED_RECURSIVE_PROOF_LENGTH),
result: schemaForPublicInputsAndRecursiveProof(ParityPublicInputs.schema, NESTED_RECURSIVE_PROOF_LENGTH),
}),
z.object({
type: z.literal(ProvingRequestType.TUBE_PROOF),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import {
type KernelCircuitPublicInputs,
type MergeRollupInputs,
type NESTED_RECURSIVE_PROOF_LENGTH,
type ParityPublicInputs,
type PrivateBaseRollupInputs,
type PrivateKernelEmptyInputData,
type PublicBaseRollupInputs,
type RECURSIVE_PROOF_LENGTH,
type RecursiveProof,
type RootParityInput,
type RootParityInputs,
type RootRollupInputs,
type RootRollupPublicInputs,
type TUBE_PROOF_LENGTH,
type TubeInputs,
} from '@aztec/circuits.js';

Expand All @@ -37,7 +37,7 @@ export interface ServerCircuitProver {
inputs: BaseParityInputs,
signal?: AbortSignal,
epochNumber?: number,
): Promise<RootParityInput<typeof RECURSIVE_PROOF_LENGTH>>;
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>>;

/**
* Creates a proof for the given input.
Expand All @@ -47,7 +47,7 @@ export interface ServerCircuitProver {
inputs: RootParityInputs,
signal?: AbortSignal,
epochNumber?: number,
): Promise<RootParityInput<typeof NESTED_RECURSIVE_PROOF_LENGTH>>;
): Promise<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof NESTED_RECURSIVE_PROOF_LENGTH>>;

/**
* Creates a proof for the given input.
Expand All @@ -73,7 +73,7 @@ export interface ServerCircuitProver {
tubeInput: TubeInputs,
signal?: AbortSignal,
epochNumber?: number,
): Promise<ProofAndVerificationKey<RecursiveProof<typeof RECURSIVE_PROOF_LENGTH>>>;
): Promise<ProofAndVerificationKey<typeof TUBE_PROOF_LENGTH>>;

/**
* Creates a proof for the given input.
Expand Down Expand Up @@ -139,7 +139,7 @@ export interface ServerCircuitProver {
inputs: AvmCircuitInputs,
signal?: AbortSignal,
epochNumber?: number,
): Promise<ProofAndVerificationKey<RecursiveProof<typeof AVM_PROOF_LENGTH_IN_FIELDS>>>;
): Promise<ProofAndVerificationKey<typeof AVM_PROOF_LENGTH_IN_FIELDS>>;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Fr } from '@aztec/foundation/fields';
import { hexSchemaFor } from '@aztec/foundation/schemas';
import { BufferReader, serializeToBuffer } from '@aztec/foundation/serialize';
import { type FieldsOf } from '@aztec/foundation/types';

Expand Down Expand Up @@ -32,6 +33,11 @@ export class ParityPublicInputs {
return this.toBuffer().toString('hex');
}

/** Returns a hex representation for JSON serialization. */
toJSON() {
return this.toString();
}

/**
* Creates a new ParityPublicInputs instance from the given fields.
* @param fields - The fields to create the instance from.
Expand Down Expand Up @@ -68,4 +74,8 @@ export class ParityPublicInputs {
static fromString(str: string) {
return ParityPublicInputs.fromBuffer(Buffer.from(str, 'hex'));
}

static get schema() {
return hexSchemaFor(ParityPublicInputs);
}
}
Loading
Loading