diff --git a/CHANGELOG.md b/CHANGELOG.md index 3dd4c5c94..0c0dc9b8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Added - `ZkProgram` to support non-pure provable types as inputs and outputs https://github.com/o1-labs/o1js/pull/1828 +- API for recursively proving a ZkProgram method from within another https://github.com/o1-labs/o1js/pull/1931 + - `let recursive = Experimental.Recursive(program);` + - `recursive.(...args): Promise` + - This also works within the same program, as long as the return value is type-annotated - Add `enforceTransactionLimits` parameter on Network https://github.com/o1-labs/o1js/issues/1910 - Method for optional types to assert none https://github.com/o1-labs/o1js/pull/1922 - Increased maximum supported amount of methods in a `SmartContract` or `ZkProgram` to 30. https://github.com/o1-labs/o1js/pull/1918 @@ -34,6 +38,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [2.1.0](https://github.com/o1-labs/o1js/compare/b04520d...e1bac02) - 2024-11-13 +### Added + - Support secp256r1 in elliptic curve and ECDSA gadgets https://github.com/o1-labs/o1js/pull/1885 ### Fixed diff --git a/run-ci-tests.sh b/run-ci-tests.sh index c4aa4a133..a75277d8b 100755 --- a/run-ci-tests.sh +++ b/run-ci-tests.sh @@ -9,6 +9,7 @@ case $TEST_TYPE in ./run src/examples/zkapps/reducer/reducer-composite.ts --bundle ./run src/examples/zkapps/composability.ts --bundle ./run src/tests/fake-proof.ts + ./run src/tests/inductive-proofs-internal.ts --bundle ./run tests/vk-regression/diverse-zk-program-run.ts --bundle ;; diff --git a/src/bindings b/src/bindings index a5a0883f0..e05efb999 160000 --- a/src/bindings +++ b/src/bindings @@ -1 +1 @@ -Subproject commit a5a0883f033afd4c9fd4bf96a5e0b8ebd92a87c9 +Subproject commit e05efb9999dc83107127d6758904e2512eb8582d diff --git a/src/index.ts b/src/index.ts index 1d27d3a86..a2dc54d68 100644 --- a/src/index.ts +++ b/src/index.ts @@ -148,6 +148,7 @@ import * as OffchainState_ from './lib/mina/actions/offchain-state.js'; import * as BatchReducer_ from './lib/mina/actions/batch-reducer.js'; import { Actionable } from './lib/mina/actions/offchain-state-serialization.js'; import { InferProvable } from './lib/provable/types/struct.js'; +import { Recursive as Recursive_ } from './lib/proof-system/recursive.js'; export { Experimental }; const Experimental_ = { @@ -162,6 +163,8 @@ const Experimental_ = { namespace Experimental { export let memoizeWitness = Experimental_.memoizeWitness; + export let Recursive = Recursive_; + // indexed merkle map export let IndexedMerkleMap = Experimental_.IndexedMerkleMap; export type IndexedMerkleMap = IndexedMerkleMapBase; diff --git a/src/lib/mina/account-update.ts b/src/lib/mina/account-update.ts index 8bfcd71d6..6f145e81c 100644 --- a/src/lib/mina/account-update.ts +++ b/src/lib/mina/account-update.ts @@ -655,7 +655,6 @@ type LazyProof = { kind: 'lazy-proof'; methodName: string; args: any[]; - previousProofs: Pickles.Proof[]; ZkappClass: typeof SmartContract; memoized: { fields: Field[]; aux: any[] }[]; blindingValue: Field; @@ -2116,14 +2115,7 @@ async function addProof( async function createZkappProof( prover: Pickles.Prover, - { - methodName, - args, - previousProofs, - ZkappClass, - memoized, - blindingValue, - }: LazyProof, + { methodName, args, ZkappClass, memoized, blindingValue }: LazyProof, { transaction, accountUpdate, index }: ZkappProverData ): Promise> { let publicInput = accountUpdate.toPublicInput(transaction); @@ -2141,7 +2133,7 @@ async function createZkappProof( blindingValue, }); try { - return await prover(publicInputFields, MlArray.to(previousProofs)); + return await prover(publicInputFields); } catch (err) { console.error(`Error when proving ${ZkappClass.name}.${methodName}()`); throw err; @@ -2151,7 +2143,7 @@ async function createZkappProof( } ); - let maxProofsVerified = ZkappClass._maxProofsVerified!; + let maxProofsVerified = await ZkappClass.getMaxProofsVerified(); const Proof = ZkappClass.Proof(); return new Proof({ publicInput, diff --git a/src/lib/mina/zkapp.ts b/src/lib/mina/zkapp.ts index 8720f3355..094b308a3 100644 --- a/src/lib/mina/zkapp.ts +++ b/src/lib/mina/zkapp.ts @@ -44,13 +44,13 @@ import { import { analyzeMethod, compileProgram, + computeMaxProofsVerified, Empty, - getPreviousProofsForProver, MethodInterface, sortMethodArguments, VerificationKey, } from '../proof-system/zkprogram.js'; -import { Proof } from '../proof-system/proof.js'; +import { Proof, ProofClass } from '../proof-system/proof.js'; import { PublicKey } from '../provable/crypto/signature.js'; import { InternalStateType, @@ -154,11 +154,6 @@ function method( // FIXME: overriding a method implies pushing a separate method entry here, yielding two entries with the same name // this should only be changed once we no longer share the _methods array with the parent class (otherwise a subclass declaration messes up the parent class) ZkappClass._methods.push(methodEntry); - ZkappClass._maxProofsVerified ??= 0; - ZkappClass._maxProofsVerified = Math.max( - ZkappClass._maxProofsVerified, - methodEntry.numberOfProofs - ) as 0 | 1 | 2; let func = descriptor.value as AsyncFunction; descriptor.value = wrapMethod(func, ZkappClass, internalMethodEntry); } @@ -341,8 +336,6 @@ function wrapMethod( { methodName: methodIntf.methodName, args: clonedArgs, - // proofs actually don't have to be cloned - previousProofs: getPreviousProofsForProver(actualArgs), ZkappClass, memoized, blindingValue, @@ -433,7 +426,6 @@ function wrapMethod( { methodName: methodIntf.methodName, args: constantArgs, - previousProofs: getPreviousProofsForProver(constantArgs), ZkappClass, memoized, blindingValue: constantBlindingValue, @@ -593,10 +585,10 @@ class SmartContract extends SmartContractBase { rows: number; digest: string; gates: Gate[]; + proofs: ProofClass[]; } >; // keyed by method name static _provers?: Pickles.Prover[]; - static _maxProofsVerified?: 0 | 1 | 2; static _verificationKey?: { data: string; hash: Field }; /** @@ -644,6 +636,7 @@ class SmartContract extends SmartContractBase { forceRecompile = false, } = {}) { let methodIntfs = this._methods ?? []; + let methodKeys = methodIntfs.map(({ methodName }) => methodName); let methods = methodIntfs.map(({ methodName }) => { return async ( publicInput: unknown, @@ -657,13 +650,15 @@ class SmartContract extends SmartContractBase { }); // run methods once to get information that we need already at compile time let methodsMeta = await this.analyzeMethods(); - let gates = methodIntfs.map((intf) => methodsMeta[intf.methodName].gates); + let gates = methodKeys.map((k) => methodsMeta[k].gates); + let proofs = methodKeys.map((k) => methodsMeta[k].proofs); let { verificationKey, provers, verify } = await compileProgram({ publicInputType: ZkappPublicInput, publicOutputType: Empty, methodIntfs, methods, gates, + proofs, proofSystemTag: this, cache, forceRecompile, @@ -689,6 +684,17 @@ class SmartContract extends SmartContractBase { return hash.toBigInt().toString(16); } + /** + * The maximum number of proofs that are verified by any of the zkApp methods. + * This is an internal parameter needed by the proof system. + */ + static async getMaxProofsVerified() { + let methodData = await this.analyzeMethods(); + return computeMaxProofsVerified( + Object.values(methodData).map((d) => d.proofs.length) + ); + } + /** * Deploys a {@link SmartContract}. * @@ -1189,7 +1195,7 @@ super.init(); try { for (let methodIntf of methodIntfs) { let accountUpdate: AccountUpdate; - let { rows, digest, gates, summary } = await analyzeMethod( + let { rows, digest, gates, summary, proofs } = await analyzeMethod( ZkappPublicInput, methodIntf, async (publicInput, publicKey, tokenId, ...args) => { @@ -1207,6 +1213,7 @@ super.init(); rows, digest, gates, + proofs, }; if (printSummary) console.log(methodIntf.methodName, summary()); } diff --git a/src/lib/proof-system/proof-system.unit-test.ts b/src/lib/proof-system/proof-system.unit-test.ts index c6b570f71..67dfca967 100644 --- a/src/lib/proof-system/proof-system.unit-test.ts +++ b/src/lib/proof-system/proof-system.unit-test.ts @@ -54,7 +54,6 @@ it('pickles rule creation', async () => { expect(methodIntf).toEqual({ methodName: 'main', args: [EmptyProof, Bool], - numberOfProofs: 1, }); // store compiled tag @@ -67,7 +66,8 @@ it('pickles rule creation', async () => { main as AnyFunction, { name: 'mock' }, methodIntf, - [] + [], + [EmptyProof] ); await equivalentAsync( @@ -133,7 +133,6 @@ it('pickles rule creation: nested proof', async () => { expect(methodIntf).toEqual({ methodName: 'main', args: [NestedProof2], - numberOfProofs: 2, }); // store compiled tag @@ -146,7 +145,8 @@ it('pickles rule creation: nested proof', async () => { main as AnyFunction, { name: 'mock' }, methodIntf, - [] + [], + [EmptyProof, EmptyProof] ); let dummy = await EmptyProof.dummy(Field(0), undefined, 0); diff --git a/src/lib/proof-system/proof.ts b/src/lib/proof-system/proof.ts index 345155cb7..c1d6af333 100644 --- a/src/lib/proof-system/proof.ts +++ b/src/lib/proof-system/proof.ts @@ -16,15 +16,18 @@ import type { Provable } from '../provable/provable.js'; import { assert } from '../util/assert.js'; import { Unconstrained } from '../provable/types/unconstrained.js'; import { ProvableType } from '../provable/types/provable-intf.js'; +import { ZkProgramContext } from './zkprogram-context.js'; // public API -export { ProofBase, Proof, DynamicProof }; +export { ProofBase, Proof, DynamicProof, ProofClass }; // internal API export { dummyProof, extractProofs, extractProofTypes, type ProofValue }; type MaxProofs = 0 | 1 | 2; +type ProofClass = Subclass; + class ProofBase { static publicInputType: FlexibleProvable = undefined as any; static publicOutputType: FlexibleProvable = undefined as any; @@ -40,6 +43,27 @@ class ProofBase { maxProofsVerified: 0 | 1 | 2; shouldVerify = Bool(false); + /** + * To verify a recursive proof inside a ZkProgram method, it has to be "declared" as part of + * the method. This is done by calling `declare()` on the proof. + * + * Note: `declare()` is a low-level method that most users will not have to call directly. + * For proofs that are inputs to the ZkProgram, it is done automatically. + * + * You can think of declaring a proof as a similar step as witnessing a variable, which introduces + * that variable to the circuit. Declaring a proof will tell Pickles to add the additional constraints + * for recursive proof verification. + * + * Similar to `Provable.witness()`, `declare()` is a no-op when run outside ZkProgram compilation or proving. + * It returns `false` in that case, and `true` if the proof was actually declared. + */ + declare() { + if (!ZkProgramContext.has()) return false; + const ProofClass = this.constructor as Subclass; + ZkProgramContext.declareProof({ ProofClass, proofInstance: this }); + return true; + } + toJSON(): JsonProof { let fields = this.publicFields(); return { diff --git a/src/lib/proof-system/recursive.ts b/src/lib/proof-system/recursive.ts new file mode 100644 index 000000000..f7f0a4a3f --- /dev/null +++ b/src/lib/proof-system/recursive.ts @@ -0,0 +1,140 @@ +import { InferProvable } from '../provable/types/struct.js'; +import { Provable } from '../provable/provable.js'; +import { ProvableType } from '../provable/types/provable-intf.js'; +import { Tuple } from '../util/types.js'; +import { Proof } from './proof.js'; +import { mapObject, mapToObject, zip } from '../util/arrays.js'; +import { Undefined, Void } from './zkprogram.js'; + +export { Recursive }; + +function Recursive< + PublicInputType extends Provable, + PublicOutputType extends Provable, + PrivateInputs extends { + [Key in string]: Tuple; + } +>( + zkprogram: { + name: string; + publicInputType: PublicInputType; + publicOutputType: PublicOutputType; + privateInputTypes: PrivateInputs; + rawMethods: { + [Key in keyof PrivateInputs]: ( + ...args: any + ) => Promise<{ publicOutput: InferProvable }>; + }; + } & { + [Key in keyof PrivateInputs]: (...args: any) => Promise<{ + proof: Proof< + InferProvable, + InferProvable + >; + }>; + } +): { + [Key in keyof PrivateInputs]: RecursiveProver< + InferProvable, + InferProvable, + PrivateInputs[Key] + >; +} { + type PublicInput = InferProvable; + type PublicOutput = InferProvable; + type MethodKey = keyof PrivateInputs; + + let { + publicInputType, + publicOutputType, + privateInputTypes: privateInputs, + rawMethods: methods, + } = zkprogram; + + let hasPublicInput = + publicInputType !== Undefined && publicInputType !== Void; + + class SelfProof extends Proof { + static publicInputType = publicInputType; + static publicOutputType = publicOutputType; + static tag = () => zkprogram; + } + + let methodKeys: MethodKey[] = Object.keys(methods); + + let regularRecursiveProvers = mapToObject(methodKeys, (key) => { + return async function proveRecursively_( + publicInput: PublicInput, + ...args: TupleToInstances + ) { + // create the base proof in a witness block + let proof = await Provable.witnessAsync(SelfProof, async () => { + // move method args to constants + let constInput = Provable.toConstant( + publicInputType, + publicInput + ); + let constArgs = zip(args, privateInputs[key]).map(([arg, type]) => + Provable.toConstant(type, arg) + ); + + let prover = zkprogram[key]; + + if (hasPublicInput) { + let { proof } = await prover(constInput, ...constArgs); + return proof; + } else { + let { proof } = await prover(...constArgs); + return proof; + } + }); + + // assert that the witnessed proof has the correct public input (which will be used by Pickles as part of verification) + if (hasPublicInput) { + Provable.assertEqual(publicInputType, proof.publicInput, publicInput); + } + + // declare and verify the proof, and return its public output + proof.declare(); + proof.verify(); + return proof.publicOutput; + }; + }); + + type RecursiveProver_ = RecursiveProver< + PublicInput, + PublicOutput, + PrivateInputs[K] + >; + type RecursiveProvers = { + [K in MethodKey]: RecursiveProver_; + }; + let proveRecursively: RecursiveProvers = mapToObject( + methodKeys, + (key: MethodKey) => { + if (!hasPublicInput) { + return ((...args: any) => + regularRecursiveProvers[key](undefined as any, ...args)) as any; + } else { + return regularRecursiveProvers[key] as any; + } + } + ); + + return proveRecursively; +} + +type RecursiveProver< + PublicInput, + PublicOutput, + Args extends Tuple +> = PublicInput extends undefined + ? (...args: TupleToInstances) => Promise + : ( + publicInput: PublicInput, + ...args: TupleToInstances + ) => Promise; + +type TupleToInstances = { + [I in keyof T]: InferProvable; +}; diff --git a/src/lib/proof-system/workers.ts b/src/lib/proof-system/workers.ts index 9076ad755..cff25a58b 100644 --- a/src/lib/proof-system/workers.ts +++ b/src/lib/proof-system/workers.ts @@ -1,4 +1,4 @@ -export { workers, setNumberOfWorkers }; +export { workers, setNumberOfWorkers, WithThreadPool }; const workers = { numWorkers: undefined as number | undefined, @@ -15,3 +15,68 @@ const workers = { const setNumberOfWorkers = (numWorkers: number) => { workers.numWorkers = numWorkers; }; + +type ThreadPoolState = + | { type: 'none' } + | { type: 'initializing'; initPromise: Promise } + | { type: 'running' } + | { type: 'exiting'; exitPromise: Promise }; + +function WithThreadPool({ + initThreadPool, + exitThreadPool, +}: { + initThreadPool: () => Promise; + exitThreadPool: () => Promise; +}) { + // state machine to enable calling multiple functions that need a thread pool at once + let state: ThreadPoolState = { type: 'none' }; + let isNeededBy = 0; + + return async function withThreadPool(run: () => Promise): Promise { + isNeededBy++; + // none, exiting -> initializing + switch (state.type) { + case 'none': { + let initPromise = initThreadPool(); + state = { type: 'initializing', initPromise }; + break; + } + case 'initializing': + case 'running': + break; + case 'exiting': { + let initPromise = state.exitPromise.then(initThreadPool); + state = { type: 'initializing', initPromise }; + break; + } + } + // initializing -> running + if (state.type === 'initializing') await state.initPromise; + state = { type: 'running' }; + + let result: T; + try { + result = await run(); + } finally { + // running -> exiting IF we don't need to run longer + isNeededBy--; + + if (state.type !== 'running') { + throw Error('bug in ThreadPool state machine'); + } + + if (isNeededBy < 1) { + let exitPromise = exitThreadPool(); + state = { type: 'exiting', exitPromise }; + + // exiting -> none IF we didn't move exiting -> initializing + await exitPromise; + if (state.type === 'exiting') { + state = { type: 'none' }; + } + } + } + return result; + }; +} diff --git a/src/lib/proof-system/zkprogram-context.ts b/src/lib/proof-system/zkprogram-context.ts new file mode 100644 index 000000000..e67787695 --- /dev/null +++ b/src/lib/proof-system/zkprogram-context.ts @@ -0,0 +1,29 @@ +import { Context } from '../util/global-context.js'; +import type { Subclass } from '../util/types.js'; +import type { ProofBase } from './proof.js'; + +export { ZkProgramContext, DeclaredProof }; + +type DeclaredProof = { + ProofClass: Subclass>; + proofInstance: ProofBase; +}; +type ZkProgramContext = { + proofs: DeclaredProof[]; +}; +let context = Context.create(); + +const ZkProgramContext = { + enter() { + return context.enter({ proofs: [] }); + }, + leave: context.leave, + has: context.has, + + declareProof(proof: DeclaredProof) { + context.get().proofs.push(proof); + }, + getDeclaredProofs() { + return context.get().proofs; + }, +}; diff --git a/src/lib/proof-system/zkprogram.ts b/src/lib/proof-system/zkprogram.ts index 29cd5d663..9629f56cb 100644 --- a/src/lib/proof-system/zkprogram.ts +++ b/src/lib/proof-system/zkprogram.ts @@ -45,6 +45,7 @@ import { extractProofTypes, Proof, ProofBase, + ProofClass, ProofValue, } from './proof.js'; import { @@ -53,6 +54,8 @@ import { } from './feature-flags.js'; import { emptyWitness } from '../provable/types/util.js'; import { InferValue } from '../../bindings/lib/provable-generic.js'; +import { DeclaredProof, ZkProgramContext } from './zkprogram-context.js'; +import { mapObject, mapToObject, zip } from '../util/arrays.js'; // public API export { @@ -70,13 +73,16 @@ export { export { CompiledTag, sortMethodArguments, - getPreviousProofsForProver, MethodInterface, picklesRuleFromFunction, compileProgram, analyzeMethod, Prover, dummyBase64Proof, + computeMaxProofsVerified, + RegularProver, + TupleToInstances, + PrivateInput, }; type Undefined = undefined; @@ -199,9 +205,9 @@ function ZkProgram< // derived types for convenience MethodSignatures extends Config['methods'] = Config['methods'], PrivateInputs extends { - [I in keyof MethodSignatures]: MethodSignatures[I]['privateInputs']; + [I in keyof Config['methods']]: Config['methods'][I]['privateInputs']; } = { - [I in keyof MethodSignatures]: MethodSignatures[I]['privateInputs']; + [I in keyof Config['methods']]: Config['methods'][I]['privateInputs']; }, AuxiliaryOutputs extends { [I in keyof MethodSignatures]: Get; @@ -260,6 +266,8 @@ function ZkProgram< let publicInputType: Provable = ProvableType.get( config.publicInput ?? Undefined ); + let hasPublicInput = + publicInputType !== Undefined && publicInputType !== Void; let publicOutputType: Provable = ProvableType.get( config.publicOutput ?? Void ); @@ -274,19 +282,28 @@ function ZkProgram< static tag = () => selfTag; } + type MethodKey = keyof Config['methods']; // TODO remove sort()! Object.keys() has a deterministic order - let methodKeys: (keyof Methods & string)[] = Object.keys(methods).sort(); // need to have methods in (any) fixed order + let methodKeys: MethodKey[] = Object.keys(methods).sort(); // need to have methods in (any) fixed order let methodIntfs = methodKeys.map((key) => sortMethodArguments( 'program', - key, + key as string, methods[key].privateInputs, ProvableType.get(methods[key].auxiliaryOutput) ?? Undefined, SelfProof ) ); let methodFunctions = methodKeys.map((key) => methods[key].method); - let maxProofsVerified = getMaxProofsVerified(methodIntfs); + let maxProofsVerified: undefined | 0 | 1 | 2 = undefined; + + async function getMaxProofsVerified() { + if (maxProofsVerified !== undefined) return maxProofsVerified; + let methodsMeta = await analyzeMethods(); + let proofs = methodKeys.map((k) => methodsMeta[k].proofs.length); + maxProofsVerified = computeMaxProofsVerified(proofs); + return maxProofsVerified; + } async function analyzeMethods() { let methodsMeta: Record< @@ -309,6 +326,7 @@ function ZkProgram< let compileOutput: | { provers: Pickles.Prover[]; + maxProofsVerified: 0 | 1 | 2; verify: ( statement: Pickles.Statement, proof: Pickles.Proof @@ -321,13 +339,15 @@ function ZkProgram< async function compile({ cache = Cache.FileSystemDefault, forceRecompile = false, - proofsEnabled = undefined, + proofsEnabled = undefined as boolean | undefined, } = {}) { doProving = proofsEnabled ?? doProving; if (doProving) { let methodsMeta = await analyzeMethods(); let gates = methodKeys.map((k) => methodsMeta[k].gates); + let proofs = methodKeys.map((k) => methodsMeta[k].proofs); + maxProofsVerified = computeMaxProofsVerified(proofs.map((p) => p.length)); let { provers, verify, verificationKey } = await compileProgram({ publicInputType, @@ -335,6 +355,7 @@ function ZkProgram< methodIntfs, methods: methodFunctions, gates, + proofs, proofSystemTag: selfTag, cache, forceRecompile, @@ -342,7 +363,7 @@ function ZkProgram< state: programState, }); - compileOutput = { provers, verify }; + compileOutput = { provers, verify, maxProofsVerified }; return { verificationKey }; } else { return { @@ -351,61 +372,59 @@ function ZkProgram< } } - function toProver( + // for each of the methods, create a prover function. + // in the first step, these are "regular" in that they always expect the public input as the first argument, + // which is easier to use internally. + type RegularProver_ = RegularProver< + PublicInput, + PublicOutput, + PrivateInputs[K], + InferProvableOrUndefined + >; + + function toRegularProver( key: K, i: number - ): [ - K, - Prover< - PublicInput, - PublicOutput, - PrivateInputs[K], - InferProvableOrUndefined - > - ] { - async function prove_( - publicInput: PublicInput, - ...args: TupleToInstances - ): Promise<{ - proof: Proof; - auxiliaryOutput: any; - }> { - class ProgramProof extends Proof { - static publicInputType = publicInputType; - static publicOutputType = publicOutputType; - static tag = () => selfTag; - } - + ): RegularProver_ { + return async function prove_(publicInput, ...args) { if (!doProving) { - let previousProofs = MlArray.to(getPreviousProofsForProver(args)); - - let { publicOutput, auxiliaryOutput } = - (await (methods[key].method as any)(publicInput, previousProofs)) ?? - {}; - - let proof = await ProgramProof.dummy( - publicInput, - publicOutput, - maxProofsVerified - ); - return { proof, auxiliaryOutput }; + // we step into a ZkProgramContext here to match the context nesting + // that would happen if proofs were enabled -- otherwise, proofs declared + // in an inner program could be counted to the outer program + let id = ZkProgramContext.enter(); + try { + let { publicOutput, auxiliaryOutput } = + (hasPublicInput + ? await (methods[key].method as any)(publicInput, ...args) + : await (methods[key].method as any)(...args)) ?? {}; + + let proof = await SelfProof.dummy( + publicInput, + publicOutput, + await getMaxProofsVerified() + ); + return { proof, auxiliaryOutput }; + } finally { + ZkProgramContext.leave(id); + } } - let picklesProver = compileOutput?.provers?.[i]; - if (picklesProver === undefined) { + if (compileOutput === undefined) { throw Error( - `Cannot prove execution of program.${key}(), no prover found. ` + + `Cannot prove execution of program.${String( + key + )}(), no prover found. ` + `Try calling \`await program.compile()\` first, this will cache provers in the background.\nIf you compiled your zkProgram with proofs disabled (\`proofsEnabled = false\`), you have to compile it with proofs enabled first.` ); } + let picklesProver = compileOutput.provers[i]; + let maxProofsVerified = compileOutput.maxProofsVerified; let { publicInputFields, publicInputAux } = toFieldAndAuxConsts( publicInputType, publicInput ); - let previousProofs = MlArray.to(getPreviousProofsForProver(args)); - let id = snarkContext.enter({ witnesses: args, inProver: true, @@ -414,7 +433,7 @@ function ZkProgram< let result: UnwrapPromise>; try { - result = await picklesProver(publicInputFields, previousProofs); + result = await picklesProver(publicInputFields); } finally { snarkContext.leave(id); } @@ -445,7 +464,7 @@ function ZkProgram< programState.reset('__nonPureOutput__'); return { - proof: new ProgramProof({ + proof: new SelfProof({ publicInput, publicOutput, proof, @@ -453,33 +472,28 @@ function ZkProgram< }), auxiliaryOutput, }; - } - - let prove: Prover< - PublicInput, - PublicOutput, - PrivateInputs[K], - InferProvableOrUndefined - >; - if ( - (publicInputType as any) === Undefined || - (publicInputType as any) === Void - ) { - prove = ((...args: any) => prove_(undefined as any, ...args)) as any; - } else { - prove = prove_ as any; - } - return [key, prove]; + }; } - - let provers = Object.fromEntries(methodKeys.map(toProver)) as { - [I in keyof Config['methods']]: Prover< - PublicInput, - PublicOutput, - PrivateInputs[I], - InferProvableOrUndefined - >; + let regularProvers = mapToObject(methodKeys, toRegularProver); + + // wrap "regular" provers to remove an `undefined` public input argument, + // this matches how the method itself was defined in the case of no public input + type Prover_ = Prover< + PublicInput, + PublicOutput, + PrivateInputs[K], + InferProvableOrUndefined + >; + type Provers = { + [K in MethodKey]: Prover_; }; + let provers: Provers = mapObject(regularProvers, (prover): Prover_ => { + if (publicInputType === Undefined || publicInputType === Void) { + return ((...args: any) => prover(undefined as any, ...args)) as any; + } else { + return prover as any; + } + }); function verify(proof: Proof) { if (!doProving) { @@ -527,6 +541,7 @@ function ZkProgram< rawMethods: Object.fromEntries( methodKeys.map((key) => [key, methods[key].method]) ) as any, + proofsEnabled: doProving, setProofsEnabled(proofsEnabled: boolean) { doProving = proofsEnabled; }, @@ -539,7 +554,7 @@ function ZkProgram< get: () => doProving, }); - return program as any; + return program; } type ZkProgram< @@ -607,7 +622,8 @@ function sortMethodArguments( ); }); - // extract proofs to count them and for sanity checks + // extract input proofs to count them and for sanity checks + // WARNING: this doesn't include internally declared proofs! let proofs = args.flatMap(extractProofTypes); let numberOfProofs = proofs.length; @@ -628,7 +644,7 @@ function sortMethodArguments( `Suggestion: You can merge more than two proofs by merging two at a time in a binary tree.` ); } - return { methodName, args, numberOfProofs, auxiliaryType }; + return { methodName, args, auxiliaryType }; } function isProvable(type: unknown): type is ProvableType { @@ -648,14 +664,9 @@ function isDynamicProof( return typeof type === 'function' && type.prototype instanceof DynamicProof; } -function getPreviousProofsForProver(methodArgs: any[]) { - return methodArgs.flatMap(extractProofs).map((proof) => proof.proof); -} - type MethodInterface = { methodName: string; args: ProvableType[]; - numberOfProofs: number; returnType?: Provable; auxiliaryType?: Provable; }; @@ -669,6 +680,7 @@ async function compileProgram({ methodIntfs, methods, gates, + proofs, proofSystemTag, cache, forceRecompile, @@ -680,6 +692,7 @@ async function compileProgram({ methodIntfs: MethodInterface[]; methods: ((...args: any) => unknown)[]; gates: Gate[][]; + proofs: ProofClass[][]; proofSystemTag: { name: string }; cache: Cache; forceRecompile: boolean; @@ -700,12 +713,13 @@ If you are using a SmartContract, make sure you are using the @method decorator. proofSystemTag, methodEntry, gates[i], + proofs[i], state ) ); - let maxProofs = getMaxProofsVerified(methodIntfs); - overrideWrapDomain ??= maxProofsToWrapDomain[maxProofs]; + let maxProofs = computeMaxProofsVerified(proofs.map((p) => p.length)); + overrideWrapDomain ??= maxProofsToWrapDomain[maxProofs]; let picklesCache: Pickles.Cache = [ 0, function read_(mlHeader) { @@ -761,12 +775,9 @@ If you are using a SmartContract, make sure you are using the @method decorator. // wrap provers let wrappedProvers = provers.map( (prover): Pickles.Prover => - async function picklesProver( - publicInput: MlFieldConstArray, - previousProofs: MlArray - ) { + async function picklesProver(publicInput: MlFieldConstArray) { return prettifyStacktracePromise( - withThreadPool(() => prover(publicInput, previousProofs)) + withThreadPool(() => prover(publicInput)) ); } ); @@ -787,19 +798,34 @@ If you are using a SmartContract, make sure you are using the @method decorator. }; } -function analyzeMethod( +async function analyzeMethod( publicInputType: Provable, methodIntf: MethodInterface, method: (...args: any) => unknown ) { - return Provable.constraintSystem(() => { - let args = methodIntf.args.map(emptyWitness); - let publicInput = emptyWitness(publicInputType); - // note: returning the method result here makes this handle async methods - if (publicInputType === Undefined || publicInputType === Void) - return method(...args); - return method(publicInput, ...args); - }); + let result: Awaited>; + let proofs: ProofClass[]; + let id = ZkProgramContext.enter(); + try { + result = await Provable.constraintSystem(() => { + let args = methodIntf.args.map(emptyWitness); + args.forEach((value) => + extractProofs(value).forEach((proof) => proof.declare()) + ); + + let publicInput = emptyWitness(publicInputType); + // note: returning the method result here makes this handle async methods + if (publicInputType === Undefined || publicInputType === Void) + return method(...args); + return method(publicInput, ...args); + }); + proofs = ZkProgramContext.getDeclaredProofs().map( + ({ ProofClass }) => ProofClass + ); + } finally { + ZkProgramContext.leave(id); + } + return { ...result, proofs }; } function inCircuitVkHash(inCircuitVk: unknown): Field { @@ -822,6 +848,7 @@ function picklesRuleFromFunction( proofSystemTag: { name: string }, { methodName, args, auxiliaryType }: MethodInterface, gates: Gate[], + verifiedProofs: ProofClass[], state?: ReturnType ): Pickles.Rule { async function main( @@ -833,42 +860,40 @@ function picklesRuleFromFunction( auxInputData, } = snarkContext.get(); assert(!(inProver && argsWithoutPublicInput === undefined)); + + // witness private inputs and declare input proofs + let id = ZkProgramContext.enter(); let finalArgs = []; - let proofs: { - Proof: Subclass>; - proof: ProofBase; - }[] = []; - let previousStatements: Pickles.Statement[] = []; for (let i = 0; i < args.length; i++) { - let type = args[i]; try { + let type = args[i]; let value = Provable.witness(type, () => { return argsWithoutPublicInput?.[i] ?? ProvableType.synthesize(type); }); finalArgs[i] = value; - for (let proof of extractProofs(value)) { - let Proof = proof.constructor as Subclass>; - proofs.push({ Proof, proof }); - let fields = proof.publicFields(); - let input = MlFieldArray.to(fields.input); - let output = MlFieldArray.to(fields.output); - previousStatements.push(MlPair(input, output)); - } + extractProofs(value).forEach((proof) => proof.declare()); } catch (e: any) { + ZkProgramContext.leave(id); e.message = `Error when witnessing in ${methodName}, argument ${i}: ${e.message}`; throw e; } } - let result: { - publicOutput?: any; - auxiliaryOutput?: any; - }; - if (publicInputType === Undefined || publicInputType === Void) { - result = (await func(...finalArgs)) as any; - } else { - let input = fromFieldVars(publicInputType, publicInput, auxInputData); - result = (await func(input, ...finalArgs)) as any; + + // run the user circuit + let result: { publicOutput?: any; auxiliaryOutput?: any }; + let proofs: DeclaredProof[]; + + try { + if (publicInputType === Undefined || publicInputType === Void) { + result = (await func(...finalArgs)) as any; + } else { + let input = fromFieldVars(publicInputType, publicInput, auxInputData); + result = (await func(input, ...finalArgs)) as any; + } + proofs = ZkProgramContext.getDeclaredProofs(); + } finally { + ZkProgramContext.leave(id); } if (result?.publicOutput) { @@ -877,13 +902,30 @@ function picklesRuleFromFunction( state?.setNonPureOutput(nonPureOutput); } - proofs.forEach(({ Proof, proof }) => { - if (!(proof instanceof DynamicProof)) return; + // now all proofs are declared - check that we got as many as during compile time + assert( + proofs.length === verifiedProofs.length, + `Expected ${verifiedProofs.length} proofs, but got ${proofs.length}` + ); + + // extract proof statements for Pickles + let previousStatements = proofs.map( + ({ proofInstance }): Pickles.Statement => { + let fields = proofInstance.publicFields(); + let input = MlFieldArray.to(fields.input); + let output = MlFieldArray.to(fields.output); + return MlPair(input, output); + } + ); + + // handle dynamic proofs + proofs.forEach(({ ProofClass, proofInstance }) => { + if (!(proofInstance instanceof DynamicProof)) return; // Initialize side-loaded verification key - const tag = Proof.tag(); + const tag = ProofClass.tag(); const computedTag = SideloadedTag.get(tag.name); - const vk = proof.usedVerificationKey; + const vk = proofInstance.usedVerificationKey; if (vk === undefined) { throw new Error( @@ -932,20 +974,20 @@ function picklesRuleFromFunction( return { publicOutput: MlFieldArray.to(publicOutput), previousStatements: MlArray.to(previousStatements), + previousProofs: MlArray.to(proofs.map((p) => p.proofInstance.proof)), shouldVerify: MlArray.to( - proofs.map((proof) => proof.proof.shouldVerify.toField().value) + proofs.map((proof) => proof.proofInstance.shouldVerify.toField().value) ), }; } - let proofs: Subclass[] = args.flatMap(extractProofTypes); - if (proofs.length > 2) { + if (verifiedProofs.length > 2) { throw Error( `${proofSystemTag.name}.${methodName}() has more than two proof arguments, which is not supported.\n` + `Suggestion: You can merge more than two proofs by merging two at a time in a binary tree.` ); } - let proofsToVerify = proofs.map((Proof) => { + let proofsToVerify = verifiedProofs.map((Proof) => { let tag = Proof.tag(); if (tag === proofSystemTag) return { isSelf: true as const }; else if (isDynamicProof(Proof)) { @@ -986,11 +1028,11 @@ function picklesRuleFromFunction( }; } -function getMaxProofsVerified(methodIntfs: MethodInterface[]) { - return methodIntfs.reduce( - (acc, { numberOfProofs }) => Math.max(acc, numberOfProofs), - 0 - ) as any as 0 | 1 | 2; +function computeMaxProofsVerified(proofs: number[]) { + return proofs.reduce((acc: number, n) => { + assert(n <= 2, 'Too many proofs'); + return Math.max(acc, n); + }, 0) as 0 | 1 | 2; } function fromFieldVars( @@ -1134,6 +1176,19 @@ type Method< >; }; +type RegularProver< + PublicInput, + PublicOutput, + Args extends Tuple, + AuxiliaryOutput +> = ( + publicInput: PublicInput, + ...args: TupleToInstances +) => Promise<{ + proof: Proof; + auxiliaryOutput: AuxiliaryOutput; +}>; + type Prover< PublicInput, PublicOutput, diff --git a/src/lib/util/arrays.ts b/src/lib/util/arrays.ts index b8d11660f..b62440ab2 100644 --- a/src/lib/util/arrays.ts +++ b/src/lib/util/arrays.ts @@ -1,6 +1,6 @@ import { assert } from './errors.js'; -export { chunk, chunkString, zip, pad }; +export { chunk, chunkString, zip, pad, mapObject, mapToObject }; function chunk(array: T[], size: number): T[][] { assert( @@ -31,3 +31,24 @@ function pad(array: T[], size: number, value: T): T[] { ); return array.concat(Array.from({ length: size - array.length }, () => value)); } + +function mapObject< + T extends Record, + F extends (value: T[K], key: K) => any +>(t: T, fn: F) { + let s = {} as { [K in keyof T]: ReturnType }; + for (let key in t) { + s[key] = fn(t[key], key); + } + return s; +} +function mapToObject< + Key extends string | number | symbol, + F extends (key: K, i: number) => any +>(keys: Key[], fn: F) { + let s = {} as { [K in Key]: ReturnType }; + keys.forEach((key, i) => { + s[key] = fn(key, i); + }); + return s; +} diff --git a/src/snarky.d.ts b/src/snarky.d.ts index 9b91a2de6..feac69d9d 100644 --- a/src/snarky.d.ts +++ b/src/snarky.d.ts @@ -628,6 +628,7 @@ declare namespace Pickles { main: (publicInput: MlArray) => Promise<{ publicOutput: MlArray; previousStatements: MlArray>; + previousProofs: MlArray; shouldVerify: MlArray; }>; /** @@ -655,8 +656,7 @@ declare namespace Pickles { ]; type Prover = ( - publicInput: MlArray, - previousProofs: MlArray + publicInput: MlArray ) => Promise<[_: 0, publicOutput: MlArray, proof: Proof]>; } diff --git a/src/tests/fake-proof.ts b/src/tests/fake-proof.ts index 46f76121a..55f0433b7 100644 --- a/src/tests/fake-proof.ts +++ b/src/tests/fake-proof.ts @@ -9,17 +9,21 @@ import { Struct, Field, Proof, + Unconstrained, + Provable, } from 'o1js'; import assert from 'assert'; const RealProgram = ZkProgram({ name: 'real', + publicOutput: UInt64, methods: { make: { privateInputs: [UInt64], async method(value: UInt64) { let expected = UInt64.from(34); value.assertEquals(expected); + return { publicOutput: value.add(1) }; }, }, }, @@ -27,13 +31,19 @@ const RealProgram = ZkProgram({ const FakeProgram = ZkProgram({ name: 'fake', + publicOutput: UInt64, methods: { - make: { privateInputs: [UInt64], async method(_: UInt64) {} }, + make: { + privateInputs: [UInt64], + async method(_: UInt64) { + return { publicOutput: UInt64.zero }; + }, + }, }, }); class RealProof extends ZkProgram.Proof(RealProgram) {} -const Nested = Struct({ inner: RealProof }); +class Nested extends Struct({ inner: RealProof }) {} const RecursiveProgram = ZkProgram({ name: 'recursive', @@ -46,11 +56,29 @@ const RecursiveProgram = ZkProgram({ }, verifyNested: { privateInputs: [Field, Nested], - async method(_unrelated, { inner }) { - inner satisfies Proof; + async method(_unrelated, { inner }: Nested) { + inner satisfies Proof; inner.verify(); }, }, + verifyInternal: { + privateInputs: [Unconstrained | undefined>], + async method( + fakeProof: Unconstrained | undefined> + ) { + // witness either fake proof from input, or real proof + let proof = await Provable.witnessAsync(RealProof, async () => { + let maybeFakeProof = fakeProof.get(); + if (maybeFakeProof !== undefined) return maybeFakeProof; + + let { proof } = await RealProgram.make(UInt64.from(34)); + return proof; + }); + + proof.declare(); + proof.verify(); + }, + }, }, }); @@ -71,7 +99,7 @@ let { verificationKey: programVk } = await RecursiveProgram.compile(); // proof that should be rejected const { proof: fakeProof } = await FakeProgram.make(UInt64.from(99999)); -const dummyProof = await RealProof.dummy(undefined, undefined, 0); +const dummyProof = await RealProof.dummy(undefined, UInt64.zero, 0); for (let proof of [fakeProof, dummyProof]) { // zkprogram rejects proof @@ -115,11 +143,10 @@ for (let proof of [fakeProof, dummyProof]) { }, 'recursive program rejects fake proof (nested)'); } +// zkprogram accepts proof (nested) const { proof: recursiveProofNested } = await RecursiveProgram.verifyNested( Field(0), - { - inner: realProof, - } + { inner: realProof } ); assert( await verify(recursiveProofNested, programVk), @@ -127,3 +154,23 @@ assert( ); console.log('fake proof test passed for nested proofs 🎉'); + +// same test for internal proofs + +for (let proof of [fakeProof, dummyProof]) { + // zkprogram rejects proof (internal) + await assert.rejects(async () => { + await RecursiveProgram.verifyInternal(Unconstrained.from(proof)); + }, 'recursive program rejects fake proof (internal)'); +} + +// zkprogram accepts proof (internal) +const { proof: internalProof } = await RecursiveProgram.verifyInternal( + Unconstrained.from(undefined) +); +assert( + await verify(internalProof, programVk), + 'recursive program accepts internal proof' +); + +console.log('fake proof test passed for internal proofs 🎉'); diff --git a/src/tests/inductive-proofs-internal.ts b/src/tests/inductive-proofs-internal.ts new file mode 100644 index 000000000..0b1941f01 --- /dev/null +++ b/src/tests/inductive-proofs-internal.ts @@ -0,0 +1,92 @@ +import { Field, ZkProgram, assert, Provable, Proof, Experimental } from 'o1js'; +import { tic, toc } from '../examples/utils/tic-toc.js'; + +let log: string[] = []; + +function pushLog(s: string) { + Provable.asProver(() => { + console.log(s); + log.push(s); + }); +} + +let mergeProgram = ZkProgram({ + name: 'recursive-2', + publicOutput: Field, + + methods: { + baseCase: { + privateInputs: [Field], + + async method(x: Field) { + pushLog('baseCase'); + x = x.add(7); + return { publicOutput: x }; + }, + }, + + mergeOne: { + privateInputs: [], + + async method() { + pushLog('mergeOne'); + let z = Provable.witness(Field, () => 0); + let x: Field = await mergeProgramRecursive.baseCase(z); + return { publicOutput: x.add(1) }; + }, + }, + + mergeTwo: { + privateInputs: [], + + async method() { + pushLog('mergeTwo'); + let z = Provable.witness(Field, () => 0); + let x: Field = await mergeProgramRecursive.baseCase(z); + let y: Field = await mergeProgramRecursive.mergeOne(); + return { publicOutput: x.add(y) }; + }, + }, + }, +}); +let mergeProgramRecursive = Experimental.Recursive(mergeProgram); + +let Wrapper = ZkProgram({ + name: 'wraps-recursive-2', + + methods: { + wrap: { + privateInputs: [ZkProgram.Proof(mergeProgram)], + + async method(proof: Proof) { + proof.verify(); + let x = proof.publicOutput; + x.assertLessThan(30); + }, + }, + }, +}); + +tic('compiling'); +await mergeProgram.compile(); +await Wrapper.compile(); +toc(); + +tic('executing 4 proofs'); +let { proof } = await mergeProgram.mergeTwo(); +toc(); + +assert(await mergeProgram.verify(proof), 'Proof is not valid'); + +proof.publicOutput.assertEquals(15); + +assert(log.length === 4, 'log.length === 4'); +assert(log[0] === 'mergeTwo', 'log[0] === "mergeTwo"'); +assert(log[1] === 'baseCase', 'log[1] === "baseCase"'); +assert(log[2] === 'mergeOne', 'log[2] === "mergeOne"'); +assert(log[3] === 'baseCase', 'log[3] === "baseCase"'); + +tic('execute wrapper proof'); +let { proof: wrapperProof } = await Wrapper.wrap(proof); +toc(); +assert(await Wrapper.verify(wrapperProof), 'Wrapper proof is not valid');