Skip to content

Commit

Permalink
Merge pull request #1932 from o1-labs/feature/conditional-recursive-p…
Browse files Browse the repository at this point in the history
…roving

Conditional recursion from within ZkProgram
  • Loading branch information
mitschabaude authored Jan 8, 2025
2 parents f7d522e + dec0558 commit a5c15ad
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 22 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
### Added

- `ZkProgram` to support non-pure provable types as inputs and outputs https://github.com/o1-labs/o1js/pull/1828
- API for recursively proving a ZkProgram method from within another https://github.com/o1-labs/o1js/pull/1931
- APIs for recursively proving a ZkProgram method from within another https://github.com/o1-labs/o1js/pull/1931 https://github.com/o1-labs/o1js/pull/1932
- `let recursive = Experimental.Recursive(program);`
- `recursive.<methodName>(...args): Promise<PublicOutput>`
- `recursive.<methodName>.if(condition, ...args): Promise<PublicOutput>`
- This also works within the same program, as long as the return value is type-annotated
- Add `enforceTransactionLimits` parameter on Network https://github.com/o1-labs/o1js/issues/1910
- Method for optional types to assert none https://github.com/o1-labs/o1js/pull/1922
Expand Down
72 changes: 72 additions & 0 deletions src/examples/zkprogram/hash-chain.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/**
* This shows how to prove an arbitrarily long chain of hashes using ZkProgram, i.e.
* `hash^n(x) = y`.
*
* We implement this as a self-recursive ZkProgram, using `proveRecursivelyIf()`
*/
import {
assert,
Bool,
Experimental,
Field,
Poseidon,
Provable,
Struct,
ZkProgram,
} from 'o1js';

const HASHES_PER_PROOF = 30;

class HashChainSpec extends Struct({ x: Field, n: Field }) {}

const hashChain = ZkProgram({
name: 'hash-chain',
publicInput: HashChainSpec,
publicOutput: Field,

methods: {
chain: {
privateInputs: [],

async method({ x, n }: HashChainSpec) {
Provable.log('hashChain (start method)', n);
let y = x;
let k = Field(0);
let reachedN = Bool(false);

for (let i = 0; i < HASHES_PER_PROOF; i++) {
reachedN = k.equals(n);
y = Provable.if(reachedN, y, Poseidon.hash([y]));
k = Provable.if(reachedN, n, k.add(1));
}

// we have y = hash^k(x)
// now do z = hash^(n-k)(y) = hash^n(x) by calling this method recursively
// except if we have k = n, then ignore the output and use y
let z: Field = await hashChainRecursive.chain.if(reachedN.not(), {
x: y,
n: n.sub(k),
});
z = Provable.if(reachedN, y, z);
Provable.log('hashChain (start proving)', n);
return { publicOutput: z };
},
},
},
});
let hashChainRecursive = Experimental.Recursive(hashChain);

await hashChain.compile();

let n = 100;
let x = Field.random();

let { proof } = await hashChain.chain({ x, n: Field(n) });

assert(await hashChain.verify(proof), 'Proof invalid');

// check that the output is correct
let z = Array.from({ length: n }, () => 0).reduce((y) => Poseidon.hash([y]), x);
proof.publicOutput.assertEquals(z, 'Output is incorrect');

console.log('Finished hash chain proof');
97 changes: 78 additions & 19 deletions src/lib/proof-system/recursive.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { Tuple } from '../util/types.js';
import { Proof } from './proof.js';
import { mapObject, mapToObject, zip } from '../util/arrays.js';
import { Undefined, Void } from './zkprogram.js';
import { Bool } from '../provable/bool.js';

export { Recursive };

Expand All @@ -25,6 +26,7 @@ function Recursive<
...args: any
) => Promise<{ publicOutput: InferProvable<PublicOutputType> }>;
};
maxProofsVerified: () => Promise<0 | 1 | 2>;
} & {
[Key in keyof PrivateInputs]: (...args: any) => Promise<{
proof: Proof<
Expand All @@ -38,7 +40,13 @@ function Recursive<
InferProvable<PublicInputType>,
InferProvable<PublicOutputType>,
PrivateInputs[Key]
>;
> & {
if: ConditionalRecursiveProver<
InferProvable<PublicInputType>,
InferProvable<PublicOutputType>,
PrivateInputs[Key]
>;
};
} {
type PublicInput = InferProvable<PublicInputType>;
type PublicOutput = InferProvable<PublicOutputType>;
Expand All @@ -64,9 +72,15 @@ function Recursive<

let regularRecursiveProvers = mapToObject(methodKeys, (key) => {
return async function proveRecursively_(
conditionAndConfig: Bool | { condition: Bool; domainLog2?: number },
publicInput: PublicInput,
...args: TupleToInstances<PrivateInputs[MethodKey]>
) {
): Promise<PublicOutput> {
let condition =
conditionAndConfig instanceof Bool
? conditionAndConfig
: conditionAndConfig.condition;

// create the base proof in a witness block
let proof = await Provable.witnessAsync(SelfProof, async () => {
// move method args to constants
Expand All @@ -78,6 +92,20 @@ function Recursive<
Provable.toConstant(type, arg)
);

if (!condition.toBoolean()) {
let publicOutput: PublicOutput =
ProvableType.synthesize(publicOutputType);
let maxProofsVerified = await zkprogram.maxProofsVerified();
return SelfProof.dummy(
publicInput,
publicOutput,
maxProofsVerified,
conditionAndConfig instanceof Bool
? undefined
: conditionAndConfig.domainLog2
);
}

let prover = zkprogram[key];

if (hasPublicInput) {
Expand All @@ -96,32 +124,48 @@ function Recursive<

// declare and verify the proof, and return its public output
proof.declare();
proof.verify();
proof.verifyIf(condition);
return proof.publicOutput;
};
});

type RecursiveProver_<K extends MethodKey> = RecursiveProver<
PublicInput,
PublicOutput,
PrivateInputs[K]
>;
type RecursiveProvers = {
[K in MethodKey]: RecursiveProver_<K>;
};
let proveRecursively: RecursiveProvers = mapToObject(
methodKeys,
(key: MethodKey) => {
return mapObject(
regularRecursiveProvers,
(
prover
): RecursiveProver<PublicInput, PublicOutput, PrivateInputs[MethodKey]> & {
if: ConditionalRecursiveProver<
PublicInput,
PublicOutput,
PrivateInputs[MethodKey]
>;
} => {
if (!hasPublicInput) {
return ((...args: any) =>
regularRecursiveProvers[key](undefined as any, ...args)) as any;
return Object.assign(
((...args: any) =>
prover(new Bool(true), undefined as any, ...args)) as any,
{
if: (
condition: Bool | { condition: Bool; domainLog2?: number },
...args: any
) => prover(condition, undefined as any, ...args),
}
);
} else {
return regularRecursiveProvers[key] as any;
return Object.assign(
((pi: PublicInput, ...args: any) =>
prover(new Bool(true), pi, ...args)) as any,
{
if: (
condition: Bool | { condition: Bool; domainLog2?: number },
pi: PublicInput,
...args: any
) => prover(condition, pi, ...args),
}
);
}
}
);

return proveRecursively;
}

type RecursiveProver<
Expand All @@ -135,6 +179,21 @@ type RecursiveProver<
...args: TupleToInstances<Args>
) => Promise<PublicOutput>;

type ConditionalRecursiveProver<
PublicInput,
PublicOutput,
Args extends Tuple<ProvableType>
> = PublicInput extends undefined
? (
condition: Bool | { condition: Bool; domainLog2?: number },
...args: TupleToInstances<Args>
) => Promise<PublicOutput>
: (
condition: Bool | { condition: Bool; domainLog2?: number },
publicInput: PublicInput,
...args: TupleToInstances<Args>
) => Promise<PublicOutput>;

type TupleToInstances<T> = {
[I in keyof T]: InferProvable<T[I]>;
};
3 changes: 1 addition & 2 deletions src/lib/proof-system/zkprogram.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import {
unsetSrsCache,
} from '../../bindings/crypto/bindings/srs.js';
import {
ProvablePure,
ProvableType,
ProvableTypePure,
ToProvable,
Expand All @@ -55,7 +54,7 @@ import {
import { emptyWitness } from '../provable/types/util.js';
import { InferValue } from '../../bindings/lib/provable-generic.js';
import { DeclaredProof, ZkProgramContext } from './zkprogram-context.js';
import { mapObject, mapToObject, zip } from '../util/arrays.js';
import { mapObject, mapToObject } from '../util/arrays.js';

// public API
export {
Expand Down

0 comments on commit a5c15ad

Please sign in to comment.