From ca8e11cea88db7ec0030f894ac169128ac3e732a Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 19 Jun 2023 18:54:49 -0700 Subject: [PATCH] Release 0.3.0 (#86) * feat: upgrade overall `halo2-base` API to support future multi-threaded assignments using our basic gate * WIP: currently `gates::flex_gate` is updated and passes basic test * BUG: `GateInstructions::idx_to_indicator` was missing a constraint to check that the indicator witness was equal to 1 when non-zero. * Previously the constraint ensured that `ind[i] = 0` when `idx != i` however `ind[idx]` could be anything!!! * update: working benches for `mul` and `inner_product` * feat: add `test_multithread_gates` * BUG: `get_last_bit` did not do an `assert_bit` check on the answer * this function was not used anywhere * fix: `builder::assign_*` was not handling cases where two gates overlap and there is a break point in that overlap * we need to copy a cell between columns to fix * feat: update `gates::range` to working tests and new API * In keygen mode, the `CircuitBuilder` will clone the `ThreadBuilder` instead of `take`ing it because the same circuit is used for both vk gen and pk gen. This could lead to more memory usage for pk gen. * fix: change `AssignedValue` type to `KeccakAssignedValue` for compatibility after halo2-base update * Initial version 0.3.0 of halo2-ecc (#12) * add multi-thread witness assignment support for `variable_base_msm` and `fixed_base_msm` * batch size 100 MSM witness generation went from 500ms -> 100ms * Sync with updates in `halo2_proofs_axiom` * `assign_advice` not longer returns `Result` so no more `unwrap` * Fix: assert uses of size hint in release mode (#13) * remove `size_hint` in `inner_product_simple` * change other uses of `size_hint` to follow with `assert_eq!` instead of `debug_assert_eq!` * Fix: bit decomposition edge cases (#14) * fix: change `debug_assert` in `decompose_u64_digits_limbs` to restrict `bit_len < 64` and `decompose_biguint` to `64 <= bit_len < 128` * add more comprehensive tests for above two functions * Initial checkpoint for halo2-ecc v0.3.0 (#15) * chore: clippy --fix * Feat/add readme (#4) * feat: add README * feat: re-enable `secp256k1` module with updated tests * chore: fix result println * chore: update Cargo halo2_proofs_axiom to axiom/dev branch * compatibility update with `halo2_proofs_axiom` Co-authored-by: Matthias Seitz * Fix: make `GateThreadBuilder` compatible with external usage (#16) * chore: expose gate_builder.unknown * feat: `GateThreadBuilder::assign_all` takes assigned_{advices,constants} as input instead of new hashmap, in case we want to constrain equalities for cells not belonging to this builder * chore: update halo2-pse tag * fix: `GateThreadBuilder::assign_all` now returns `HashMap`s of assigned cells for external equality constraints (e.g., instance cells, `AssignedCells` from chips not using halo2-lib). * fix: `assign_all` was not assigning constants as desired: it was assigning a new constant per context. This leads to confusion and possible undesired consequences down the line. * Fix: under-constrained `idx_to_indicator` (#17) *fix(BUG): `GateChip::idx_to_indicator` still had soundness bug where at index `idx` the value could be 0 or 1 (instead of only 1) * feat: add some function documentation * test(idx_to_indicator): add comprehensive tests * both positive and negative tests * Fix: soundness error in `FpChip::assert_eq` due to typo (#18) * chore: update halo2-ecc version to 0.3.0 * fix(BUG): `FpChip::assert_equal` had `a` instead of `b` typo * feat: add tests for `FpChip::assert_eq` * positive and negative tests * Remove redundant code and prevent race conditions (#19) * feat: move `GateCircuitBuilder::synthesize` to `sub_synthesize` function which also returns the assigned advices. * reduces code duplication between `GateCircuitBuilder::synthesize` and `RangeCircuitBuilder::synthesize` and also makes it easier to assign public instances elsewhere (e.g., snark-verifier) * feat: remove `Mutex` to prevent non-deterministism * In variable and fixed base `msm_par` functions, remove use of `Mutex` because even the `Mutex` is not thread- safe in the sense that: if you let `Mutex` decide order that `GateThreadBuilder` is unlocked, you may still add Contexts to the builder in a non-deterministic order. * fix: `fixed_base::msm_par` loading new zeros * In `msm_par` each parallelized context was loading a new zero via `ctx.load_zero()` * This led to using more cells than the non-parallelized version * In `fixed_base_msm_in`, the if statement depending on `rayon::current_number_threads` leads to inconsistent proving keys between different machines. This has been removed and now uses a fixed number `25`. * chore: use `info!` instead of `println` for params * Allow `assign_all` also if `witness_gen_only = true` * Fix: `inner_product_left_last` size hint (#25) * Add documentation for halo2-base (#27) * adds draft documentation for range.rs * draft docs for lib.rs, utiils.rs, builder.rs * fix: add suggested doc edits for range.rs * docs: add draft documentation for flex_gate.rs * fix: range.rs doc capitalization error * fix: suggested edits for utils.rs docs * fix: resolve comments for range.rs docs * fix: resolve comments on flex_gate.rs docs * fix: resolve comments for lib.rs, util.rs docs * fix: resolve comments for builder.rs docs * chore: use `info!` instead of `println` for params * Allow `assign_all` also if `witness_gen_only = true` * Fix: `inner_product_left_last` size hint (#25) * docs: minor fixes --------- Co-authored-by: PatStiles * Smart Range Builder (#29) * feat: smart `RangeCircuitBuilder` Allow `RangeCircuitBuilder` to not create lookup table if it detects that there's nothing to look up. * feat: add `RangeWithInstanceCircuitBuilder` * Moved from `snark-verifier-sdk` * Also made this circuit builder smart so it doesn't load lookup table if not necessary * In particular this can also be used as a `GateWithInstanceCircuitBuilder` * chore: derive Eq for CircuitBuilderStage * fix: RangeConfig should not unwrap LOOKUP_BITS * fix: `div_mod_var` when `a_num_bits <= b_num_bits` (#31) * Feat: extend halo2 base test coverage (#35) * feat: add flex_gate_test.rs and pos add() test * feat: add pos sub() test * feat: add pos neg() test * feat: add pos mul() test * feat: add pos mul_add() test * feat: add pos mul_not() test * feat: add pos assert_bit * feat: add pos div_unsafe() test * feat: add pos assert_is_const test * feat: add pos inner_product() test * feat: add pos inner_product_left_last() test * feat: add pos inner_product_with_sums test * feat: add pos sum_products_with_coeff_and_var test * feat: add pos and() test * feat: add pos not() test * feat: add pos select() test * feat: add pos or_and() test * feat: add pos bits_to_indicator() test * feat: add pos idx_to_indicator() test * feat: add pos select_by_indicator() test * feat: add pos select_from_idx() test * feat: add pos is_zero() test * feat: add pos is_equal() test * feat: add pos num_to_bits() test * feat: add pos lagrange_eval() test * feat: add pos get_field_element() test * feat: add pos range_check() tests * feat: add pos check_less_than() test * feat: add pos check_less_than_safe() test * feat: add pos check_big_less_than_safe() test * feat: add pos is_less_than() test * feat: add pos is_less_than_safe() test * feat: add pos is_big_less_than_safe() test * feat: add pos div_mod() test * feat: add pos get_last_bit() test * feat: add pos div_mod_var() test * fix: pass slices into test functions not arrays * feat: Add pos property tests for flex_gate * feat: Add positive property tests for flex_gate * feat: add pos property tests for range_check.rs * feat: add neg pranking test for idx_to_indicator * fix: change div_mod_var test values * feat(refactor): refactor property tests * fix: fix neg test, assert_const, assert_bit * fix: failing prop tests * feat: expand negative testing is_less_than_failing * fix: Circuit overflow errors on neg tests * fix: prop_test_mul_not * fix: everything but get_last_bit & lagrange * fix: clippy * fix: set LOOKUP_BITS in range tests, make range check neg test more robust * fix: neg_prop_tests cannot prank inputs Inputs have many copy constraints; pranking initial input will cause all copy constraints to fail * fix: test_is_big_less_than_safe, 240 bits max * Didn't want to change current `is_less_than` implementation, which in order to optimize lookups for smaller bits, only works when inputs have at most `(F::CAPACITY // lookup_bits - 1) * lookup_bits` bits * fix: inline doc for lagrange_and_eval * Remove proptest for lagrange_and_eval and leave as todo * tests: add readme about serial execution --------- Co-authored-by: Jonathan Wang * fix(ecdsa): allow u1*G == u2*PK case (#36) NOTE: current ecdsa requires `r, s` to be given as proper CRT integers TODO: newtypes to guard this assumption * fix: `log2_ceil(0)` should return `0` (#37) * Guard `ScalarField` byte representations to always be little-endian (#38) fix: guard `ScalarField` to be little-endian * fix: get_last_bit two errors (#39) 2 embarassing errors: * Witness gen for last bit was wrong (used xor instead of &) * `ctx.get` was called after `range_check` so it was getting the wrong cell * Add documentation for all debug_asserts (#40) feat: add documentation for all debug_asserts * fix: `FieldChip::divide` renamed `divide_unsafe` (#41) Add `divide` that checks denomintor is nonzero. Add documentation in cases where `divide_unsafe` is used. * Use new types to validate input assumptions (#43) * feat: add new types `ProperUint` and `ProperCrtUint` To guard around assumptions about big integer representations * fix: remove unused `FixedAssignedCRTInteger` * feat: use new types for bigint and field chips New types now guard for different assumptions on non-native bigint arithmetic. Distinguish between: - Overflow CRT integers - Proper BigUint with native part derived from limbs - Field elements where inequality < modulus is checked Also add type to help guard for inequality check in ec_add_unequal_strict Rust traits did not play so nicely with references, so I had to switch many functions to move inputs instead of borrow by reference. However to avoid writing `clone` everywhere, we allow conversion `From` reference to the new type via cloning. * feat: use `ProperUint` for `big_less_than` * feat(ecc): add fns for assign private witness points that constrain point to lie on curve * fix: unnecessary lifetimes * chore: remove clones * Better handling of EC point at infinity (#44) * feat: allow `msm_par` to return identity point * feat: handle point at infinity `multi_scalar_multiply` and `multi_exp_par` now handle point at infinity completely Add docs for `ec_add_unequal, ec_sub_unequal, ec_double_and_add_unequal` to specify point at infinity leads to undefined behavior * feat: use strict ec ops more often (#45) * `msm` implementations now always use `ec_{add,sub}_unequal` in strict mode for safety * Add docs to `scalar_multiply` and a flag to specify when it's safe to turn off some strict assumptions * feat: add `parallelize_in` helper function (#46) Multi-threading of witness generation is tricky because one has to ensure the circuit column assignment order stays deterministic. To ensure good developer experience / avoiding pitfalls, we provide a new helper function for this. Co-authored-by: Jonathan Wang * fix: minor code quality fixes (#47) * feat: `fixed_base::msm_par` handles identity point (#48) We still require fixed base points to be non-identity, but now handle the case when scalars may be zero or the final MSM value is identity point. * chore: add assert for query_cell_at_pos (#50) * feat: add Github CI running tests (#51) * fix: ignore code block for doctest (#52) * feat: add docs and assert with non-empty array checks (#53) * Release 0.3.0 ecdsa tests (#54) * More ecdsa tests * Update mod.rs * Update tests.rs * Update ecdsa.rs * Update ecdsa.rs * Update ecdsa.rs * chore: sync with release-0.3.0 and update CI Co-authored-by: yulliakot Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> * chore: fix CI cannot multi-thread tests involving lookups due to environment variables * fix: `prop_test_is_less_than_safe` (#58) This test doesn't run any prover so the input must satisfy range check assumption. More serious coverage is provided by `prop_test_neg_is_less_than_safe` * Add halo2-base readme (#66) * feat: add halo2-base readme * fix: readme formatting * fix: readme edits * fix: grammer * fix: use relative links and formatting * fix: formatting * feat: add RangeCircuitBuilder description * feat: rewording and small edits --------- Co-authored-by: PatStiles * fix: change all `1` to `1u64` to prevent unexpected overflow (#72) * [Fix] Panic when dealing with identity point (#71) * More ecdsa tests * Update mod.rs * Update tests.rs * Update ecdsa.rs * Update ecdsa.rs * Update ecdsa.rs * msm tests * Update mod.rs * Update msm_sum_infinity.rs * fix: ec_sub_strict was panicing when output is identity * affects the MSM functions: right now if the answer is identity, there will be a panic due to divide by 0 instead of just returning 0 * there could be a more optimal solution, but due to the traits for EccChip, we just generate a random point solely to avoid divide by 0 in the case of identity point * Fix/fb msm zero (#77) * fix: fixed_base scalar multiply for [-1]P * feat: use `multi_scalar_multiply` instead of `scalar_multiply` * to reduce code maintanence / redundancy * fix: add back scalar_multiply using any_point * feat: remove flag from variable base `scalar_multiply` * feat: add scalar multiply tests for secp256k1 * fix: variable scalar_multiply last select * Fix/msm tests output identity (#75) * fixed base msm tests for output infinity * fixed base msm tests for output infinity --------- Co-authored-by: yulliakot * feat: add tests and update CI --------- Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> Co-authored-by: yulliakot --------- Co-authored-by: yulliakot Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> * [Fix] scalar multiply completeness (#82) * fix: replace `scalar_multiply` with passthrough to MSM for now * feat(msm): use strict mode always * Previously did not use strict because we make assumptions about the curve `C`. Since this was not documented and is easy to miss, we use strict mode always. * docs: add assumptions to ec_sub_strict (#84) * fix: readme from previous merge * chore: cleanup CI for merge into main * chore: fix readme --------- Co-authored-by: Jonathan Wang Co-authored-by: Matthias Seitz Co-authored-by: PatStiles Co-authored-by: PatStiles <33334338+PatStiles@users.noreply.github.com> Co-authored-by: yulliakot Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> --- .github/workflows/ci.yml | 50 + CHANGELOG.md | 4 + Cargo.toml | 2 +- README.md | 24 +- halo2-base/Cargo.toml | 17 +- halo2-base/README.md | 590 +++++++ halo2-base/benches/inner_product.rs | 103 +- halo2-base/benches/mul.rs | 112 +- halo2-base/examples/inner_product.rs | 95 + .../gates/tests/prop_test.txt | 11 + halo2-base/src/gates/builder.rs | 796 +++++++++ halo2-base/src/gates/builder/parallelize.rs | 38 + halo2-base/src/gates/flex_gate.rs | 1531 ++++++++++------- halo2-base/src/gates/mod.rs | 869 +--------- halo2-base/src/gates/range.rs | 691 +++++--- halo2-base/src/gates/tests.rs | 463 ----- halo2-base/src/gates/tests/README.md | 9 + halo2-base/src/gates/tests/flex_gate_tests.rs | 266 +++ halo2-base/src/gates/tests/general.rs | 170 ++ .../src/gates/tests/idx_to_indicator.rs | 119 ++ halo2-base/src/gates/tests/mod.rs | 73 + halo2-base/src/gates/tests/neg_prop_tests.rs | 398 +++++ halo2-base/src/gates/tests/pos_prop_tests.rs | 326 ++++ .../src/gates/tests/range_gate_tests.rs | 155 ++ .../src/gates/tests/test_ground_truths.rs | 190 ++ halo2-base/src/lib.rs | 766 ++++----- halo2-base/src/utils.rs | 354 +++- halo2-ecc/Cargo.toml | 5 +- halo2-ecc/benches/fixed_base_msm.rs | 244 +-- halo2-ecc/benches/fp_mul.rs | 197 +-- halo2-ecc/benches/msm.rs | 340 ++-- .../bn254}/bench_ec_add.config | 0 .../bn254}/bench_fixed_msm.config | 0 .../configs/bn254/bench_fixed_msm.t.config | 5 + .../bn254}/bench_msm.config | 1 + halo2-ecc/configs/bn254/bench_msm.t.config | 5 + .../bn254}/bench_pairing.config | 0 .../configs/bn254/bench_pairing.t.config | 5 + .../bn254}/ec_add_circuit.config | 0 .../bn254}/fixed_msm_circuit.config | 0 halo2-ecc/configs/bn254/msm_circuit.config | 1 + .../bn254}/pairing_circuit.config | 0 .../secp256k1}/bench_ecdsa.config | 0 .../secp256k1}/ecdsa_circuit.config | 0 halo2-ecc/src/bigint/add_no_carry.rs | 47 +- halo2-ecc/src/bigint/big_is_equal.rs | 64 +- halo2-ecc/src/bigint/big_is_zero.rs | 63 +- halo2-ecc/src/bigint/big_less_than.rs | 16 +- halo2-ecc/src/bigint/carry_mod.rs | 230 +-- .../src/bigint/check_carry_mod_to_zero.rs | 140 +- halo2-ecc/src/bigint/check_carry_to_zero.rs | 85 +- halo2-ecc/src/bigint/mod.rs | 313 ++-- halo2-ecc/src/bigint/mul_no_carry.rs | 58 +- halo2-ecc/src/bigint/negative.rs | 14 +- .../src/bigint/scalar_mul_and_add_no_carry.rs | 65 +- halo2-ecc/src/bigint/scalar_mul_no_carry.rs | 43 +- halo2-ecc/src/bigint/select.rs | 63 +- halo2-ecc/src/bigint/select_by_indicator.rs | 68 +- halo2-ecc/src/bigint/sub.rs | 82 +- halo2-ecc/src/bigint/sub_no_carry.rs | 42 +- .../src/bn254/configs/msm_circuit.config | 1 - halo2-ecc/src/bn254/final_exp.rs | 227 ++- halo2-ecc/src/bn254/mod.rs | 17 +- halo2-ecc/src/bn254/pairing.rs | 368 ++-- halo2-ecc/src/bn254/tests/ec_add.rs | 318 +--- halo2-ecc/src/bn254/tests/fixed_base_msm.rs | 410 ++--- halo2-ecc/src/bn254/tests/mod.rs | 62 +- halo2-ecc/src/bn254/tests/msm.rs | 453 ++--- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 183 ++ .../tests/msm_sum_infinity_fixed_base.rs | 183 ++ halo2-ecc/src/bn254/tests/pairing.rs | 353 ++-- halo2-ecc/src/ecc/ecdsa.rs | 111 +- halo2-ecc/src/ecc/fixed_base.rs | 284 ++- halo2-ecc/src/ecc/fixed_base_pippenger.rs | 28 +- halo2-ecc/src/ecc/mod.rs | 1019 +++++++---- halo2-ecc/src/ecc/pippenger.rs | 296 +++- halo2-ecc/src/ecc/tests.rs | 191 +- halo2-ecc/src/fields/fp.rs | 510 +++--- halo2-ecc/src/fields/fp12.rs | 483 ++---- halo2-ecc/src/fields/fp2.rs | 429 +---- halo2-ecc/src/fields/mod.rs | 377 ++-- halo2-ecc/src/fields/tests.rs | 267 --- halo2-ecc/src/fields/tests/fp/assert_eq.rs | 82 + halo2-ecc/src/fields/tests/fp/mod.rs | 72 + halo2-ecc/src/fields/tests/fp12/mod.rs | 73 + halo2-ecc/src/fields/tests/mod.rs | 2 + halo2-ecc/src/fields/vector.rs | 495 ++++++ halo2-ecc/src/lib.rs | 1 + halo2-ecc/src/secp256k1/mod.rs | 12 +- halo2-ecc/src/secp256k1/tests/ecdsa.rs | 388 ++--- halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 191 ++ halo2-ecc/src/secp256k1/tests/mod.rs | 161 ++ .../zkevm-keccak/src/keccak_packed_multi.rs | 40 +- .../src/keccak_packed_multi/tests.rs | 3 + hashes/zkevm-keccak/src/util.rs | 8 +- .../src/util/constraint_builder.rs | 2 +- hashes/zkevm-keccak/src/util/eth_types.rs | 4 +- 97 files changed, 10378 insertions(+), 8144 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 CHANGELOG.md create mode 100644 halo2-base/README.md create mode 100644 halo2-base/examples/inner_product.rs create mode 100644 halo2-base/proptest-regressions/gates/tests/prop_test.txt create mode 100644 halo2-base/src/gates/builder.rs create mode 100644 halo2-base/src/gates/builder/parallelize.rs delete mode 100644 halo2-base/src/gates/tests.rs create mode 100644 halo2-base/src/gates/tests/README.md create mode 100644 halo2-base/src/gates/tests/flex_gate_tests.rs create mode 100644 halo2-base/src/gates/tests/general.rs create mode 100644 halo2-base/src/gates/tests/idx_to_indicator.rs create mode 100644 halo2-base/src/gates/tests/mod.rs create mode 100644 halo2-base/src/gates/tests/neg_prop_tests.rs create mode 100644 halo2-base/src/gates/tests/pos_prop_tests.rs create mode 100644 halo2-base/src/gates/tests/range_gate_tests.rs create mode 100644 halo2-base/src/gates/tests/test_ground_truths.rs rename halo2-ecc/{src/bn254/configs => configs/bn254}/bench_ec_add.config (100%) rename halo2-ecc/{src/bn254/configs => configs/bn254}/bench_fixed_msm.config (100%) create mode 100644 halo2-ecc/configs/bn254/bench_fixed_msm.t.config rename halo2-ecc/{src/bn254/configs => configs/bn254}/bench_msm.config (92%) create mode 100644 halo2-ecc/configs/bn254/bench_msm.t.config rename halo2-ecc/{src/bn254/configs => configs/bn254}/bench_pairing.config (100%) create mode 100644 halo2-ecc/configs/bn254/bench_pairing.t.config rename halo2-ecc/{src/bn254/configs => configs/bn254}/ec_add_circuit.config (100%) rename halo2-ecc/{src/bn254/configs => configs/bn254}/fixed_msm_circuit.config (100%) create mode 100644 halo2-ecc/configs/bn254/msm_circuit.config rename halo2-ecc/{src/bn254/configs => configs/bn254}/pairing_circuit.config (100%) rename halo2-ecc/{src/secp256k1/configs => configs/secp256k1}/bench_ecdsa.config (100%) rename halo2-ecc/{src/secp256k1/configs => configs/secp256k1}/ecdsa_circuit.config (100%) delete mode 100644 halo2-ecc/src/bn254/configs/msm_circuit.config create mode 100644 halo2-ecc/src/bn254/tests/msm_sum_infinity.rs create mode 100644 halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs delete mode 100644 halo2-ecc/src/fields/tests.rs create mode 100644 halo2-ecc/src/fields/tests/fp/assert_eq.rs create mode 100644 halo2-ecc/src/fields/tests/fp/mod.rs create mode 100644 halo2-ecc/src/fields/tests/fp12/mod.rs create mode 100644 halo2-ecc/src/fields/tests/mod.rs create mode 100644 halo2-ecc/src/fields/vector.rs create mode 100644 halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..08c34c40 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,50 @@ +name: Tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Build + run: cargo build --verbose + - name: Run halo2-base tests + run: | + cd halo2-base + cargo test -- --test-threads=1 + cd .. + - name: Run halo2-ecc tests MockProver + run: | + cd halo2-ecc + cargo test -- --test-threads=1 test_fp + cargo test -- test_ecc + cargo test -- test_secp + cargo test -- test_ecdsa + cargo test -- test_ec_add + cargo test -- test_fixed + cargo test -- test_msm + cargo test -- test_fb + cargo test -- test_pairing + cd .. + - name: Run halo2-ecc tests real prover + run: | + cd halo2-ecc + cargo test --release -- test_fp_assert_eq + cargo test --release -- --nocapture bench_secp256k1_ecdsa + cargo test --release -- --nocapture bench_ec_add + mv configs/bn254/bench_fixed_msm.t.config configs/bn254/bench_fixed_msm.config + cargo test --release -- --nocapture bench_fixed_base_msm + mv configs/bn254/bench_msm.t.config configs/bn254/bench_msm.config + cargo test --release -- --nocapture bench_msm + mv configs/bn254/bench_pairing.t.config configs/bn254/bench_pairing.config + cargo test --release -- --nocapture bench_pairing + cd .. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..ab67d01e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,4 @@ +# v0.3.0 + +- Remove `PlonkPlus` strategy for `GateInstructions` to reduce code complexity. + - Because this strategy involved 1 selector AND 1 fixed column per advice column, it seems hard to justify it will lead to better peformance for the prover or verifier. diff --git a/Cargo.toml b/Cargo.toml index 4f01110c..9d8d2d5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ debug-assertions = false lto = "fat" # `codegen-units = 1` can lead to WORSE performance - always bench to find best profile for your machine! # codegen-units = 1 -panic = "abort" +panic = "unwind" incremental = false # For performance profiling diff --git a/README.md b/README.md index a8d3a98f..ff9ee93e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # halo2-lib -This repository aims to provide basic primitives for writing zero-knowledge proof circuits using the [Halo 2](https://zcash.github.io/halo2/) proving stack. To discuss or collaborate, join our community on [Telegram](https://t.me/halo2lib). +This repository aims to provide basic primitives for writing zero-knowledge proof circuits using the [Halo 2](https://zcash.github.io/halo2/) proving stack. To discuss or collaborate, join our community on [Telegram](https://t.me/halo2lib). ## Getting Started @@ -278,14 +278,14 @@ cargo test --release --no-default-features --features "halo2-axiom, mimalloc" -- ## Projects built with `halo2-lib` -* [Axiom](https://github.com/axiom-crypto/axiom-eth) -- Prove facts about Ethereum on-chain data via aggregate block header, account, and storage proofs. -* [Proof of Email](https://github.com/zkemail/) -- Prove facts about emails with the same trust assumption as the email domain. - * [halo2-regex](https://github.com/zkemail/halo2-regex) - * [halo2-zk-email](https://github.com/zkemail/halo2-zk-email) - * [halo2-base64](https://github.com/zkemail/halo2-base64) - * [halo2-rsa](https://github.com/zkemail/halo2-rsa/tree/feat/new_bigint) -* [halo2-fri-gadget](https://github.com/maxgillett/halo2-fri-gadget) -- FRI verifier in halo2. -* [eth-voice-recovery](https://github.com/SoraSuegami/voice_recovery_circuit) -* [zkevm tx-circuit](https://github.com/scroll-tech/zkevm-circuits/tree/develop/zkevm-circuits/src/tx_circuit) -* [webauthn-halo2](https://github.com/zkwebauthn/webauthn-halo2) -- Proving and verifying WebAuthn with halo2. -* [Fixed Point Arithmetic](https://github.com/DCMMC/halo2-scaffold/tree/main/src/gadget) -- Fixed point arithmetic library in halo2. +- [Axiom](https://github.com/axiom-crypto/axiom-eth) -- Prove facts about Ethereum on-chain data via aggregate block header, account, and storage proofs. +- [Proof of Email](https://github.com/zkemail/) -- Prove facts about emails with the same trust assumption as the email domain. + - [halo2-regex](https://github.com/zkemail/halo2-regex) + - [halo2-zk-email](https://github.com/zkemail/halo2-zk-email) + - [halo2-base64](https://github.com/zkemail/halo2-base64) + - [halo2-rsa](https://github.com/zkemail/halo2-rsa/tree/feat/new_bigint) +- [halo2-fri-gadget](https://github.com/maxgillett/halo2-fri-gadget) -- FRI verifier in halo2. +- [eth-voice-recovery](https://github.com/SoraSuegami/voice_recovery_circuit) +- [zkevm tx-circuit](https://github.com/scroll-tech/zkevm-circuits/tree/develop/zkevm-circuits/src/tx_circuit) +- [webauthn-halo2](https://github.com/zkwebauthn/webauthn-halo2) -- Proving and verifying WebAuthn with halo2. +- [Fixed Point Arithmetic](https://github.com/DCMMC/halo2-scaffold/tree/main/src/gadget) -- Fixed point arithmetic library in halo2. diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 0046f2e0..33799495 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2-base" -version = "0.2.2" +version = "0.3.0" edition = "2021" [dependencies] @@ -11,22 +11,32 @@ num-traits = "0.2" rand_chacha = "0.3" rustc-hash = "1.1" ff = "0.12" +rayon = "1.6.1" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +log = "0.4" # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", tag = "v2023_01_17", package = "halo2_proofs", optional = true } +halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/dev", package = "halo2_proofs", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_01_20", optional = true } +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_02_02", optional = true } # plotting circuit layout plotters = { version = "0.3.0", optional = true } tabbycat = { version = "0.1", features = ["attributes"], optional = true } +# test-utils +rand = { version = "0.8", optional = true } + [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } rand = "0.8" pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" +rayon = "1.6.1" +test-case = "3.1.0" +proptest = "1.1.0" # memory allocation [target.'cfg(not(target_env = "msvc"))'.dependencies] @@ -41,6 +51,7 @@ halo2-pse = ["halo2_proofs"] halo2-axiom = ["halo2_proofs_axiom"] display = [] profile = ["halo2_proofs_axiom?/profile"] +test-utils = ["dep:rand"] [[bench]] name = "mul" diff --git a/halo2-base/README.md b/halo2-base/README.md new file mode 100644 index 00000000..6b078ab9 --- /dev/null +++ b/halo2-base/README.md @@ -0,0 +1,590 @@ +# Halo2-base + +Halo2-base provides a streamlined frontend for interacting with the Halo2 API. It simplifies circuit programming to declaring constraints over a single advice and selector column and provides built-in circuit configuration and parellel proving and witness generation. + +Programmed circuit constraints are stored in `GateThreadBuilder` as a `Vec` of `Context`'s. Each `Context` can be interpreted as a "virtual column" which tracks witness values and constraints but does not assign them as cells within the Halo2 backend. Conceptually, one can think that at circuit generation time, the virtual columns are all concatenated into a **single** virtual column. This virtual column is then re-distributed into the minimal number of true `Column`s (aka Plonkish arithmetization columns) to fit within a user-specified number of rows. These true columns are then assigned into the Plonkish arithemization using the vanilla Halo2 backend. This has several benefits: + +- The user only needs to specify the desired number of rows. The rest of the circuit configuration process is done automatically because the optimal number of columns in the circuit can be calculated from the total number of cells in the `Context`s. This eliminates the need to manually assign circuit parameters at circuit creation time. +- In addition, this simplifies the process of testing the performance of different circuit configurations (different Plonkish arithmetization shapes) in the Halo2 backend, since the same virtual columns in the `Context` can be re-distributed into different Plonkish arithmetization tables. + +A user can also parallelize witness generation by specifying a function and a `Vec` of inputs to perform in parallel using `parallelize_in()` which creates a separate `Context` for each input that performs the specified function. These "virtual columns" are then computed in parallel during witness generation and combined back into a single column "virtual column" before cell assignment in the Halo2 backend. + +All assigned values in a circuit are assigned in the Halo2 backend by calling `synthesize()` in `GateCircuitBuilder` (or [`RangeCircuitBuilder`](#rangecircuitbuilder)) which in turn invokes `assign_all()` (or `assign_threads_in` if only doing witness generation) in `GateThreadBuilder` to assign the witness values tracked in a `Context` to their respective `Column` in the circuit within the Halo2 backend. + +Halo2-base also provides pre-built [Chips](https://zcash.github.io/halo2/concepts/chips.html) for common arithmetic operations in `GateChip` and range check arguments in `RangeChip`. Our `Chip` implementations differ slightly from ZCash's `Chip` implementations. In Zcash, the `Chip` struct stores knowledge about the `Config` and custom gates used. In halo2-base a `Chip` stores only functions while the interaction with the circuit's `Config` is hidden and done in `GateCircuitBuilder`. + +The structure of halo2-base is outlined as follows: + +- `builder.rs`: Contains `GateThreadBuilder`, `GateCircuitBuilder`, and `RangeCircuitBuilder` which implement the logic to provide different arithmetization configurations with different performance tradeoffs in the Halo2 backend. +- `lib.rs`: Defines the `QuantumCell`, `ContextCell`, `AssignedValue`, and `Context` types which track assigned values within a circuit across multiple columns and provide a streamlined interface to assign witness values directly to the advice column. +- `utils.rs`: Contains `BigPrimeField` and `ScalerField` traits which represent field elements within Halo2 and provides methods to decompose field elements into `u64` limbs and convert between field elements and `BigUint`. +- `flex_gate.rs`: Contains the implementation of `GateChip` and the `GateInstructions` trait which provide functions for basic arithmetic operations within Halo2. +- `range.rs:`: Implements `RangeChip` and the `RangeInstructions` trait which provide functions for performing range check and other lookup argument operations. + +This readme compliments the in-line documentation of halo2-base, providing an overview of `builder.rs` and `lib.rs`. + +
+ +## [**Context**](src/lib.rs) + +`Context` holds all information of an execution trace (circuit and its witness values). `Context` represents a "virtual column" that stores unassigned constraint information in the Halo2 backend. Storing the circuit information in a `Context` rather than assigning it directly to the Halo2 backend allows for the pre-computation of circuit parameters and preserves the underlying circuit information allowing for its rearrangement into multiple columns for parallelization in the Halo2 backend. + +During `synthesize()`, the advice values of all `Context`s are concatenated into a single "virtual column" that is split into multiple true `Column`s at `break_points` each representing a different sub-section of the "virtual column". During circuit synthesis, all cells are assigned to Halo2 `AssignedCell`s in a single `Region` within Halo2's backend. + +For parallel witness generation, multiple `Context`s are created for each parallel operation. After parallel witness generation, these `Context`'s are combined to form a single "virtual column" as above. Note that while the witness generation can be multi-threaded, the ordering of the contents in each `Context`, and the order of the `Context`s themselves, must be deterministic. + +```rust ignore +pub struct Context { + + witness_gen_only: bool, + + pub context_id: usize, + + pub advice: Vec>, + + pub cells_to_lookup: Vec>, + + pub zero_cell: Option>, + + pub selector: Vec, + + pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, + + pub constant_equality_constraints: Vec<(F, ContextCell)>, +} +``` + +`witness_gen_only` is set to `true` if we only care about witness generation and not about circuit constraints, otherwise it is set to false. This should **not** be set to `true` during mock proving or **key generation**. When this flag is `true`, we perform certain optimizations that are only valid when we don't care about constraints or selectors. + +A `Context` holds all equality and constant constraints as a `Vec` of `ContextCell` tuples representing the positions of the two cells to constrain. `advice` and`selector` store the respective column values of the `Context`'s which may represent the entire advice and selector column or a sub-section of the advice and selector column during parellel witness generation. `cells_to_lookup` tracks `AssignedValue`'s of cells to be looked up in a global lookup table, specifically for range checks, shared among all `Context`'s'. + +### [**ContextCell**](./src/lib.rs): + +`ContextCell` is a pointer to a specific cell within a `Context` identified by the Context's `context_id` and the cell's relative `offset` from the first cell of the advice column of the `Context`. + +```rust ignore +#[derive(Clone, Copy, Debug)] +pub struct ContextCell { + /// Identifier of the [Context] that this cell belongs to. + pub context_id: usize, + /// Relative offset of the cell within this [Context] advice column. + pub offset: usize, +} +``` + +### [**AssignedValue**](./src/lib.rs): + +`AssignedValue` represents a specific `Assigned` value assigned to a specific cell within a `Context` of a circuit referenced by a `ContextCell`. + +```rust ignore +pub struct AssignedValue { + pub value: Assigned, + + pub cell: Option, +} +``` + +### [**Assigned**](./src/plonk/assigned.rs) + +`Assigned` is a wrapper enum for values assigned to a cell within a circuit which stores the value as a fraction and marks it for batched inversion using [Montgomery's trick](https://zcash.github.io/halo2/background/fields.html#montgomerys-trick). Performing batched inversion allows for the computation of the inverse of all marked values with a single inversion operation. + +```rust ignore +pub enum Assigned { + /// The field element zero. + Zero, + /// A value that does not require inversion to evaluate. + Trivial(F), + /// A value stored as a fraction to enable batch inversion. + Rational(F, F), +} +``` + +
+ +## [**QuantumCell**](./src/lib.rs) + +`QuantumCell` is a helper enum that abstracts the scenarios in which a value is assigned to the advice column in Halo2-base. Without `QuantumCell` assigning existing or constant values to the advice column requires manually specifying the enforced constraints on top of assigning the value leading to bloated code. `QuantumCell` handles these technical operations, all a developer needs to do is specify which enum option in `QuantumCell` the value they are adding corresponds to. + +```rust ignore +pub enum QuantumCell { + + Existing(AssignedValue), + + Witness(F), + + WitnessFraction(Assigned), + + Constant(F), +} +``` + +QuantumCell contains the following enum variants. + +- **Existing**: + Assigns a value to the advice column that exists within the advice column. The value is an existing value from some previous part of your computation already in the advice column in the form of an `AssignedValue`. When you add an existing cell into the table a new cell will be assigned into the advice column with value equal to the existing value. An equality constraint will then be added between the new cell and the "existing" cell so the Verifier has a guarantee that these two cells are always equal. + + ```rust ignore + QuantumCell::Existing(acell) => { + self.advice.push(acell.value); + + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); + } + } + ``` + +- **Witness**: + Assigns an entirely new witness value into the advice column, such as a private input. When `assign_cell()` is called the value is wrapped in as an `Assigned::Trivial()` which marks it for exclusion from batch inversion. + ```rust ignore + QuantumCell::Witness(val) => { + self.advice.push(Assigned::Trivial(val)); + } + ``` +- **WitnessFraction**: + Assigns an entirely new witness value to the advice column. `WitnessFraction` exists for optimization purposes and accepts Assigned values wrapped in `Assigned::Rational()` marked for batch inverion. + ```rust ignore + QuantumCell::WitnessFraction(val) => { + self.advice.push(val); + } + ``` +- **Constant**: + A value that is a "known" constant. A "known" refers to known at circuit creation time to both the Prover and Verifier. When you assign a constant value there exists another secret "Fixed" column in the circuit constraint table whose values are fixed at circuit creation time. When you assign a Constant value, you are adding this value to the Fixed column, adding the value as a witness to the Advice column, and then imposing an equality constraint between the two corresponding cells in the Fixed and Advice columns. + +```rust ignore +QuantumCell::Constant(c) => { + self.advice.push(Assigned::Trivial(c)); + // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.constant_equality_constraints.push((c, new_cell)); + } +} +``` + +
+ +## [**GateThreadBuilder**](./src/gates/builder.rs) & [**GateCircuitBuilder**](./src/gates/builder.rs) + +`GateThreadBuilder` tracks the cell assignments of a circuit as an array of `Vec` of `Context`' where `threads[i]` contains all `Context`'s for phase `i`. Each array element corresponds to a distinct challenge phase of Halo2's proving system, each of which has its own unique set of rows and columns. + +```rust ignore +#[derive(Clone, Debug, Default)] +pub struct GateThreadBuilder { + /// Threads for each challenge phase + pub threads: [Vec>; MAX_PHASE], + /// Max number of threads + thread_count: usize, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + use_unknown: bool, +} +``` + +Once a `GateThreadBuilder` is created, gates may be assigned to a `Context` (or in the case of parallel witness generation multiple `Context`'s) within `threads`. Once the circuit is written `config()` is called to pre-compute the circuits size and set the circuit's environment variables. + +[**config()**](./src/gates/builder.rs) + +```rust ignore +pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let total_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) + .collect::>(); + // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) + // if this is too small, manual configuration will be needed + let num_advice_per_phase = total_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_lookup_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) + .collect::>(); + let num_lookup_advice_per_phase = total_lookup_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { + threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) + })) + .len(); + let num_fixed = (total_fixed + (1 << k) - 1) >> k; + + let params = FlexGateConfigParams { + strategy: GateStrategy::Vertical, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + }; + #[cfg(feature = "display")] + { + for phase in 0..MAX_PHASE { + if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { + println!( + "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", + phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], + ); + } + } + println!("Total {total_fixed} fixed cells"); + println!("Auto-calculated config params:\n {params:#?}"); + } + std::env::set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); + params +} +``` + +For circuit creation a `GateCircuitBuilder` is created by passing the `GateThreadBuilder` as an argument to `GateCircuitBuilder`'s `keygen`,`mock`, or `prover` functions. `GateCircuitBuilder` acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's`Circuit` Trait and calling into `GateThreadBuilder` `assign_all()` and `assign_threads_in()` functions to perform circuit assignment. + +**Note for developers:** We encourage you to always use [`RangeCircuitBuilder`](#rangecircuitbuilder) instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. + +```rust ignore +/// Vector of vectors tracking the thread break points across different halo2 phases +pub type MultiPhaseThreadBreakPoints = Vec; + +#[derive(Clone, Debug)] +pub struct GateCircuitBuilder { + /// The Thread Builder for the circuit + pub builder: RefCell>, + /// Break points for threads within the circuit + pub break_points: RefCell, +} + +impl Circuit for GateCircuitBuilder { + type Config = FlexGateConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the circuit without withnesses filled in. + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config]. + fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase: _, + num_fixed, + k, + } = serde_json::from_str(&std::env::var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + self.sub_synthesize(&config, &[], &[], &mut layouter); + Ok(()) + } +} +``` + +During circuit creation `synthesize()` is invoked which passes into `sub_synthesize()` a `FlexGateConfig` containing the actual circuits columns and a mutable reference to a `Layouter` from the Halo2 API which facilitates the final assignment of cells within a `Region` of a circuit in Halo2's backend. + +`GateCircuitBuilder` contains a list of breakpoints for each thread across all phases in and `GateThreadBuilder` itself. Both are wrapped in a `RefCell` allowing them to be borrowed mutably so the function performing circuit creation can take ownership of the `builder` and `break_points` can be recorded during circuit creation for later use. + +[**sub_synthesize()**](./src/gates/builder.rs) + +```rust ignore + pub fn sub_synthesize( + &self, + gate: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + layouter: &mut impl Layouter, + ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { + let mut first_pass = SKIP_FIRST_PASS; + let mut assigned_advices = HashMap::new(); + layouter + .assign_region( + || "GateCircuitBuilder generated circuit", + |mut region| { + if first_pass { + first_pass = false; + return Ok(()); + } + // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize + // If we are not performing witness generation only, we can skip the first pass and assign threads directly + if !self.builder.borrow().witness_gen_only { + // clone the builder so we can re-use the circuit for both vk and pk gen + let builder = self.builder.borrow().clone(); + for threads in builder.threads.iter().skip(1) { + assert!( + threads.is_empty(), + "GateCircuitBuilder only supports FirstPhase for now" + ); + } + let assignments = builder.assign_all( + gate, + lookup_advice, + q_lookup, + &mut region, + Default::default(), + ); + *self.break_points.borrow_mut() = assignments.break_points; + assigned_advices = assignments.assigned_advices; + } else { + // If we are only generating witness, we can skip the first pass and assign threads directly + let builder = self.builder.take(); + let break_points = self.break_points.take(); + for (phase, (threads, break_points)) in builder + .threads + .into_iter() + .zip(break_points.into_iter()) + .enumerate() + .take(1) + { + assign_threads_in( + phase, + threads, + gate, + lookup_advice.get(phase).unwrap_or(&vec![]), + &mut region, + break_points, + ); + } + } + Ok(()) + }, + ) + .unwrap(); + assigned_advices + } +``` + +Within `sub_synthesize()` `layouter`'s `assign_region()` function is invoked which yields a mutable reference to `Region`. `region` is used to assign cells within a contiguous region of the circuit represented in Halo2's proving system. + +If `witness_gen_only` is not set within the `builder` (for keygen, and mock proving) `sub_synthesize` takes ownership of the `builder`, and calls `assign_all()` to assign all cells within this context to a circuit in Halo2's backend. The resulting column breakpoints are recorded in `GateCircuitBuilder`'s `break_points` field. + +`assign_all()` iterates over each `Context` within a `phase` and assigns the values and constraints of the advice, selector, fixed, and lookup columns to the circuit using `region`. + +Breakpoints for the advice column are assigned sequentially. If, the `row_offset` of the cell value being currently assigned exceeds the maximum amount of rows allowed in a column a new column is created. + +It should be noted this process is only compatible with the first phase of Halo2's proving system as retrieving witness challenges in later phases requires more specialized witness generation during synthesis. Therefore, `assign_all()` must assert all elements in `threads` are unassigned excluding the first phase. + +[**assign_all()**](./src/gates/builder.rs) + +```rust ignore +pub fn assign_all( + &self, + config: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + region: &mut Region, + KeygenAssignments { + mut assigned_advices, + mut assigned_constants, + mut break_points + }: KeygenAssignments, + ) -> KeygenAssignments { + ... + for (phase, threads) in self.threads.iter().enumerate() { + let mut break_point = vec![]; + let mut gate_index = 0; + let mut row_offset = 0; + for ctx in threads { + let mut basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + assert_eq!(ctx.selector.len(), ctx.advice.len()); + + for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { + let column = basic_gate.value; + let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; + #[cfg(feature = "halo2-axiom")] + let cell = *region.assign_advice(column, row_offset, value).cell(); + #[cfg(not(feature = "halo2-axiom"))] + let cell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); + ... + +``` + +In the case a breakpoint falls on the overlap between two gates (such as chained addition of two cells) the cells the breakpoint falls on must be copied to the next column and a new equality constraint enforced between the value of the cell in the old column and the copied cell in the new column. This prevents the circuit from being undersconstratined and preserves the equality constraint from the overlapping gates. + +```rust ignore +if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { + break_point.push(row_offset); + row_offset = 0; + gate_index += 1; + +// when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety + basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + let column = basic_gate.value; + + #[cfg(feature = "halo2-axiom")] + { + let ncell = region.assign_advice(column, row_offset, value); + region.constrain_equal(ncell.cell(), &cell); + } + #[cfg(not(feature = "halo2-axiom"))] + { + let ncell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + region.constrain_equal(ncell, cell).unwrap(); + } +} + +``` + +If `witness_gen_only` is set, only witness generation is performed, and no copy constraints or selector values are considered. + +Witness generation can be parallelized by a user by calling `parallelize_in()` and specifying a function and a `Vec` of inputs to perform in parallel. `parallelize_in()` creates a separate `Context` for each input that performs the specified function and appends them to the `Vec` of `Context`'s of a particular phase. + +[**assign_threads_in()**](./src/gates/builder.rs) + +```rust ignore +pub fn assign_threads_in( + phase: usize, + threads: Vec>, + config: &FlexGateConfig, + lookup_advice: &[Column], + region: &mut Region, + break_points: ThreadBreakPoints, +) { + if config.basic_gates[phase].is_empty() { + assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); + return; + } + + let mut break_points = break_points.into_iter(); + let mut break_point = break_points.next(); + + let mut gate_index = 0; + let mut column = config.basic_gates[phase][gate_index].value; + let mut row_offset = 0; + + let mut lookup_offset = 0; + let mut lookup_advice = lookup_advice.iter(); + let mut lookup_column = lookup_advice.next(); + for ctx in threads { + // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns + if lookup_column.is_some() { + for advice in ctx.cells_to_lookup { + if lookup_offset >= config.max_rows { + lookup_offset = 0; + lookup_column = lookup_advice.next(); + } + // Assign the lookup advice values to the lookup_column + let value = advice.value; + let lookup_column = *lookup_column.unwrap(); + #[cfg(feature = "halo2-axiom")] + region.assign_advice(lookup_column, lookup_offset, Value::known(value)); + #[cfg(not(feature = "halo2-axiom"))] + region + .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) + .unwrap(); + + lookup_offset += 1; + } + } + // Assign advice values to the advice columns in each [Context] + for advice in ctx.advice { + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + + if break_point == Some(row_offset) { + break_point = break_points.next(); + row_offset = 0; + gate_index += 1; + column = config.basic_gates[phase][gate_index].value; + + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + } + + row_offset += 1; + } + } + +``` + +`sub_synthesize` iterates over all phases and calls `assign_threads_in()` for that phase. `assign_threads_in()` iterates over all `Context`s within that phase and assigns all lookup and advice values in the `Context`, creating a new advice column at every pre-computed "breakpoint" by incrementing `gate_index` and assigning `column` to a new `Column` found at `config.basic_gates[phase][gate_index].value`. + +## [**RangeCircuitBuilder**](./src/gates/builder.rs) + +`RangeCircuitBuilder` is a wrapper struct around `GateCircuitBuilder`. Like `GateCircuitBuilder` it acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's `Circuit` Trait. + +```rust ignore +#[derive(Clone, Debug)] +pub struct RangeCircuitBuilder(pub GateCircuitBuilder); + +impl Circuit for RangeCircuitBuilder { + type Config = RangeConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + let strategy = match strategy { + GateStrategy::Vertical => RangeStrategy::Vertical, + }; + let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); + RangeConfig::configure( + meta, + strategy, + &num_advice_per_phase, + &num_lookup_advice_per_phase, + num_fixed, + lookup_bits, + k, + ) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // only load lookup table if we are actually doing lookups + if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 + || !config.q_lookup.iter().all(|q| q.is_none()) + { + config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); + Ok(()) + } +} +``` + +`RangeCircuitBuilder` differs from `GateCircuitBuilder` in that it contains a `RangeConfig` instead of a `FlexGateConfig` as its `Config`. `RangeConfig` contains a `lookup` table needed to declare lookup arguments within Halo2's backend. When creating a circuit that uses lookup tables `GateThreadBuilder` must be wrapped with `RangeCircuitBuilder` instead of `GateCircuitBuilder` otherwise circuit synthesis will fail as a lookup table is not present within the Halo2 backend. + +**Note:** We encourage you to always use `RangeCircuitBuilder` instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. diff --git a/halo2-base/benches/inner_product.rs b/halo2-base/benches/inner_product.rs index e5fec21c..9454faa3 100644 --- a/halo2-base/benches/inner_product.rs +++ b/halo2-base/benches/inner_product.rs @@ -1,9 +1,7 @@ #![allow(unused_imports)] #![allow(unused_variables)] -use halo2_base::gates::{ - flex_gate::{FlexGateConfig, GateStrategy}, - GateInstructions, -}; +use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; use halo2_base::halo2_proofs::{ arithmetic::Field, circuit::*, @@ -16,7 +14,12 @@ use halo2_base::halo2_proofs::{ }, transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; -use halo2_base::{Context, ContextParams, QuantumCell::Witness, SKIP_FIRST_PASS}; +use halo2_base::utils::ScalarField; +use halo2_base::{ + Context, + QuantumCell::{Existing, Witness}, + SKIP_FIRST_PASS, +}; use itertools::Itertools; use rand::rngs::OsRng; use std::marker::PhantomData; @@ -28,82 +31,50 @@ use pprof::criterion::{Output, PProfProfiler}; // Thanks to the example provided by @jebbow in his article // https://www.jibbow.com/posts/criterion-flamegraphs/ -#[derive(Clone, Default)] -struct MyCircuit { - _marker: PhantomData, -} - -const NUM_ADVICE: usize = 1; const K: u32 = 19; -impl Circuit for MyCircuit { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; +fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) { + assert_eq!(a.len(), b.len()); + let a = ctx.assign_witnesses(a); + let b = ctx.assign_witnesses(b); - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FlexGateConfig::configure(meta, GateStrategy::Vertical, &[NUM_ADVICE], 1, 0, K as usize) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "gate", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.max_rows, - num_context_ids: 1, - fixed_columns: config.constants.clone(), - }, - ); - let ctx = &mut aux; - - let a = (0..5).map(|_| Witness(Value::known(Fr::random(OsRng)))).collect_vec(); - let b = (0..5).map(|_| Witness(Value::known(Fr::random(OsRng)))).collect_vec(); - - for _ in 0..(1 << K) / 16 - 10 { - config.inner_product(ctx, a.clone(), b.clone()); - } - - Ok(()) - }, - ) + let chip = GateChip::default(); + for _ in 0..(1 << K) / 16 - 10 { + chip.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); } } fn bench(c: &mut Criterion) { - let circuit = MyCircuit:: { _marker: PhantomData }; + let k = 19u32; + // create circuit for keygen + let mut builder = GateThreadBuilder::new(false); + inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); + builder.config(k as usize, Some(20)); + let circuit = GateCircuitBuilder::mock(builder); - MockProver::run(K, &circuit, vec![]).unwrap().assert_satisfied(); + // check the circuit is correct just in case + MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); - let params = ParamsKZG::::setup(K, OsRng); + let params = ParamsKZG::::setup(k, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.break_points.take(); + drop(circuit); + let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); group.bench_with_input( - BenchmarkId::new("inner_product", K), + BenchmarkId::new("inner_product", k), &(¶ms, &pk), - |b, &(params, pk)| { - b.iter(|| { - let circuit = MyCircuit:: { _marker: PhantomData }; - let rng = OsRng; + |bencher, &(params, pk)| { + bencher.iter(|| { + let mut builder = GateThreadBuilder::new(true); + let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); + let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); + inner_prod_bench(builder.main(0), a, b); + let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -112,7 +83,7 @@ fn bench(c: &mut Criterion) { _, Blake2bWrite, G1Affine, Challenge255<_>>, _, - >(params, pk, &[circuit], &[&[]], rng, &mut transcript) + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) .expect("prover should not fail"); }) }, diff --git a/halo2-base/benches/mul.rs b/halo2-base/benches/mul.rs index 6698ae99..16687e08 100644 --- a/halo2-base/benches/mul.rs +++ b/halo2-base/benches/mul.rs @@ -1,9 +1,7 @@ -use halo2_base::gates::{ - flex_gate::{FlexGateConfig, GateStrategy}, - GateInstructions, -}; +use ff::Field; +use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ - circuit::*, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, poly::kzg::{ @@ -12,11 +10,8 @@ use halo2_base::halo2_proofs::{ }, transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; -use halo2_base::{ - Context, ContextParams, - QuantumCell::{Existing, Witness}, - SKIP_FIRST_PASS, -}; +use halo2_base::utils::ScalarField; +use halo2_base::Context; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -26,92 +21,43 @@ use pprof::criterion::{Output, PProfProfiler}; // Thanks to the example provided by @jebbow in his article // https://www.jibbow.com/posts/criterion-flamegraphs/ -#[derive(Clone, Default)] -struct MyCircuit { - a: Value, - b: Value, - c: Value, -} - -const NUM_ADVICE: usize = 1; const K: u32 = 9; -impl Circuit for MyCircuit { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; +fn mul_bench(ctx: &mut Context, inputs: [F; 2]) { + let [a, b]: [_; 2] = ctx.assign_witnesses(inputs).try_into().unwrap(); + let chip = GateChip::default(); - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FlexGateConfig::configure(meta, GateStrategy::PlonkPlus, &[NUM_ADVICE], 1, 0, K as usize) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "gate", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.max_rows, - num_context_ids: 1, - fixed_columns: config.constants.clone(), - }, - ); - let ctx = &mut aux; - - let (_a_cell, b_cell, c_cell) = { - let cells = config.assign_region_smart( - ctx, - vec![Witness(self.a), Witness(self.b), Witness(self.c)], - vec![], - vec![], - vec![], - ); - (cells[0].clone(), cells[1].clone(), cells[2].clone()) - }; - - for _ in 0..120 { - config.mul(ctx, Existing(&c_cell), Existing(&b_cell)); - } - - Ok(()) - }, - ) + for _ in 0..120 { + chip.mul(ctx, a, b); } } fn bench(c: &mut Criterion) { - let circuit = MyCircuit:: { - a: Value::known(Fr::from(10u64)), - b: Value::known(Fr::from(12u64)), - c: Value::known(Fr::from(120u64)), - }; + // create circuit for keygen + let mut builder = GateThreadBuilder::new(false); + mul_bench(builder.main(0), [Fr::zero(); 2]); + builder.config(K as usize, Some(9)); + let circuit = GateCircuitBuilder::keygen(builder); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.break_points.take(); + + let a = Fr::random(OsRng); + let b = Fr::random(OsRng); // native multiplication 120 times c.bench_with_input( BenchmarkId::new("native mul", K), - &(¶ms, &pk, &circuit), - |b, &(params, pk, circuit)| { - b.iter(|| { - let rng = OsRng; + &(¶ms, &pk, [a, b]), + |bencher, &(params, pk, inputs)| { + bencher.iter(|| { + let mut builder = GateThreadBuilder::new(true); + // do the computation + mul_bench(builder.main(0), inputs); + let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -120,8 +66,8 @@ fn bench(c: &mut Criterion) { _, Blake2bWrite, G1Affine, Challenge255<_>>, _, - >(params, pk, &[circuit.clone()], &[&[]], rng, &mut transcript) - .expect("prover should not fail"); + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) + .unwrap(); }) }, ); diff --git a/halo2-base/examples/inner_product.rs b/halo2-base/examples/inner_product.rs new file mode 100644 index 00000000..8572817e --- /dev/null +++ b/halo2-base/examples/inner_product.rs @@ -0,0 +1,95 @@ +#![allow(unused_imports)] +#![allow(unused_variables)] +use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; +use halo2_base::halo2_proofs::{ + arithmetic::Field, + circuit::*, + dev::MockProver, + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::*, + poly::kzg::multiopen::VerifierSHPLONK, + poly::kzg::strategy::SingleStrategy, + poly::kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::ProverSHPLONK, + }, + transcript::{Blake2bRead, TranscriptReadBuffer}, + transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, +}; +use halo2_base::utils::ScalarField; +use halo2_base::{ + Context, + QuantumCell::{Existing, Witness}, + SKIP_FIRST_PASS, +}; +use itertools::Itertools; +use rand::rngs::OsRng; +use std::marker::PhantomData; + +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; + +use pprof::criterion::{Output, PProfProfiler}; +// Thanks to the example provided by @jebbow in his article +// https://www.jibbow.com/posts/criterion-flamegraphs/ + +const K: u32 = 19; + +fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) { + assert_eq!(a.len(), b.len()); + let a = ctx.assign_witnesses(a); + let b = ctx.assign_witnesses(b); + + let chip = GateChip::default(); + for _ in 0..(1 << K) / 16 - 10 { + chip.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); + } +} + +fn main() { + let k = 10u32; + // create circuit for keygen + let mut builder = GateThreadBuilder::new(false); + inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); + builder.config(k as usize, Some(20)); + let circuit = GateCircuitBuilder::mock(builder); + + // check the circuit is correct just in case + MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); + + let params = ParamsKZG::::setup(k, OsRng); + let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); + let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + + let break_points = circuit.break_points.take(); + + let mut builder = GateThreadBuilder::new(true); + let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); + let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); + inner_prod_bench(builder.main(0), a, b); + let circuit = GateCircuitBuilder::prover(builder, break_points); + + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255<_>>, + _, + >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("prover should not fail"); + + let strategy = SingleStrategy::new(¶ms); + let proof = transcript.finalize(); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); + verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + _, + >(¶ms, pk.get_vk(), strategy, &[&[]], &mut transcript) + .unwrap(); +} diff --git a/halo2-base/proptest-regressions/gates/tests/prop_test.txt b/halo2-base/proptest-regressions/gates/tests/prop_test.txt new file mode 100644 index 00000000..aa4e1000 --- /dev/null +++ b/halo2-base/proptest-regressions/gates/tests/prop_test.txt @@ -0,0 +1,11 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 8489bbcc3439950355c90ecbc92546a66e4b57eae0a3856e7a4ccb59bf74b4ce # shrinks to k = 1, len = 1, idx = 0, witness_vals = [0x0000000000000000000000000000000000000000000000000000000000000000] +cc b18c4f5e502fe36dbc2471f89a6ffb389beaf473b280e844936298ab1cf9b74e # shrinks to (k, len, idx, witness_vals) = (8, 2, 1, [0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000001]) +cc 4528fb02e7227f85116c2a16aef251b9c3b6d9c340ddb50b936c2140d7856cc4 # shrinks to inputs = ([], []) +cc 79bfe42c93b5962a38b2f831f1dd438d8381a24a6ce15bfb89a8562ce9af0a2d # shrinks to (k, len, idx, witness_vals) = (8, 62, 0, [0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000]) +cc d0e10a06108cb58995a8ae77a91b299fb6230e9e6220121c48f2488e5d199e82 # shrinks to input = (0x000000000000000000000000000000000000000000000000070a95cb0607bef9, 4096) diff --git a/halo2-base/src/gates/builder.rs b/halo2-base/src/gates/builder.rs new file mode 100644 index 00000000..22c2ce93 --- /dev/null +++ b/halo2-base/src/gates/builder.rs @@ -0,0 +1,796 @@ +use super::{ + flex_gate::{FlexGateConfig, GateStrategy, MAX_PHASE}, + range::{RangeConfig, RangeStrategy}, +}; +use crate::{ + halo2_proofs::{ + circuit::{self, Layouter, Region, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Instance, Selector}, + }, + utils::ScalarField, + AssignedValue, Context, SKIP_FIRST_PASS, +}; +use serde::{Deserialize, Serialize}; +use std::{ + cell::RefCell, + collections::{HashMap, HashSet}, + env::{set_var, var}, +}; + +mod parallelize; +pub use parallelize::*; + +/// Vector of thread advice column break points +pub type ThreadBreakPoints = Vec; +/// Vector of vectors tracking the thread break points across different halo2 phases +pub type MultiPhaseThreadBreakPoints = Vec; + +/// Stores the cell values loaded during the Keygen phase of a halo2 proof and breakpoints for multi-threading +#[derive(Clone, Debug, Default)] +pub struct KeygenAssignments { + /// Advice assignments + pub assigned_advices: HashMap<(usize, usize), (circuit::Cell, usize)>, // (key = ContextCell, value = (circuit::Cell, row offset)) + /// Constant assignments in Fixes Assignments + pub assigned_constants: HashMap, // (key = constant, value = circuit::Cell) + /// Advice column break points for threads in each phase. + pub break_points: MultiPhaseThreadBreakPoints, +} + +/// Builds the process for gate threading +#[derive(Clone, Debug, Default)] +pub struct GateThreadBuilder { + /// Threads for each challenge phase + pub threads: [Vec>; MAX_PHASE], + /// Max number of threads + thread_count: usize, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + pub witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + use_unknown: bool, +} + +impl GateThreadBuilder { + /// Creates a new [GateThreadBuilder] and spawns a main thread in phase 0. + /// * `witness_gen_only`: If true, the [GateThreadBuilder] is used for witness generation only. + /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). + /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. + pub fn new(witness_gen_only: bool) -> Self { + let mut threads = [(); MAX_PHASE].map(|_| vec![]); + // start with a main thread in phase 0 + threads[0].push(Context::new(witness_gen_only, 0)); + Self { threads, thread_count: 1, witness_gen_only, use_unknown: false } + } + + /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. + /// + /// Performs the witness assignment computations and then checks using normal programming logic whether the gate constraints are all satisfied. + pub fn mock() -> Self { + Self::new(false) + } + + /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. + /// + /// Performs the witness assignment computations and generates prover and verifier keys. + pub fn keygen() -> Self { + Self::new(false) + } + + /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to true. + /// + /// Performs the witness assignment computations and then runs the proving system. + pub fn prover() -> Self { + Self::new(true) + } + + /// Creates a new [GateThreadBuilder] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + pub fn unknown(self, use_unknown: bool) -> Self { + Self { use_unknown, ..self } + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + /// * `phase`: The challenge phase (as an index) of the gate thread. + pub fn main(&mut self, phase: usize) -> &mut Context { + if self.threads[phase].is_empty() { + self.new_thread(phase) + } else { + self.threads[phase].last_mut().unwrap() + } + } + + /// Returns the `witness_gen_only` flag. + pub fn witness_gen_only(&self) -> bool { + self.witness_gen_only + } + + /// Returns the `use_unknown` flag. + pub fn use_unknown(&self) -> bool { + self.use_unknown + } + + /// Returns the current number of threads in the [GateThreadBuilder]. + pub fn thread_count(&self) -> usize { + self.thread_count + } + + /// Creates a new thread id by incrementing the `thread count` + pub fn get_new_thread_id(&mut self) -> usize { + let thread_id = self.thread_count; + self.thread_count += 1; + thread_id + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self, phase: usize) -> &mut Context { + let thread_id = self.thread_count; + self.thread_count += 1; + self.threads[phase].push(Context::new(self.witness_gen_only, thread_id)); + self.threads[phase].last_mut().unwrap() + } + + /// Auto-calculates configuration parameters for the circuit + /// + /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) + /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. + pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let total_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) + .collect::>(); + // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) + // if this is too small, manual configuration will be needed + let num_advice_per_phase = total_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_lookup_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) + .collect::>(); + let num_lookup_advice_per_phase = total_lookup_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { + threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) + })) + .len(); + let num_fixed = (total_fixed + (1 << k) - 1) >> k; + + let params = FlexGateConfigParams { + strategy: GateStrategy::Vertical, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + }; + #[cfg(feature = "display")] + { + for phase in 0..MAX_PHASE { + if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { + println!( + "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", + phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], + ); + } + } + println!("Total {total_fixed} fixed cells"); + log::info!("Auto-calculated config params:\n {params:#?}"); + } + set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); + params + } + + /// Assigns all advice and fixed cells, turns on selectors, and imposes equality constraints. + /// + /// Returns the assigned advices, and constants in the form of [KeygenAssignments]. + /// + /// Assumes selector and advice columns are already allocated and of the same length. + /// + /// Note: `assign_all()` **should** be called during keygen or if using mock prover. It also works for the real prover, but there it is more optimal to use [`assign_threads_in`] instead. + /// * `config`: The [FlexGateConfig] of the circuit. + /// * `lookup_advice`: The lookup advice columns. + /// * `q_lookup`: The lookup advice selectors. + /// * `region`: The [Region] of the circuit. + /// * `assigned_advices`: The assigned advice cells. + /// * `assigned_constants`: The assigned fixed cells. + /// * `break_points`: The break points of the circuit. + pub fn assign_all( + &self, + config: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + region: &mut Region, + KeygenAssignments { + mut assigned_advices, + mut assigned_constants, + mut break_points + }: KeygenAssignments, + ) -> KeygenAssignments { + let use_unknown = self.use_unknown; + let max_rows = config.max_rows; + let mut fixed_col = 0; + let mut fixed_offset = 0; + for (phase, threads) in self.threads.iter().enumerate() { + let mut break_point = vec![]; + let mut gate_index = 0; + let mut row_offset = 0; + for ctx in threads { + let mut basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + assert_eq!(ctx.selector.len(), ctx.advice.len()); + + for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { + let column = basic_gate.value; + let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; + #[cfg(feature = "halo2-axiom")] + let cell = *region.assign_advice(column, row_offset, value).cell(); + #[cfg(not(feature = "halo2-axiom"))] + let cell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); + + // If selector enabled and row_offset is valid add break point to Keygen Assignments, account for break point overlap, and enforce equality constraint for gate outputs. + if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { + break_point.push(row_offset); + row_offset = 0; + gate_index += 1; + + // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety + basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + let column = basic_gate.value; + + #[cfg(feature = "halo2-axiom")] + { + let ncell = region.assign_advice(column, row_offset, value); + region.constrain_equal(ncell.cell(), &cell); + } + #[cfg(not(feature = "halo2-axiom"))] + { + let ncell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + region.constrain_equal(ncell, cell).unwrap(); + } + } + + if q { + basic_gate + .q_enable + .enable(region, row_offset) + .expect("enable selector should not fail"); + } + + row_offset += 1; + } + // Assign fixed cells + for (c, _) in ctx.constant_equality_constraints.iter() { + if assigned_constants.get(c).is_none() { + #[cfg(feature = "halo2-axiom")] + let cell = + region.assign_fixed(config.constants[fixed_col], fixed_offset, c); + #[cfg(not(feature = "halo2-axiom"))] + let cell = region + .assign_fixed( + || "", + config.constants[fixed_col], + fixed_offset, + || Value::known(*c), + ) + .unwrap() + .cell(); + assigned_constants.insert(*c, cell); + fixed_col += 1; + if fixed_col >= config.constants.len() { + fixed_col = 0; + fixed_offset += 1; + } + } + } + } + break_points.push(break_point); + } + // we constrain equality constraints in a separate loop in case context `i` contains references to context `j` for `j > i` + for (phase, threads) in self.threads.iter().enumerate() { + let mut lookup_offset = 0; + let mut lookup_col = 0; + for ctx in threads { + for (left, right) in &ctx.advice_equality_constraints { + let (left, _) = assigned_advices[&(left.context_id, left.offset)]; + let (right, _) = assigned_advices[&(right.context_id, right.offset)]; + #[cfg(feature = "halo2-axiom")] + region.constrain_equal(&left, &right); + #[cfg(not(feature = "halo2-axiom"))] + region.constrain_equal(left, right).unwrap(); + } + for (left, right) in &ctx.constant_equality_constraints { + let left = assigned_constants[left]; + let (right, _) = assigned_advices[&(right.context_id, right.offset)]; + #[cfg(feature = "halo2-axiom")] + region.constrain_equal(&left, &right); + #[cfg(not(feature = "halo2-axiom"))] + region.constrain_equal(left, right).unwrap(); + } + + for advice in &ctx.cells_to_lookup { + // if q_lookup is Some, that means there should be a single advice column and it has lookup enabled + let cell = advice.cell.unwrap(); + let (acell, row_offset) = assigned_advices[&(cell.context_id, cell.offset)]; + if let Some(q_lookup) = q_lookup[phase] { + assert_eq!(config.basic_gates[phase].len(), 1); + q_lookup.enable(region, row_offset).unwrap(); + continue; + } + // otherwise, we copy the advice value to the special lookup_advice columns + if lookup_offset >= max_rows { + lookup_offset = 0; + lookup_col += 1; + } + let value = advice.value; + let value = if use_unknown { Value::unknown() } else { Value::known(value) }; + let column = lookup_advice[phase][lookup_col]; + + #[cfg(feature = "halo2-axiom")] + { + let bcell = region.assign_advice(column, lookup_offset, value); + region.constrain_equal(&acell, bcell.cell()); + } + #[cfg(not(feature = "halo2-axiom"))] + { + let bcell = region + .assign_advice(|| "", column, lookup_offset, || value) + .expect("assign_advice should not fail") + .cell(); + region.constrain_equal(acell, bcell).unwrap(); + } + lookup_offset += 1; + } + } + } + KeygenAssignments { assigned_advices, assigned_constants, break_points } + } +} + +/// Assigns threads to regions of advice column. +/// +/// Uses preprocessed `break_points` to assign where to divide the advice column into a new column for each thread. +/// +/// Performs only witness generation, so should only be evoked during proving not keygen. +/// +/// Assumes that the advice columns are already assigned. +/// * `phase` - the phase of the circuit +/// * `threads` - [Vec] threads to assign +/// * `config` - immutable reference to the configuration of the circuit +/// * `lookup_advice` - Slice of lookup advice columns +/// * `region` - mutable reference to the region to assign threads to +/// * `break_points` - the preprocessed break points for the threads +pub fn assign_threads_in( + phase: usize, + threads: Vec>, + config: &FlexGateConfig, + lookup_advice: &[Column], + region: &mut Region, + break_points: ThreadBreakPoints, +) { + if config.basic_gates[phase].is_empty() { + assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); + return; + } + + let mut break_points = break_points.into_iter(); + let mut break_point = break_points.next(); + + let mut gate_index = 0; + let mut column = config.basic_gates[phase][gate_index].value; + let mut row_offset = 0; + + let mut lookup_offset = 0; + let mut lookup_advice = lookup_advice.iter(); + let mut lookup_column = lookup_advice.next(); + for ctx in threads { + // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns + if lookup_column.is_some() { + for advice in ctx.cells_to_lookup { + if lookup_offset >= config.max_rows { + lookup_offset = 0; + lookup_column = lookup_advice.next(); + } + // Assign the lookup advice values to the lookup_column + let value = advice.value; + let lookup_column = *lookup_column.unwrap(); + #[cfg(feature = "halo2-axiom")] + region.assign_advice(lookup_column, lookup_offset, Value::known(value)); + #[cfg(not(feature = "halo2-axiom"))] + region + .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) + .unwrap(); + + lookup_offset += 1; + } + } + // Assign advice values to the advice columns in each [Context] + for advice in ctx.advice { + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + + if break_point == Some(row_offset) { + break_point = break_points.next(); + row_offset = 0; + gate_index += 1; + column = config.basic_gates[phase][gate_index].value; + + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + } + + row_offset += 1; + } + } +} + +/// A Config struct defining the parameters for a FlexGate circuit. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FlexGateConfigParams { + /// The gate strategy used for the advice column of the circuit and applied at every row. + pub strategy: GateStrategy, + /// Security parameter `k` used for the keygen. + pub k: usize, + /// The number of advice columns per phase + pub num_advice_per_phase: Vec, + /// The number of advice columns that do not have lookup enabled per phase + pub num_lookup_advice_per_phase: Vec, + /// The number of fixed columns per phase + pub num_fixed: usize, +} + +/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. +#[derive(Clone, Debug)] +pub struct GateCircuitBuilder { + /// The Thread Builder for the circuit + pub builder: RefCell>, // `RefCell` is just to trick circuit `synthesize` to take ownership of the inner builder + /// Break points for threads within the circuit + pub break_points: RefCell, // `RefCell` allows the circuit to record break points in a keygen call of `synthesize` for use in later witness gen +} + +impl GateCircuitBuilder { + /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to true. + pub fn keygen(builder: GateThreadBuilder) -> Self { + Self { builder: RefCell::new(builder.unknown(true)), break_points: RefCell::new(vec![]) } + } + + /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to false. + pub fn mock(builder: GateThreadBuilder) -> Self { + Self { builder: RefCell::new(builder.unknown(false)), break_points: RefCell::new(vec![]) } + } + + /// Creates a new [GateCircuitBuilder]. + pub fn prover( + builder: GateThreadBuilder, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self { + Self { builder: RefCell::new(builder), break_points: RefCell::new(break_points) } + } + + /// Synthesizes from the [GateCircuitBuilder] by populating the advice column and assigning new threads if witness generation is performed. + pub fn sub_synthesize( + &self, + gate: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + layouter: &mut impl Layouter, + ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { + let mut first_pass = SKIP_FIRST_PASS; + let mut assigned_advices = HashMap::new(); + layouter + .assign_region( + || "GateCircuitBuilder generated circuit", + |mut region| { + if first_pass { + first_pass = false; + return Ok(()); + } + // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize + // If we are not performing witness generation only, we can skip the first pass and assign threads directly + if !self.builder.borrow().witness_gen_only { + // clone the builder so we can re-use the circuit for both vk and pk gen + let builder = self.builder.borrow().clone(); + for threads in builder.threads.iter().skip(1) { + assert!( + threads.is_empty(), + "GateCircuitBuilder only supports FirstPhase for now" + ); + } + let assignments = builder.assign_all( + gate, + lookup_advice, + q_lookup, + &mut region, + Default::default(), + ); + *self.break_points.borrow_mut() = assignments.break_points; + assigned_advices = assignments.assigned_advices; + } else { + // If we are only generating witness, we can skip the first pass and assign threads directly + let builder = self.builder.take(); + let break_points = self.break_points.take(); + for (phase, (threads, break_points)) in builder + .threads + .into_iter() + .zip(break_points.into_iter()) + .enumerate() + .take(1) + { + assign_threads_in( + phase, + threads, + gate, + lookup_advice.get(phase).unwrap_or(&vec![]), + &mut region, + break_points, + ); + } + } + Ok(()) + }, + ) + .unwrap(); + assigned_advices + } +} + +impl Circuit for GateCircuitBuilder { + type Config = FlexGateConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the circuit without withnesses filled in. + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config]. + fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase: _, + num_fixed, + k, + } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + self.sub_synthesize(&config, &[], &[], &mut layouter); + Ok(()) + } +} + +/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. +#[derive(Clone, Debug)] +pub struct RangeCircuitBuilder(pub GateCircuitBuilder); + +impl RangeCircuitBuilder { + /// Creates an instance of the [RangeCircuitBuilder] and executes in keygen mode. + pub fn keygen(builder: GateThreadBuilder) -> Self { + Self(GateCircuitBuilder::keygen(builder)) + } + + /// Creates a mock instance of the [RangeCircuitBuilder]. + pub fn mock(builder: GateThreadBuilder) -> Self { + Self(GateCircuitBuilder::mock(builder)) + } + + /// Creates an instance of the [RangeCircuitBuilder] and executes in prover mode. + pub fn prover( + builder: GateThreadBuilder, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self { + Self(GateCircuitBuilder::prover(builder, break_points)) + } +} + +impl Circuit for RangeCircuitBuilder { + type Config = RangeConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + let strategy = match strategy { + GateStrategy::Vertical => RangeStrategy::Vertical, + }; + let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); + RangeConfig::configure( + meta, + strategy, + &num_advice_per_phase, + &num_lookup_advice_per_phase, + num_fixed, + lookup_bits, + k, + ) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // only load lookup table if we are actually doing lookups + if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 + || !config.q_lookup.iter().all(|q| q.is_none()) + { + config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); + Ok(()) + } +} + +/// Configuration with [`RangeConfig`] and a single public instance column. +#[derive(Clone, Debug)] +pub struct RangeWithInstanceConfig { + /// The underlying range configuration + pub range: RangeConfig, + /// The public instance column + pub instance: Column, +} + +/// This is an extension of [`RangeCircuitBuilder`] that adds support for public instances (aka public inputs+outputs) +/// +/// The intended design is that a [`GateThreadBuilder`] is populated and then produces some assigned instances, which are supplied as `assigned_instances` to this struct. +/// The [`Circuit`] implementation for this struct will then expose these instances and constrain them using the Halo2 API. +#[derive(Clone, Debug)] +pub struct RangeWithInstanceCircuitBuilder { + /// The underlying circuit builder + pub circuit: RangeCircuitBuilder, + /// The assigned instances to expose publicly at the end of circuit synthesis + pub assigned_instances: Vec>, +} + +impl RangeWithInstanceCircuitBuilder { + /// See [`RangeCircuitBuilder::keygen`] + pub fn keygen( + builder: GateThreadBuilder, + assigned_instances: Vec>, + ) -> Self { + Self { circuit: RangeCircuitBuilder::keygen(builder), assigned_instances } + } + + /// See [`RangeCircuitBuilder::mock`] + pub fn mock(builder: GateThreadBuilder, assigned_instances: Vec>) -> Self { + Self { circuit: RangeCircuitBuilder::mock(builder), assigned_instances } + } + + /// See [`RangeCircuitBuilder::prover`] + pub fn prover( + builder: GateThreadBuilder, + assigned_instances: Vec>, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self { + Self { circuit: RangeCircuitBuilder::prover(builder, break_points), assigned_instances } + } + + /// Creates a new instance of the [RangeWithInstanceCircuitBuilder]. + pub fn new(circuit: RangeCircuitBuilder, assigned_instances: Vec>) -> Self { + Self { circuit, assigned_instances } + } + + /// Calls [`GateThreadBuilder::config`] + pub fn config(&self, k: u32, minimum_rows: Option) -> FlexGateConfigParams { + self.circuit.0.builder.borrow().config(k as usize, minimum_rows) + } + + /// Gets the break points of the circuit. + pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { + self.circuit.0.break_points.borrow().clone() + } + + /// Gets the number of instances. + pub fn instance_count(&self) -> usize { + self.assigned_instances.len() + } + + /// Gets the instances. + pub fn instance(&self) -> Vec { + self.assigned_instances.iter().map(|v| *v.value()).collect() + } +} + +impl Circuit for RangeWithInstanceCircuitBuilder { + type Config = RangeWithInstanceConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let range = RangeCircuitBuilder::configure(meta); + let instance = meta.instance_column(); + meta.enable_equality(instance); + RangeWithInstanceConfig { range, instance } + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // copied from RangeCircuitBuilder::synthesize but with extra logic to expose public instances + let range = config.range; + let circuit = &self.circuit.0; + // only load lookup table if we are actually doing lookups + if range.lookup_advice.iter().map(|a| a.len()).sum::() != 0 + || !range.q_lookup.iter().all(|q| q.is_none()) + { + range.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + // we later `take` the builder, so we need to save this value + let witness_gen_only = circuit.builder.borrow().witness_gen_only(); + let assigned_advices = circuit.sub_synthesize( + &range.gate, + &range.lookup_advice, + &range.q_lookup, + &mut layouter, + ); + + if !witness_gen_only { + // expose public instances + let mut layouter = layouter.namespace(|| "expose"); + for (i, instance) in self.assigned_instances.iter().enumerate() { + let cell = instance.cell.unwrap(); + let (cell, _) = assigned_advices + .get(&(cell.context_id, cell.offset)) + .expect("instance not assigned"); + layouter.constrain_instance(*cell, config.instance, i); + } + } + Ok(()) + } +} + +/// Defines stage of the circuit builder. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CircuitBuilderStage { + /// Keygen phase + Keygen, + /// Prover Circuit + Prover, + /// Mock Circuit + Mock, +} diff --git a/halo2-base/src/gates/builder/parallelize.rs b/halo2-base/src/gates/builder/parallelize.rs new file mode 100644 index 00000000..ab9171d5 --- /dev/null +++ b/halo2-base/src/gates/builder/parallelize.rs @@ -0,0 +1,38 @@ +use itertools::Itertools; +use rayon::prelude::*; + +use crate::{utils::ScalarField, Context}; + +use super::GateThreadBuilder; + +/// Utility function to parallelize an operation involving [`Context`]s in phase `phase`. +pub fn parallelize_in( + phase: usize, + builder: &mut GateThreadBuilder, + input: Vec, + f: FR, +) -> Vec +where + F: ScalarField, + T: Send, + R: Send, + FR: Fn(&mut Context, T) -> R + Send + Sync, +{ + let witness_gen_only = builder.witness_gen_only(); + // to prevent concurrency issues with context id, we generate all the ids first + let ctx_ids = input.iter().map(|_| builder.get_new_thread_id()).collect_vec(); + let (outputs, mut ctxs): (Vec<_>, Vec<_>) = input + .into_par_iter() + .zip(ctx_ids.into_par_iter()) + .map(|(input, ctx_id)| { + // create new context + let mut ctx = Context::new(witness_gen_only, ctx_id); + let output = f(&mut ctx, input); + (output, ctx) + }) + .unzip(); + // we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused + builder.threads[phase].append(&mut ctxs); + + outputs +} diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index fdbd8652..1907521e 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -1,57 +1,56 @@ -use super::{ - AssignedValue, Context, GateInstructions, - QuantumCell::{self, Constant, Existing, Witness}, -}; -use crate::halo2_proofs::{ - circuit::Value, - plonk::{ - Advice, Assigned, Column, ConstraintSystem, FirstPhase, Fixed, SecondPhase, Selector, - ThirdPhase, +use crate::{ + halo2_proofs::{ + plonk::{ + Advice, Assigned, Column, ConstraintSystem, FirstPhase, Fixed, SecondPhase, Selector, + ThirdPhase, + }, + poly::Rotation, }, - poly::Rotation, + utils::ScalarField, + AssignedValue, Context, + QuantumCell::{self, Constant, Existing, Witness, WitnessFraction}, }; -use crate::utils::ScalarField; -use itertools::Itertools; +use serde::{Deserialize, Serialize}; use std::{ - iter::{self, once}, + iter::{self}, marker::PhantomData, }; -/// The maximum number of phases halo2 currently supports +/// The maximum number of phases in halo2. pub const MAX_PHASE: usize = 3; -#[derive(Clone, Copy, Debug, PartialEq)] +/// Specifies the gate strategy for the gate chip +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub enum GateStrategy { + /// # Vertical Gate Strategy: + /// `q_0 * (a + b * c - d) = 0` + /// where + /// * a = value[0], b = value[1], c = value[2], d = value[3] + /// * q = q_enable[0] + /// * q is either 0 or 1 so this is just a simple selector + /// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. Vertical, - PlonkPlus, } +/// A configuration for a basic gate chip describing the selector, and advice column values. #[derive(Clone, Debug)] pub struct BasicGateConfig { + /// [Selector] column that stores selector values that are used to activate gates in the advice column. // `q_enable` will have either length 1 or 2, depending on the strategy - - // If strategy is Vertical, then this is the basic vertical gate - // `q_0 * (a + b * c - d) = 0` - // where - // * a = value[0], b = value[1], c = value[2], d = value[3] - // * q = q_enable[0] - // * q_i is either 0 or 1 so this is just a simple selector - // We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate - - // If strategy is PlonkPlus, then this is a slightly extended version of the vanilla plonk (vertical) gate - // `q_io * (a + q_left * b + q_right * c + q_mul * b * c - d)` - // where - // * a = value[0], b = value[1], c = value[2], d = value[3] - // * the q_{} can be any fixed values in F, placed in two fixed columns - // * it is crucial that q_io goes in its own selector column! we need it to be 0, 1 to turn on/off the gate pub q_enable: Selector, - pub q_enable_plus: Vec>, - // one column to store the inputs and outputs of the gate + /// [Column] that stores the advice values of the gate. pub value: Column, + /// Marker for the field type. _marker: PhantomData, } impl BasicGateConfig { + /// Instantiates a new [BasicGateConfig]. + /// + /// Assumes `phase` is in the range [0, MAX_PHASE). + /// * `meta`: [ConstraintSystem] used for the gate + /// * `strategy`: The [GateStrategy] to use for the gate + /// * `phase`: The phase to add the gate to pub fn configure(meta: &mut ConstraintSystem, strategy: GateStrategy, phase: u8) -> Self { let value = match phase { 0 => meta.advice_column_in(FirstPhase), @@ -65,22 +64,17 @@ impl BasicGateConfig { match strategy { GateStrategy::Vertical => { - let config = Self { q_enable, q_enable_plus: vec![], value, _marker: PhantomData }; + let config = Self { q_enable, value, _marker: PhantomData }; config.create_gate(meta); config } - GateStrategy::PlonkPlus => { - let q_aux = meta.fixed_column(); - let config = - Self { q_enable, q_enable_plus: vec![q_aux], value, _marker: PhantomData }; - config.create_plonk_gate(meta); - config - } } } + /// Wrapper for [ConstraintSystem].create_gate(name, meta) creates a gate form [q * (a + b * c - out)]. + /// * `meta`: [ConstraintSystem] used for the gate fn create_gate(&self, meta: &mut ConstraintSystem) { - meta.create_gate("1 column a * b + c = out", |meta| { + meta.create_gate("1 column a + b * c = out", |meta| { let q = meta.query_selector(self.q_enable); let a = meta.query_advice(self.value, Rotation::cur()); @@ -91,53 +85,41 @@ impl BasicGateConfig { vec![q * (a + b * c - out)] }) } - - fn create_plonk_gate(&self, meta: &mut ConstraintSystem) { - meta.create_gate("plonk plus", |meta| { - // q_io * (a + q_left * b + q_right * c + q_mul * b * c - d) - // the gate is turned "off" as long as q_io = 0 - let q_io = meta.query_selector(self.q_enable); - - let q_mul = meta.query_fixed(self.q_enable_plus[0], Rotation::cur()); - let q_left = meta.query_fixed(self.q_enable_plus[0], Rotation::next()); - let q_right = meta.query_fixed(self.q_enable_plus[0], Rotation(2)); - - let a = meta.query_advice(self.value, Rotation::cur()); - let b = meta.query_advice(self.value, Rotation::next()); - let c = meta.query_advice(self.value, Rotation(2)); - let d = meta.query_advice(self.value, Rotation(3)); - - vec![q_io * (a + q_left * b.clone() + q_right * c.clone() + q_mul * b * c - d)] - }) - } } +/// Defines a configuration for a flex gate chip describing the selector, and advice column values for the chip. #[derive(Clone, Debug)] pub struct FlexGateConfig { + /// A [Vec] of [BasicGateConfig] that define gates for each halo2 phase. pub basic_gates: [Vec>; MAX_PHASE], - // `constants` is a vector of fixed columns for allocating constant values + /// A [Vec] of [Fixed] [Column]s for allocating constant values. pub constants: Vec>, + /// Number of advice columns for each halo2 phase. pub num_advice: [usize; MAX_PHASE], - strategy: GateStrategy, - gate_len: usize, - pub context_id: usize, + /// [GateStrategy] for the flex gate. + _strategy: GateStrategy, + /// Max number of rows in flex gate. pub max_rows: usize, - - pub pow_of_two: Vec, - /// To avoid Montgomery conversion in `F::from` for common small numbers, we keep a cache of field elements - pub field_element_cache: Vec, } impl FlexGateConfig { + /// Generates a new [FlexGateConfig] + /// + /// Assumes `num_advice` is a [Vec] of length [MAX_PHASE] + /// * `meta`: [ConstraintSystem] of the circuit + /// * `strategy`: [GateStrategy] of the flex gate + /// * `num_advice`: Number of [Advice] [Column]s in each phase + /// * `num_fixed`: Number of [Fixed] [Column]s in each phase + /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) pub fn configure( meta: &mut ConstraintSystem, strategy: GateStrategy, num_advice: &[usize], num_fixed: usize, - context_id: usize, // log2_ceil(# rows in circuit) circuit_degree: usize, ) -> Self { + // create fixed (constant) columns and enable equality constraints let mut constants = Vec::with_capacity(num_fixed); for _i in 0..num_fixed { let c = meta.fixed_column(); @@ -145,17 +127,9 @@ impl FlexGateConfig { // meta.enable_constant(c); constants.push(c); } - let mut pow_of_two = Vec::with_capacity(F::NUM_BITS as usize); - let two = F::from(2); - pow_of_two.push(F::one()); - pow_of_two.push(two); - for _ in 2..F::NUM_BITS { - pow_of_two.push(two * pow_of_two.last().unwrap()); - } - let field_element_cache = (0..1024).map(|i| F::from(i)).collect(); match strategy { - GateStrategy::Vertical | GateStrategy::PlonkPlus => { + GateStrategy::Vertical => { let mut basic_gates = [(); MAX_PHASE].map(|_| vec![]); let mut num_advice_array = [0usize; MAX_PHASE]; for ((phase, &num_columns), gates) in @@ -170,528 +144,879 @@ impl FlexGateConfig { basic_gates, constants, num_advice: num_advice_array, - strategy, - gate_len: 4, - context_id, + _strategy: strategy, /// Warning: this needs to be updated if you create more advice columns after this `FlexGateConfig` is created max_rows: (1 << circuit_degree) - meta.minimum_rows(), - pow_of_two, - field_element_cache, } } } } +} + +/// Trait that defines basic arithmetic operations for a gate. +pub trait GateInstructions { + /// Returns the [GateStrategy] for the gate. + fn strategy(&self) -> GateStrategy; + + /// Returns a slice of the [ScalarField] field elements 2^i for i in 0..F::NUM_BITS. + fn pow_of_two(&self) -> &[F]; - pub fn inner_product_simple<'a, 'b: 'a>( + /// Converts a [u64] into a scalar field element [ScalarField]. + fn get_field_element(&self, n: u64) -> F; + + /// Constrains and returns `a + b * 1 = out`. + /// + /// Defines a vertical gate of form | a | b | 1 | a + b | where (a + b) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to add to 'a` + fn add( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> AssignedValue<'b, F> { - let mut sum; - let mut a = a.into_iter(); - let mut b = b.into_iter().peekable(); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let out_val = *a.value() + b.value(); + ctx.assign_region_last([a, b, Constant(F::one()), Witness(out_val)], [0]) + } - let cells = if matches!(b.peek(), Some(Constant(c)) if c == &F::one()) { - b.next(); - let start_a = a.next().unwrap(); - sum = start_a.value().copied(); - iter::once(start_a) - } else { - sum = Value::known(F::zero()); - iter::once(Constant(F::zero())) - } - .chain(a.zip(b).flat_map(|(a, b)| { - sum = sum + a.value().zip(b.value()).map(|(a, b)| *a * b); - [a, b, Witness(sum)] - })); + /// Constrains and returns `a + b * (-1) = out`. + /// + /// Defines a vertical gate of form | a - b | b | 1 | a |, where (a - b) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to subtract from 'a' + fn sub( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let out_val = *a.value() - b.value(); + // slightly better to not have to compute -F::one() since F::one() is cached + ctx.assign_region([Witness(out_val), b, Constant(F::one()), a], [0]); + ctx.get(-4) + } - let (lo, hi) = cells.size_hint(); - debug_assert_eq!(Some(lo), hi); - let len = lo / 3; - let gate_offsets = (0..len).map(|i| (3 * i as isize, None)); - self.assign_region_last(ctx, cells, gate_offsets) + /// Constrains and returns `a * (-1) = out`. + /// + /// Defines a vertical gate of form | a | -a | 1 | 0 |, where (-a) = out. + /// * `ctx`: the [Context] to add the constraints to + /// * `a`: [QuantumCell] value to negate + fn neg(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + let a = a.into(); + let out_val = -*a.value(); + ctx.assign_region([a, Witness(out_val), Constant(F::one()), Constant(F::zero())], [0]); + ctx.get(-3) } - pub fn inner_product_simple_with_assignments<'a, 'b: 'a>( + /// Constrains and returns `0 + a * b = out`. + /// + /// Defines a vertical gate of form | 0 | a | b | a * b |, where (a * b) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to multiply 'a' by + fn mul( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> (Vec>, AssignedValue<'b, F>) { - let mut sum; - let mut a = a.into_iter(); - let mut b = b.into_iter().peekable(); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let out_val = *a.value() * b.value(); + ctx.assign_region_last([Constant(F::zero()), a, b, Witness(out_val)], [0]) + } - let cells = if matches!(b.peek(), Some(Constant(c)) if c == &F::one()) { - b.next(); - let start_a = a.next().unwrap(); - sum = start_a.value().copied(); - iter::once(start_a) - } else { - sum = Value::known(F::zero()); - iter::once(Constant(F::zero())) - } - .chain(a.zip(b).flat_map(|(a, b)| { - sum = sum + a.value().zip(b.value()).map(|(a, b)| *a * b); - [a, b, Witness(sum)] - })); + /// Constrains and returns `a * b + c = out`. + /// + /// Defines a vertical gate of form | c | a | b | a * b + c |, where (a * b + c) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to multiply 'a' by + /// * `c`: [QuantumCell] value to add to 'a * b' + fn mul_add( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + c: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let c = c.into(); + let out_val = *a.value() * b.value() + c.value(); + ctx.assign_region_last([c, a, b, Witness(out_val)], [0]) + } - let (lo, hi) = cells.size_hint(); - debug_assert_eq!(Some(lo), hi); - let len = lo / 3; - let gate_offsets = (0..len).map(|i| (3 * i as isize, None)); - let mut assignments = self.assign_region(ctx, cells, gate_offsets); - let last = assignments.pop().unwrap(); - (assignments, last) + /// Constrains and returns `(1 - a) * b = b - a * b`. + /// + /// Defines a vertical gate of form | (1 - a) * b | a | b | b |, where (1 - a) * b = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to multiply 'a' by + fn mul_not( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let out_val = (F::one() - a.value()) * b.value(); + ctx.assign_region_smart([Witness(out_val), a, b, b], [0], [(2, 3)], []); + ctx.get(-4) + } + + /// Constrains that x is boolean (e.g. 0 or 1). + /// + /// Defines a vertical gate of form | 0 | x | x | x |. + /// * `ctx`: [Context] to add the constraints to + /// * `x`: [QuantumCell] value to constrain + fn assert_bit(&self, ctx: &mut Context, x: AssignedValue) { + ctx.assign_region([Constant(F::zero()), Existing(x), Existing(x), Existing(x)], [0]); } - fn inner_product_with_assignments<'a, 'b: 'a>( + /// Constrains and returns a / b = 0. + /// + /// Defines a vertical gate of form | 0 | b^1 * a | b | a |, where b^1 * a = out. + /// + /// Assumes `b != 0`. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to divide 'a' by + fn div_unsafe( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> (Vec>, AssignedValue<'b, F>) { - // we will do special handling of the cases where one of the vectors is all constants - match self.strategy { - GateStrategy::PlonkPlus => { - let vec_a = a.into_iter().collect::>(); - let vec_b = b.into_iter().collect::>(); - if vec_b.iter().all(|b| matches!(b, Constant(_))) { - let vec_b: Vec = vec_b - .into_iter() - .map(|b| if let Constant(c) = b { c } else { unreachable!() }) - .collect(); - let k = vec_a.len(); - let gate_segment = self.gate_len - 2; - - // Say a = [a0, .., a4] for example - // Then to compute we use transpose of - // | 0 | a0 | a1 | x | a2 | a3 | y | a4 | 0 | | - // while letting q_enable equal transpose of - // | * | | | * | | | * | | | | - // | 0 | b0 | b1 | 0 | b2 | b3 | 0 | b4 | 0 | - - // we effect a small optimization if we know the constant b0 == 1: then instead of starting from 0 we can start from a0 - // this is a peculiarity of our plonk-plus gate - let start_ida: usize = (vec_b[0] == F::one()).into(); - if start_ida == 1 && k == 1 { - // this is just a0 * 1 = a0; you're doing nothing, why are you calling this function? - return (vec![], self.assign_region_last(ctx, vec_a, vec![])); - } - let k_chunks = (k - start_ida + gate_segment - 1) / gate_segment; - let mut cells = Vec::with_capacity(1 + (gate_segment + 1) * k_chunks); - let mut gate_offsets = Vec::with_capacity(k_chunks); - let mut running_sum = - if start_ida == 1 { vec_a[0].clone() } else { Constant(F::zero()) }; - cells.push(running_sum.clone()); - for i in 0..k_chunks { - let window = (start_ida + i * gate_segment) - ..std::cmp::min(k, start_ida + (i + 1) * gate_segment); - // we add a 0 at the start for q_mul = 0 - let mut c_window = [&[F::zero()], &vec_b[window.clone()]].concat(); - c_window.extend((c_window.len()..(gate_segment + 1)).map(|_| F::zero())); - // c_window should have length gate_segment + 1 - gate_offsets.push(( - (i * (gate_segment + 1)) as isize, - Some(c_window.try_into().expect("q_coeff should be correct len")), - )); - - cells.extend(window.clone().map(|j| vec_a[j].clone())); - cells.extend((window.len()..gate_segment).map(|_| Constant(F::zero()))); - running_sum = Witness( - window.into_iter().fold(running_sum.value().copied(), |sum, j| { - sum + Value::known(vec_b[j]) * vec_a[j].value() - }), - ); - cells.push(running_sum.clone()); - } - let mut assignments = self.assign_region(ctx, cells, gate_offsets); - let last = assignments.pop().unwrap(); - (assignments, last) - } else if vec_a.iter().all(|a| matches!(a, Constant(_))) { - self.inner_product_with_assignments(ctx, vec_b, vec_a) - } else { - self.inner_product_simple_with_assignments(ctx, vec_a, vec_b) - } - } - _ => self.inner_product_simple_with_assignments(ctx, a, b), - } + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + // TODO: if really necessary, make `c` of type `Assigned` + // this would require the API using `Assigned` instead of `F` everywhere, so leave as last resort + let c = b.value().invert().unwrap() * a.value(); + ctx.assign_region([Constant(F::zero()), Witness(c), b, a], [0]); + ctx.get(-3) } -} -impl GateInstructions for FlexGateConfig { - fn strategy(&self) -> GateStrategy { - self.strategy + /// Constrains that `a` is equal to `constant` value. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `constant`: constant value to constrain `a` to be equal to + fn assert_is_const(&self, ctx: &mut Context, a: &AssignedValue, constant: &F) { + if !ctx.witness_gen_only { + ctx.constant_equality_constraints.push((*constant, a.cell.unwrap())); + } } - fn context_id(&self) -> usize { - self.context_id + + /// Constrains and returns the inner product of ``. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to take inner product of `a` by + fn inner_product( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> AssignedValue + where + QA: Into>; + + /// Returns the inner product of `` and the last element of `a` now assigned, i.e. `(inner_product_, last_element_a)`. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] of the circuit + /// * `a`: Iterator of [QuantumCell]s + /// * `b`: Iterator of [QuantumCell]s to take inner product of `a` by + fn inner_product_left_last( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, AssignedValue) + where + QA: Into>; + + /// Calculates and constrains the inner product. + /// + /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to calculate the partial sums of the inner product of `a` by. + fn inner_product_with_sums<'thread, QA>( + &self, + ctx: &'thread mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> Box> + 'thread> + where + QA: Into>; + + /// Constrains and returns the sum of [QuantumCell]'s in iterator `a`. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values to sum + fn sum(&self, ctx: &mut Context, a: impl IntoIterator) -> AssignedValue + where + Q: Into>, + { + let mut a = a.into_iter().peekable(); + let start = a.next(); + if start.is_none() { + return ctx.load_zero(); + } + let start = start.unwrap().into(); + if a.peek().is_none() { + return ctx.assign_region_last([start], []); + } + let (len, hi) = a.size_hint(); + assert_eq!(Some(len), hi); + + let mut sum = *start.value(); + let cells = iter::once(start).chain(a.flat_map(|a| { + let a = a.into(); + sum += a.value(); + [a, Constant(F::one()), Witness(sum)] + })); + ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } - fn pow_of_two(&self) -> &[F] { - &self.pow_of_two + + /// Calculates and constrains the sum of the elements of `a`. + /// + /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j]`. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values to sum + fn partial_sums<'thread, Q>( + &self, + ctx: &'thread mut Context, + a: impl IntoIterator, + ) -> Box> + 'thread> + where + Q: Into>, + { + let mut a = a.into_iter().peekable(); + let start = a.next(); + if start.is_none() { + return Box::new(iter::once(ctx.load_zero())); + } + let start = start.unwrap().into(); + if a.peek().is_none() { + return Box::new(iter::once(ctx.assign_region_last([start], []))); + } + let (len, hi) = a.size_hint(); + assert_eq!(Some(len), hi); + + let mut sum = *start.value(); + let cells = iter::once(start).chain(a.flat_map(|a| { + let a = a.into(); + sum += a.value(); + [a, Constant(F::one()), Witness(sum)] + })); + ctx.assign_region(cells, (0..len).map(|i| 3 * i as isize)); + Box::new((0..=len).rev().map(|i| ctx.get(-1 - 3 * (i as isize)))) } - fn get_field_element(&self, n: u64) -> F { - let get = self.field_element_cache.get(n as usize); - if let Some(fe) = get { - *fe + + /// Calculates and constrains the accumulated product of 'a' and 'b' i.e. `x_i = b_1 * (a_1...a_{i - 1}) + /// + b_2 * (a_2...a_{i - 1}) + /// + ... + /// + b_i` + /// + /// Returns the assignment trace where `output[i]` is the running accumulated product x_i. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to take the accumulated product of `a` by + fn accumulated_product( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> Vec> + where + QA: Into>, + QB: Into>, + { + let mut b = b.into_iter(); + let mut a = a.into_iter(); + let b_first = b.next(); + if let Some(b_first) = b_first { + let b_first = ctx.assign_region_last([b_first], []); + std::iter::successors(Some(b_first), |x| { + a.next().zip(b.next()).map(|(a, b)| self.mul_add(ctx, Existing(*x), a, b)) + }) + .collect() } else { - F::from(n) + vec![] } } - /// All indices in `gate_offsets` are with respect to `inputs` indices - /// * `gate_offsets` specifies indices to enable selector for the gate - /// * `gate_offsets` specifies (index, Option<[q_left, q_right, q_mul, q_const, q_out]>) - /// * second coordinate should only be set if using strategy PlonkPlus; if not set, default to [1, 0, 0] - /// * allow the index in `gate_offsets` to be negative in case we want to do advanced overlapping - /// * gate_index can either be set if you know the specific column you want to assign to, or None if you want to auto-select index - /// * only selects from advice columns in `ctx.current_phase` - // same as `assign_region` except you can specify the `phase` to assign in - fn assign_region_in<'a, 'b: 'a>( + + /// Constrains and returns the sum of products of `coeff * (a * b)` defined in `values` plus a variable `var` e.g. + /// `x = var + values[0].0 * (values[0].1 * values[0].2) + values[1].0 * (values[1].1 * values[1].2) + ... + values[n].0 * (values[n].1 * values[n].2)`. + /// * `ctx`: [Context] to add the constraints to. + /// * `values`: Iterator of tuples `(coeff, a, b)` where `coeff` is a field element, `a` and `b` are [QuantumCell]'s. + /// * `var`: [QuantumCell] that represents the value of a variable added to the sum. + fn sum_products_with_coeff_and_var( &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - phase: usize, - ) -> Vec> { - // We enforce the pattern that you should assign everything in current phase at once and then move onto next phase - debug_assert_eq!(phase, ctx.current_phase()); - - let inputs = inputs.into_iter(); - let (len, hi) = inputs.size_hint(); - debug_assert_eq!(Some(len), hi); - // we index into `advice_alloc` twice so this assert should save a bound check - assert!(self.context_id < ctx.advice_alloc.len(), "context id out of bounds"); - - let (gate_index, row_offset) = { - let alloc = ctx.advice_alloc.get_mut(self.context_id).unwrap(); - - if alloc.1 + len >= ctx.max_rows { - alloc.1 = 0; - alloc.0 += 1; - } - *alloc + ctx: &mut Context, + values: impl IntoIterator, QuantumCell)>, + var: QuantumCell, + ) -> AssignedValue; + + /// Constrains and returns `a || b`, assuming `a` and `b` are boolean. + /// + /// Defines a vertical gate of form `| 1 - b | 1 | b | 1 | b | a | 1 - b | out |`, where `out = a + b - a * b`. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + /// * `b`: [QuantumCell] that contains a boolean value. + fn or( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let not_b_val = F::one() - b.value(); + let out_val = *a.value() + b.value() - *a.value() * b.value(); + let cells = [ + Witness(not_b_val), + Constant(F::one()), + b, + Constant(F::one()), + b, + a, + Witness(not_b_val), + Witness(out_val), + ]; + ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); + ctx.last().unwrap() + } + + /// Constrains and returns `a & b`, assumeing `a` and `b` are boolean. + /// + /// Defines a vertical gate of form | 0 | a | b | out |, where out = a * b. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + /// * `b`: [QuantumCell] that contains a boolean value. + fn and( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + self.mul(ctx, a, b) + } + + /// Constrains and returns `!a` assumeing `a` is boolean. + /// + /// Defines a vertical gate of form | 1 - a | a | 1 | 1 |, where 1 - a = out. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + fn not(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.sub(ctx, Constant(F::one()), a) + } + + /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. + /// + /// Defines a vertical gate of form `| 1 - sel | sel | 1 | a | 1 - sel | sel | 1 | b | out |`, where out = sel * a + (1 - sel) * b. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + /// * `b`: [QuantumCell] that contains a boolean value. + /// * `sel`: [QuantumCell] that contains a boolean value. + fn select( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + sel: impl Into>, + ) -> AssignedValue; + + /// Constains and returns `a || (b && c)`, assuming `a`, `b` and `c` are boolean. + /// + /// Defines a vertical gate of form `| 1 - b c | b | c | 1 | a - 1 | 1 - b c | out | a - 1 | 1 | 1 | a |`, where out = a + b * c - a * b * c. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + /// * `b`: [QuantumCell] that contains a boolean value. + /// * `c`: [QuantumCell] that contains a boolean value. + fn or_and( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + c: impl Into>, + ) -> AssignedValue; + + /// Constrains and returns an indicator vector from a slice of boolean values, where `output[idx] = 1` iff idx = (the number represented by `bits` in binary little endian), otherwise `output[idx] = 0`. + /// * `ctx`: [Context] to add the constraints to + /// * `bits`: slice of [QuantumCell]'s that contains boolean values + /// + /// # Assumptions + /// * `bits` is non-empty + fn bits_to_indicator( + &self, + ctx: &mut Context, + bits: &[AssignedValue], + ) -> Vec> { + let k = bits.len(); + assert!(k > 0, "bits_to_indicator: bits must be non-empty"); + + // (inv_last_bit, last_bit) = (1, 0) if bits[k - 1] = 0 + let (inv_last_bit, last_bit) = { + ctx.assign_region( + [ + Witness(F::one() - bits[k - 1].value()), + Existing(bits[k - 1]), + Constant(F::one()), + Constant(F::one()), + ], + [0], + ); + (ctx.get(-4), ctx.get(-3)) }; + let mut indicator = Vec::with_capacity(2 * (1 << k) - 2); + let mut offset = 0; + indicator.push(inv_last_bit); + indicator.push(last_bit); + for (idx, bit) in bits.iter().rev().enumerate().skip(1) { + for old_idx in 0..(1 << idx) { + // inv_prod_val = (1 - bit) * indicator[offset + old_idx] + let inv_prod_val = (F::one() - bit.value()) * indicator[offset + old_idx].value(); + ctx.assign_region( + [ + Witness(inv_prod_val), + Existing(indicator[offset + old_idx]), + Existing(*bit), + Existing(indicator[offset + old_idx]), + ], + [0], + ); + indicator.push(ctx.get(-4)); - let basic_gate = self.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}")); - let column = basic_gate.value; - let assignments = inputs - .enumerate() - .map(|(i, input)| { - ctx.assign_cell( - input, - column, - #[cfg(feature = "display")] - self.context_id, - row_offset + i, - #[cfg(feature = "halo2-pse")] - (phase as u8), - ) - }) - .collect::>(); - - for (i, q_coeff) in gate_offsets.into_iter() { - basic_gate - .q_enable - .enable(&mut ctx.region, (row_offset as isize + i) as usize) - .expect("enable selector should not fail"); - - if self.strategy == GateStrategy::PlonkPlus { - let q_coeff = q_coeff.unwrap_or([F::one(), F::zero(), F::zero()]); - for (j, q_coeff) in q_coeff.into_iter().enumerate() { - #[cfg(feature = "halo2-axiom")] - { - ctx.region.assign_fixed( - basic_gate.q_enable_plus[0], - ((row_offset as isize) + i) as usize + j, - Assigned::Trivial(q_coeff), - ); - } - #[cfg(feature = "halo2-pse")] - { - ctx.region - .assign_fixed( - || "", - basic_gate.q_enable_plus[0], - ((row_offset as isize) + i) as usize + j, - || Value::known(q_coeff), - ) - .unwrap(); - } - } + // prod = bit * indicator[offset + old_idx] + let prod = self.mul(ctx, Existing(indicator[offset + old_idx]), Existing(*bit)); + indicator.push(prod); } + offset += 1 << idx; } + indicator.split_off((1 << k) - 2) + } - ctx.advice_alloc[self.context_id].1 += assignments.len(); - - #[cfg(feature = "display")] - { - ctx.total_advice += assignments.len(); - } + /// Constrains and returns a [Vec] `indicator` of length `len`, where `indicator[i] == 1 if i == idx otherwise 0`, if `idx >= len` then `indicator` is all zeros. + /// + /// Assumes `len` is greater than 0. + /// * `ctx`: [Context] to add the constraints to + /// * `idx`: [QuantumCell] index of the indicator vector to be set to 1 + /// * `len`: length of the `indicator` vector + fn idx_to_indicator( + &self, + ctx: &mut Context, + idx: impl Into>, + len: usize, + ) -> Vec> { + let mut idx = idx.into(); + (0..len) + .map(|i| { + // need to use assigned idx after i > 0 so equality constraint holds + if i == 0 { + // unroll `is_zero` to make sure if `idx == Witness(_)` it is replaced by `Existing(_)` in later iterations + let x = idx.value(); + let (is_zero, inv) = if x.is_zero_vartime() { + (F::one(), Assigned::Trivial(F::one())) + } else { + (F::zero(), Assigned::Rational(F::one(), *x)) + }; + let cells = [ + Witness(is_zero), + idx, + WitnessFraction(inv), + Constant(F::one()), + Constant(F::zero()), + idx, + Witness(is_zero), + Constant(F::zero()), + ]; + ctx.assign_region_smart(cells, [0, 4], [(0, 6), (1, 5)], []); // note the two `idx` need to be constrained equal: (1, 5) + idx = Existing(ctx.get(-3)); // replacing `idx` with Existing cell so future loop iterations constrain equality of all `idx`s + ctx.get(-2) + } else { + self.is_equal(ctx, idx, Constant(self.get_field_element(i as u64))) + } + }) + .collect() + } - assignments + /// Constrains the inner product of `a` and `indicator` and returns `a[idx]` (e.g. the value of `a` at `idx`). + /// + /// Assumes that `a` and `indicator` are non-empty iterators of the same length, the values of `indicator` are boolean, + /// and that `indicator` has at most one `1` bit. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell]'s that contains field elements + /// * `indicator`: Iterator of [AssignedValue]'s where indicator[i] == 1 if i == `idx`, otherwise 0 + fn select_by_indicator( + &self, + ctx: &mut Context, + a: impl IntoIterator, + indicator: impl IntoIterator>, + ) -> AssignedValue + where + Q: Into>, + { + let mut sum = F::zero(); + let a = a.into_iter(); + let (len, hi) = a.size_hint(); + assert_eq!(Some(len), hi); + + let cells = std::iter::once(Constant(F::zero())).chain( + a.zip(indicator.into_iter()).flat_map(|(a, ind)| { + let a = a.into(); + sum = if ind.value().is_zero_vartime() { sum } else { *a.value() }; + [a, Existing(ind), Witness(sum)] + }), + ); + ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } - fn assign_region_last_in<'a, 'b: 'a>( + /// Constrains and returns `cells[idx]` if `idx < cells.len()`, otherwise return 0. + /// + /// Assumes that `cells` and `idx` are non-empty iterators of the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `cells`: Iterator of [QuantumCell]s to select from + /// * `idx`: [QuantumCell] with value `idx` where `idx` is the index of the cell to be selected + fn select_from_idx( &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - phase: usize, - ) -> AssignedValue<'b, F> { - // We enforce the pattern that you should assign everything in current phase at once and then move onto next phase - debug_assert_eq!(phase, ctx.current_phase()); - - let inputs = inputs.into_iter(); - let (len, hi) = inputs.size_hint(); - debug_assert_eq!(hi, Some(len)); - debug_assert_ne!(len, 0); - // we index into `advice_alloc` twice so this assert should save a bound check - assert!(self.context_id < ctx.advice_alloc.len(), "context id out of bounds"); - - let (gate_index, row_offset) = { - let alloc = ctx.advice_alloc.get_mut(self.context_id).unwrap(); - - if alloc.1 + len >= ctx.max_rows { - alloc.1 = 0; - alloc.0 += 1; - } - *alloc + ctx: &mut Context, + cells: impl IntoIterator, + idx: impl Into>, + ) -> AssignedValue + where + Q: Into>, + { + let cells = cells.into_iter(); + let (len, hi) = cells.size_hint(); + assert_eq!(Some(len), hi); + + let ind = self.idx_to_indicator(ctx, idx, len); + self.select_by_indicator(ctx, cells, ind) + } + + /// Constrains that a cell is equal to 0 and returns `1` if `a = 0`, otherwise `0`. + /// + /// Defines a vertical gate of form `| out | a | inv | 1 | 0 | a | out | 0 |`, where out = 1 if a = 0, otherwise out = 0. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value to be constrained + fn is_zero(&self, ctx: &mut Context, a: AssignedValue) -> AssignedValue { + let x = a.value(); + let (is_zero, inv) = if x.is_zero_vartime() { + (F::one(), Assigned::Trivial(F::one())) + } else { + (F::zero(), Assigned::Rational(F::one(), *x)) }; - let basic_gate = self.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}")); - let column = basic_gate.value; - let mut out = None; - for (i, input) in inputs.enumerate() { - out = Some(ctx.assign_cell( - input, - column, - #[cfg(feature = "display")] - self.context_id, - row_offset + i, - #[cfg(feature = "halo2-pse")] - (phase as u8), - )); - } + let cells = [ + Witness(is_zero), + Existing(a), + WitnessFraction(inv), + Constant(F::one()), + Constant(F::zero()), + Existing(a), + Witness(is_zero), + Constant(F::zero()), + ]; + ctx.assign_region_smart(cells, [0, 4], [(0, 6)], []); + ctx.get(-2) + } + + /// Constrains that the value of two cells are equal: b - a = 0, returns `1` if `a = b`, otherwise `0`. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to compare to `a` + fn is_equal( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let diff = self.sub(ctx, a, b); + self.is_zero(ctx, diff) + } - for (i, q_coeff) in gate_offsets.into_iter() { - basic_gate - .q_enable - .enable(&mut ctx.region, (row_offset as isize + i) as usize) - .expect("selector enable should not fail"); - - if self.strategy == GateStrategy::PlonkPlus { - let q_coeff = q_coeff.unwrap_or([F::one(), F::zero(), F::zero()]); - for (j, q_coeff) in q_coeff.into_iter().enumerate() { - #[cfg(feature = "halo2-axiom")] - { - ctx.region.assign_fixed( - basic_gate.q_enable_plus[0], - ((row_offset as isize) + i) as usize + j, - Assigned::Trivial(q_coeff), - ); - } - #[cfg(feature = "halo2-pse")] - { - ctx.region - .assign_fixed( - || "", - basic_gate.q_enable_plus[0], - ((row_offset as isize) + i) as usize + j, - || Value::known(q_coeff), - ) - .unwrap(); - } + /// Constrains and returns little-endian bit vector representation of `a`. + /// + /// Assumes `range_bits <= number of bits in a`. + /// * `a`: [QuantumCell] of the value to convert + /// * `range_bits`: range of bits needed to represent `a` + fn num_to_bits( + &self, + ctx: &mut Context, + a: AssignedValue, + range_bits: usize, + ) -> Vec>; + + /// Performs and constrains Lagrange interpolation on `coords` and evaluates the resulting polynomial at `x`. + /// + /// Given pairs `coords[i] = (x_i, y_i)`, let `f` be the unique degree `len(coords) - 1` polynomial such that `f(x_i) = y_i` for all `i`. + /// + /// Returns: + /// (f(x), Prod_i(x - x_i)) + /// * `ctx`: [Context] to add the constraints to + /// * `coords`: immutable reference to a slice of tuples of [AssignedValue]s representing the points to interpolate over such that `coords[i] = (x_i, y_i)` + /// * `x`: x-coordinate of the point to evaluate `f` at + /// + /// # Assumptions + /// * `coords` is non-empty + fn lagrange_and_eval( + &self, + ctx: &mut Context, + coords: &[(AssignedValue, AssignedValue)], + x: AssignedValue, + ) -> (AssignedValue, AssignedValue) { + assert!(!coords.is_empty(), "coords should not be empty"); + let mut z = self.sub(ctx, Existing(x), Existing(coords[0].0)); + for coord in coords.iter().skip(1) { + let sub = self.sub(ctx, Existing(x), Existing(coord.0)); + z = self.mul(ctx, Existing(z), Existing(sub)); + } + let mut eval = None; + for i in 0..coords.len() { + // compute (x - x_i) * Prod_{j != i} (x_i - x_j) + let mut denom = self.sub(ctx, Existing(x), Existing(coords[i].0)); + for j in 0..coords.len() { + if i == j { + continue; } + let sub = self.sub(ctx, coords[i].0, coords[j].0); + denom = self.mul(ctx, denom, sub); } + // TODO: batch inversion + let is_zero = self.is_zero(ctx, denom); + self.assert_is_const(ctx, &is_zero, &F::zero()); + + // y_i / denom + let quot = self.div_unsafe(ctx, coords[i].1, denom); + eval = if let Some(eval) = eval { + let eval = self.add(ctx, eval, quot); + Some(eval) + } else { + Some(quot) + }; } + let out = self.mul(ctx, eval.unwrap(), z); + (out, z) + } +} - ctx.advice_alloc[self.context_id].1 += len; +/// A chip that implements the [GateInstructions] trait supporting basic arithmetic operations. +#[derive(Clone, Debug)] +pub struct GateChip { + /// The [GateStrategy] used when declaring gates. + strategy: GateStrategy, + /// The field elements 2^i for i in 0..F::NUM_BITS. + pub pow_of_two: Vec, + /// To avoid Montgomery conversion in `F::from` for common small numbers, we keep a cache of field elements. + pub field_element_cache: Vec, +} + +impl Default for GateChip { + fn default() -> Self { + Self::new(GateStrategy::Vertical) + } +} - #[cfg(feature = "display")] - { - ctx.total_advice += len; +impl GateChip { + /// Returns a new [GateChip] with the given [GateStrategy]. + pub fn new(strategy: GateStrategy) -> Self { + let mut pow_of_two = Vec::with_capacity(F::NUM_BITS as usize); + let two = F::from(2); + pow_of_two.push(F::one()); + pow_of_two.push(two); + for _ in 2..F::NUM_BITS { + pow_of_two.push(two * pow_of_two.last().unwrap()); } + let field_element_cache = (0..1024).map(|i| F::from(i)).collect(); - out.unwrap() + Self { strategy, pow_of_two, field_element_cache } } - // Takes two vectors of `QuantumCell` and constrains a witness output to the inner product of `` - // outputs are (assignments except last, out_cell) - // Currently the only places `assignments` is used are: `num_to_bits, range_check, carry_mod, check_carry_mod_to_zero` - fn inner_product<'a, 'b: 'a>( + /// Calculates and constrains the inner product of ``. + /// + /// Returns `true` if `b` start with `Constant(F::one())`, and `false` otherwise. + /// + /// Assumes `a` and `b` are the same length. + /// * `ctx`: [Context] of the circuit + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to take inner product of `a` by + fn inner_product_simple( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> AssignedValue<'b, F> { - // we will do special handling of the cases where one of the vectors is all constants - match self.strategy { - GateStrategy::PlonkPlus => { - let (_, out) = self.inner_product_with_assignments(ctx, a, b); - out - } - _ => self.inner_product_simple(ctx, a, b), + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> bool + where + QA: Into>, + { + let mut sum; + let mut a = a.into_iter(); + let mut b = b.into_iter().peekable(); + + let b_starts_with_one = matches!(b.peek(), Some(Constant(c)) if c == &F::one()); + let cells = if b_starts_with_one { + b.next(); + let start_a = a.next().unwrap().into(); + sum = *start_a.value(); + iter::once(start_a) + } else { + sum = F::zero(); + iter::once(Constant(F::zero())) } + .chain(a.zip(b).flat_map(|(a, b)| { + let a = a.into(); + sum += *a.value() * b.value(); + [a, b, Witness(sum)] + })); + + if ctx.witness_gen_only() { + ctx.assign_region(cells, vec![]); + } else { + let cells = cells.collect::>(); + let lo = cells.len(); + let len = lo / 3; + ctx.assign_region(cells, (0..len).map(|i| 3 * i as isize)); + }; + b_starts_with_one } +} - fn inner_product_with_sums<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> Box> + 'b> { - let mut b = b.into_iter().peekable(); - let flag = matches!(b.peek(), Some(&Constant(c)) if c == F::one()); - let (assignments_without_last, last) = - self.inner_product_simple_with_assignments(ctx, a, b); - if flag { - Box::new(assignments_without_last.into_iter().step_by(3).chain(once(last))) +impl GateInstructions for GateChip { + /// Returns the [GateStrategy] the [GateChip]. + fn strategy(&self) -> GateStrategy { + self.strategy + } + + /// Returns a slice of the [ScalarField] elements 2i for i in 0..F::NUM_BITS. + fn pow_of_two(&self) -> &[F] { + &self.pow_of_two + } + + /// Returns the the value of `n` as a [ScalarField] element. + /// * `n`: the [u64] value to convert + fn get_field_element(&self, n: u64) -> F { + let get = self.field_element_cache.get(n as usize); + if let Some(fe) = get { + *fe } else { - // in this case the first assignment is 0 so we skip it - Box::new(assignments_without_last.into_iter().step_by(3).skip(1).chain(once(last))) + F::from(n) } } - fn inner_product_left<'a, 'b: 'a>( + /// Constrains and returns the inner product of ``. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to take inner product of `a` by + fn inner_product( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - a_assigned: &mut Vec>, - ) -> AssignedValue<'b, F> { - match self.strategy { - GateStrategy::PlonkPlus => { - let a = a.into_iter(); - let (len, _) = a.size_hint(); - let (assignments, acc) = self.inner_product_with_assignments(ctx, a, b); - let mut assignments = assignments.into_iter(); - a_assigned.clear(); - assert!(a_assigned.capacity() >= len); - a_assigned.extend( - iter::once(assignments.next().unwrap()) - .chain( - assignments - .chunks(3) - .into_iter() - .flat_map(|chunk| chunk.into_iter().take(2)), - ) - .take(len), - ); - acc - } - _ => { - let mut a = a.into_iter(); - let mut b = b.into_iter().peekable(); - let (len, hi) = b.size_hint(); - debug_assert_eq!(Some(len), hi); - // we do not use `assign_region` and implement directly to avoid `collect`ing the vector of assignments - let phase = ctx.current_phase(); - assert!(self.context_id < ctx.advice_alloc.len(), "context id out of bounds"); - - let (gate_index, mut row_offset) = { - let alloc = ctx.advice_alloc.get_mut(self.context_id).unwrap(); - if alloc.1 + 3 * len + 1 >= ctx.max_rows { - alloc.1 = 0; - alloc.0 += 1; - } - *alloc - }; - let basic_gate = self.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}")); - let column = basic_gate.value; - let q_enable = basic_gate.q_enable; - - let mut right_one = false; - let start = ctx.assign_cell( - if matches!(b.peek(), Some(&Constant(x)) if x == F::one()) { - right_one = true; - b.next(); - a.next().unwrap() - } else { - Constant(F::zero()) - }, - column, - #[cfg(feature = "display")] - self.context_id, - row_offset, - #[cfg(feature = "halo2-pse")] - (phase as u8), - ); - - row_offset += 1; - let mut acc = start.value().copied(); - a_assigned.clear(); - assert!(a_assigned.capacity() >= len); - if right_one { - a_assigned.push(start); - } - let mut last = None; - - for (a, b) in a.zip(b) { - q_enable - .enable(&mut ctx.region, row_offset - 1) - .expect("enable selector should not fail"); - - acc = acc + a.value().zip(b.value()).map(|(a, b)| *a * b); - let [a, _, c] = [(a, 0), (b, 1), (Witness(acc), 2)].map(|(qcell, idx)| { - ctx.assign_cell( - qcell, - column, - #[cfg(feature = "display")] - self.context_id, - row_offset + idx, - #[cfg(feature = "halo2-pse")] - (phase as u8), - ) - }); - last = Some(c); - row_offset += 3; - a_assigned.push(a); - } - ctx.advice_alloc[self.context_id].1 = row_offset; + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> AssignedValue + where + QA: Into>, + { + self.inner_product_simple(ctx, a, b); + ctx.last().unwrap() + } - #[cfg(feature = "display")] - { - ctx.total_advice += 3 * (len - usize::from(right_one)) + 1; - } - last.unwrap_or_else(|| a_assigned[0].clone()) + /// Returns the inner product of `` and returns a tuple of the last item of `a` after it is assigned and the item to its left `(left_a, last_a)`. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] of the circuit + /// * `a`: Iterator of [QuantumCell]s + /// * `b`: Iterator of [QuantumCell]s to take inner product of `a` by + fn inner_product_left_last( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, AssignedValue) + where + QA: Into>, + { + let a = a.into_iter(); + let (len, hi) = a.size_hint(); + assert_eq!(Some(len), hi); + let row_offset = ctx.advice.len(); + let b_starts_with_one = self.inner_product_simple(ctx, a, b); + let a_last = if b_starts_with_one { + if len == 1 { + ctx.get(row_offset as isize) + } else { + ctx.get((row_offset + 1 + 3 * (len - 2)) as isize) } + } else { + ctx.get((row_offset + 1 + 3 * (len - 1)) as isize) + }; + (ctx.last().unwrap(), a_last) + } + + /// Calculates and constrains the inner product. + /// + /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to calculate the partial sums of the inner product of `a` by + fn inner_product_with_sums<'thread, QA>( + &self, + ctx: &'thread mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> Box> + 'thread> + where + QA: Into>, + { + let row_offset = ctx.advice.len(); + let b_starts_with_one = self.inner_product_simple(ctx, a, b); + if b_starts_with_one { + Box::new((row_offset..ctx.advice.len()).step_by(3).map(|i| ctx.get(i as isize))) + } else { + // in this case the first assignment is 0 so we skip it + Box::new((row_offset..ctx.advice.len()).step_by(3).skip(1).map(|i| ctx.get(i as isize))) } } - fn sum_products_with_coeff_and_var<'a, 'b: 'a>( + /// Constrains and returns the sum of products of `coeff * (a * b)` defined in `values` plus a variable `var` e.g. + /// `x = var + values[0].0 * (values[0].1 * values[0].2) + values[1].0 * (values[1].1 * values[1].2) + ... + values[n].0 * (values[n].1 * values[n].2)`. + /// * `ctx`: [Context] to add the constraints to + /// * `values`: Iterator of tuples `(coeff, a, b)` where `coeff` is a field element, `a` and `b` are [QuantumCell]'s + /// * `var`: [QuantumCell] that represents the value of a variable added to the sum + fn sum_products_with_coeff_and_var( &self, - ctx: &mut Context<'_, F>, - values: impl IntoIterator, QuantumCell<'a, 'b, F>)>, - var: QuantumCell<'a, 'b, F>, - ) -> AssignedValue<'b, F> { - // TODO: optimize + ctx: &mut Context, + values: impl IntoIterator, QuantumCell)>, + var: QuantumCell, + ) -> AssignedValue { + // TODO: optimizer match self.strategy { - GateStrategy::PlonkPlus => { - let mut cells = Vec::new(); - let mut gate_offsets = Vec::new(); - let mut acc = var.value().copied(); - cells.push(var); - for (i, (c, a, b)) in values.into_iter().enumerate() { - acc = acc + Value::known(c) * a.value() * b.value(); - cells.append(&mut vec![a, b, Witness(acc)]); - gate_offsets.push((3 * i as isize, Some([c, F::zero(), F::zero()]))); - } - self.assign_region_last(ctx, cells, gate_offsets) - } GateStrategy::Vertical => { + // Create an iterator starting with `var` and let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::one()))) .chain(values.into_iter().filter_map(|(c, va, vb)| { if c == F::one() { Some((va, vb)) } else if c != F::zero() { let prod = self.mul(ctx, va, vb); - Some((QuantumCell::ExistingOwned(prod), Constant(c))) + Some((QuantumCell::Existing(prod), Constant(c))) } else { None } @@ -702,74 +1027,67 @@ impl GateInstructions for FlexGateConfig { } } - /// assumes sel is boolean - /// returns - /// a * sel + b * (1 - sel) - fn select<'v>( + /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. + /// + /// Defines a vertical gate of form `| 1 - sel | sel | 1 | a | 1 - sel | sel | 1 | b | out |`, where out = sel * a + (1 - sel) * b. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] that contains a boolean value + /// * `b`: [QuantumCell] that contains a boolean value + /// * `sel`: [QuantumCell] that contains a boolean value + fn select( &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - sel: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let diff_val: Value = a.value().zip(b.value()).map(|(a, b)| *a - b); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + sel: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let sel = sel.into(); + let diff_val = *a.value() - b.value(); let out_val = diff_val * sel.value() + b.value(); match self.strategy { // | a - b | 1 | b | a | // | b | sel | a - b | out | GateStrategy::Vertical => { - let cells = vec![ + let cells = [ Witness(diff_val), Constant(F::one()), - b.clone(), + b, a, b, sel, Witness(diff_val), Witness(out_val), ]; - let mut assigned_cells = - self.assign_region_smart(ctx, cells, vec![0, 4], vec![(0, 6), (2, 4)], vec![]); - assigned_cells.pop().unwrap() - } - // | 0 | a | a - b | b | sel | a - b | out | - // selectors - // | 1 | 0 | 0 | 1 | 0 | 0 - // | 0 | 1 | -1 | 1 | 0 | 0 - GateStrategy::PlonkPlus => { - let mut assignments = self.assign_region( - ctx, - vec![ - Constant(F::zero()), - a, - Witness(diff_val), - b, - sel, - Witness(diff_val), - Witness(out_val), - ], - vec![(0, Some([F::zero(), F::one(), -F::one()])), (3, None)], - ); - ctx.region.constrain_equal(assignments[2].cell(), assignments[5].cell()); - assignments.pop().unwrap() + ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); + ctx.last().unwrap() } } } - /// returns: a || (b && c) - // | 1 - b c | b | c | 1 | a - 1 | 1 - b c | out | a - 1 | 1 | 1 | a | - fn or_and<'v>( + /// Constains and returns `a || (b && c)`, assuming `a`, `b` and `c` are boolean. + /// + /// Defines a vertical gate of form `| 1 - b c | b | c | 1 | a - 1 | 1 - b c | out | a - 1 | 1 | 1 | a |`, where out = a + b * c - a * b * c. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] that contains a boolean value + /// * `b`: [QuantumCell] that contains a boolean value + /// * `c`: [QuantumCell] that contains a boolean value + fn or_and( &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - c: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let bc_val = b.value().zip(c.value()).map(|(b, c)| *b * c); - let not_bc_val = bc_val.map(|x| F::one() - x); - let not_a_val = a.value().map(|x| *x - F::one()); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + c: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let c = c.into(); + let bc_val = *b.value() * c.value(); + let not_bc_val = F::one() - bc_val; + let not_a_val = *a.value() - F::one(); let out_val = bc_val + a.value() - bc_val * a.value(); - let cells = vec![ + let cells = [ Witness(not_bc_val), b, c, @@ -782,52 +1100,39 @@ impl GateInstructions for FlexGateConfig { Constant(F::one()), a, ]; - let assigned_cells = - self.assign_region_smart(ctx, cells, vec![0, 3, 7], vec![(4, 7), (0, 5)], vec![]); - assigned_cells.into_iter().nth(6).unwrap() + ctx.assign_region_smart(cells, [0, 3, 7], [(4, 7), (0, 5)], []); + ctx.get(-5) } - // returns little-endian bit vectors - fn num_to_bits<'v>( + /// Constrains and returns little-endian bit vector representation of `a`. + /// + /// Assumes `range_bits >= number of bits in a`. + /// * `a`: [QuantumCell] of the value to convert + /// * `range_bits`: range of bits needed to represent `a`. Assumes `range_bits > 0`. + fn num_to_bits( &self, - ctx: &mut Context<'_, F>, - a: &AssignedValue<'v, F>, + ctx: &mut Context, + a: AssignedValue, range_bits: usize, - ) -> Vec> { - let bits = a - .value() - .map(|a| { - a.to_repr() - .as_ref() - .iter() - .flat_map(|byte| (0..8).map(|i| (*byte as u64 >> i) & 1)) - .take(range_bits) - .map(|x| F::from(x)) - .collect::>() - }) - .transpose_vec(range_bits); + ) -> Vec> { + let bits = a.value().to_u64_limbs(range_bits, 1).into_iter().map(|x| Witness(F::from(x))); let mut bit_cells = Vec::with_capacity(range_bits); - - let acc = self.inner_product_left( + let row_offset = ctx.advice.len(); + let acc = self.inner_product( ctx, - bits.into_iter().map(|x| Witness(x)), + bits, self.pow_of_two[..range_bits].iter().map(|c| Constant(*c)), - &mut bit_cells, ); - ctx.region.constrain_equal(a.cell(), acc.cell()); + ctx.constrain_equal(&a, &acc); + debug_assert!(range_bits > 0); + bit_cells.push(ctx.get(row_offset as isize)); + for i in 1..range_bits { + bit_cells.push(ctx.get((row_offset + 1 + 3 * (i - 1)) as isize)); + } for bit_cell in &bit_cells { - self.assign_region( - ctx, - vec![ - Constant(F::zero()), - Existing(bit_cell), - Existing(bit_cell), - Existing(bit_cell), - ], - vec![(0, None)], - ); + self.assert_bit(ctx, *bit_cell); } bit_cells } diff --git a/halo2-base/src/gates/mod.rs b/halo2-base/src/gates/mod.rs index 52706772..3e96bdba 100644 --- a/halo2-base/src/gates/mod.rs +++ b/halo2-base/src/gates/mod.rs @@ -1,864 +1,13 @@ -use self::{flex_gate::GateStrategy, range::RangeStrategy}; -use super::{ - utils::ScalarField, - AssignedValue, Context, - QuantumCell::{self, Constant, Existing, ExistingOwned, Witness, WitnessFraction}, -}; -use crate::{ - halo2_proofs::{circuit::Value, plonk::Assigned}, - utils::{biguint_to_fe, bit_length, fe_to_biguint, PrimeField}, -}; -use core::iter; -use num_bigint::BigUint; -use num_integer::Integer; -use num_traits::{One, Zero}; -use std::ops::Shl; - +/// Module that helps auto-build circuits +pub mod builder; +/// Module implementing our simple custom gate and common functions using it pub mod flex_gate; +/// Module using a single lookup table for range checks pub mod range; -pub trait GateInstructions { - fn strategy(&self) -> GateStrategy; - fn context_id(&self) -> usize; - - fn pow_of_two(&self) -> &[F]; - fn get_field_element(&self, n: u64) -> F; - - fn assign_region<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - ) -> Vec> { - self.assign_region_in(ctx, inputs, gate_offsets, ctx.current_phase()) - } - - fn assign_region_in<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - phase: usize, - ) -> Vec>; - - /// Only returns the last assigned cell - /// - /// Does not collect the vec, saving heap allocation - fn assign_region_last<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - ) -> AssignedValue<'b, F> { - self.assign_region_last_in(ctx, inputs, gate_offsets, ctx.current_phase()) - } - - fn assign_region_last_in<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - phase: usize, - ) -> AssignedValue<'b, F>; - - /// Only call this if ctx.region is not in shape mode, i.e., if not using simple layouter or ctx.first_pass = false - /// - /// All indices in `gate_offsets`, `equality_offsets`, `external_equality` are with respect to `inputs` indices - /// - `gate_offsets` specifies indices to enable selector for the gate; assume `gate_offsets` is sorted in increasing order - /// - `equality_offsets` specifies pairs of indices to constrain equality - /// - `external_equality` specifies an existing cell to constrain equality with the cell at a certain index - fn assign_region_smart<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator, - equality_offsets: impl IntoIterator, - external_equality: Vec<(&AssignedValue, usize)>, - ) -> Vec> { - let assignments = - self.assign_region(ctx, inputs, gate_offsets.into_iter().map(|i| (i as isize, None))); - for (offset1, offset2) in equality_offsets.into_iter() { - ctx.region.constrain_equal(assignments[offset1].cell(), assignments[offset2].cell()); - } - for (assigned, eq_offset) in external_equality.into_iter() { - ctx.region.constrain_equal(assigned.cell(), assignments[eq_offset].cell()); - } - assignments - } - - fn assign_witnesses<'v>( - &self, - ctx: &mut Context<'_, F>, - witnesses: impl IntoIterator>, - ) -> Vec> { - self.assign_region(ctx, witnesses.into_iter().map(Witness), []) - } - - fn load_witness<'v>( - &self, - ctx: &mut Context<'_, F>, - witness: Value, - ) -> AssignedValue<'v, F> { - self.assign_region_last(ctx, [Witness(witness)], []) - } - - fn load_constant<'a>(&self, ctx: &mut Context<'_, F>, c: F) -> AssignedValue<'a, F> { - self.assign_region_last(ctx, [Constant(c)], []) - } - - fn load_zero<'a>(&self, ctx: &mut Context<'a, F>) -> AssignedValue<'a, F> { - if let Some(zcell) = &ctx.zero_cell { - return zcell.clone(); - } - let zero_cell = self.assign_region_last(ctx, [Constant(F::zero())], []); - ctx.zero_cell = Some(zero_cell.clone()); - zero_cell - } - - /// Copies a, b and constrains `a + b * 1 = out` - // | a | b | 1 | a + b | - fn add<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| *a + b); - self.assign_region_last( - ctx, - vec![a, b, Constant(F::one()), Witness(out_val)], - vec![(0, None)], - ) - } - - /// Copies a, b and constrains `a + b * (-1) = out` - // | a - b | b | 1 | a | - fn sub<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| *a - b); - // slightly better to not have to compute -F::one() since F::one() is cached - let assigned_cells = self.assign_region( - ctx, - vec![Witness(out_val), b, Constant(F::one()), a], - vec![(0, None)], - ); - assigned_cells.into_iter().next().unwrap() - } - - // | a | -a | 1 | 0 | - fn neg<'v>(&self, ctx: &mut Context<'_, F>, a: QuantumCell<'_, 'v, F>) -> AssignedValue<'v, F> { - let out_val = a.value().map(|v| -*v); - let assigned_cells = self.assign_region( - ctx, - vec![a, Witness(out_val), Constant(F::one()), Constant(F::zero())], - vec![(0, None)], - ); - assigned_cells.into_iter().nth(1).unwrap() - } - - /// Copies a, b and constrains `0 + a * b = out` - // | 0 | a | b | a * b | - fn mul<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| *a * b); - self.assign_region_last( - ctx, - vec![Constant(F::zero()), a, b, Witness(out_val)], - vec![(0, None)], - ) - } - - /// a * b + c - fn mul_add<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - c: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| *a * b) + c.value(); - self.assign_region_last(ctx, vec![c, a, b, Witness(out_val)], vec![(0, None)]) - } - - /// (1 - a) * b = b - a * b - fn mul_not<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| (F::one() - a) * b); - let assignments = - self.assign_region(ctx, vec![Witness(out_val), a, b.clone(), b], vec![(0, None)]); - ctx.region.constrain_equal(assignments[2].cell(), assignments[3].cell()); - assignments.into_iter().next().unwrap() - } - - /// Constrain x is 0 or 1. - fn assert_bit(&self, ctx: &mut Context<'_, F>, x: &AssignedValue) { - self.assign_region_last( - ctx, - [Constant(F::zero()), Existing(x), Existing(x), Existing(x)], - [(0, None)], - ); - } - - fn div_unsafe<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - // TODO: if really necessary, make `c` of type `Assigned` - // this would require the API using `Assigned` instead of `F` everywhere, so leave as last resort - let c = a.value().zip(b.value()).map(|(a, b)| b.invert().unwrap() * a); - let assignments = - self.assign_region(ctx, vec![Constant(F::zero()), Witness(c), b, a], vec![(0, None)]); - assignments.into_iter().nth(1).unwrap() - } - - fn assert_equal(&self, ctx: &mut Context<'_, F>, a: QuantumCell, b: QuantumCell) { - if let (Existing(a), Existing(b)) = (&a, &b) { - ctx.region.constrain_equal(a.cell(), b.cell()); - } else { - self.assign_region_smart( - ctx, - vec![Constant(F::zero()), a, Constant(F::one()), b], - vec![0], - vec![], - vec![], - ); - } - } - - fn assert_is_const(&self, ctx: &mut Context<'_, F>, a: &AssignedValue, constant: F) { - let c_cell = ctx.assign_fixed(constant); - #[cfg(feature = "halo2-axiom")] - ctx.region.constrain_equal(a.cell(), &c_cell); - #[cfg(feature = "halo2-pse")] - ctx.region.constrain_equal(a.cell(), c_cell).unwrap(); - } - - /// Returns `(assignments, output)` where `output` is the inner product of `` - /// - /// `assignments` is for internal use - fn inner_product<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> AssignedValue<'b, F>; - - /// very specialized for optimal range check, not for general consumption - /// - `a_assigned` is expected to have capacity a.len() - /// - we re-use `a_assigned` to save memory allocation - fn inner_product_left<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - a_assigned: &mut Vec>, - ) -> AssignedValue<'b, F>; - - /// Returns an iterator with the partial sums `sum_{j=0..=i} a[j] * b[j]`. - fn inner_product_with_sums<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> Box> + 'b>; - - fn sum<'a, 'b: 'a>( - &self, - ctx: &mut Context<'b, F>, - a: impl IntoIterator>, - ) -> AssignedValue<'b, F> { - let mut a = a.into_iter().peekable(); - let start = a.next(); - if start.is_none() { - return self.load_zero(ctx); - } - let start = start.unwrap(); - if a.peek().is_none() { - return self.assign_region_last(ctx, [start], []); - } - let (len, hi) = a.size_hint(); - debug_assert_eq!(Some(len), hi); - - let mut sum = start.value().copied(); - let cells = iter::once(start).chain(a.flat_map(|a| { - sum = sum + a.value(); - [a, Constant(F::one()), Witness(sum)] - })); - self.assign_region_last(ctx, cells, (0..len).map(|i| (3 * i as isize, None))) - } - - /// Returns the assignment trace where `output[3 * i]` has the running sum `sum_{j=0..=i} a[j]` - fn sum_with_assignments<'a, 'b: 'a>( - &self, - ctx: &mut Context<'b, F>, - a: impl IntoIterator>, - ) -> Vec> { - let mut a = a.into_iter().peekable(); - let start = a.next(); - if start.is_none() { - return vec![self.load_zero(ctx)]; - } - let start = start.unwrap(); - if a.peek().is_none() { - return self.assign_region(ctx, [start], []); - } - let (len, hi) = a.size_hint(); - debug_assert_eq!(Some(len), hi); - - let mut sum = start.value().copied(); - let cells = iter::once(start).chain(a.flat_map(|a| { - sum = sum + a.value(); - [a, Constant(F::one()), Witness(sum)] - })); - self.assign_region(ctx, cells, (0..len).map(|i| (3 * i as isize, None))) - } - - // requires b.len() == a.len() + 1 - // returns - // x_i = b_1 * (a_1...a_{i - 1}) - // + b_2 * (a_2...a_{i - 1}) - // + ... - // + b_i - // Returns [x_1, ..., x_{b.len()}] - fn accumulated_product<'a, 'v: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> Vec> { - let mut b = b.into_iter(); - let mut a = a.into_iter(); - let b_first = b.next(); - if let Some(b_first) = b_first { - let b_first = self.assign_region_last(ctx, [b_first], []); - std::iter::successors(Some(b_first), |x| { - a.next().zip(b.next()).map(|(a, b)| self.mul_add(ctx, Existing(x), a, b)) - }) - .collect() - } else { - vec![] - } - } - - fn sum_products_with_coeff_and_var<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - values: impl IntoIterator, QuantumCell<'a, 'b, F>)>, - var: QuantumCell<'a, 'b, F>, - ) -> AssignedValue<'b, F>; - - // | 1 - b | 1 | b | 1 | b | a | 1 - b | out | - fn or<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let not_b_val = b.value().map(|x| F::one() - x); - let out_val = a.value().zip(b.value()).map(|(a, b)| *a + b) - - a.value().zip(b.value()).map(|(a, b)| *a * b); - let cells = vec![ - Witness(not_b_val), - Constant(F::one()), - b.clone(), - Constant(F::one()), - b, - a, - Witness(not_b_val), - Witness(out_val), - ]; - let mut assigned_cells = - self.assign_region_smart(ctx, cells, vec![0, 4], vec![(0, 6), (2, 4)], vec![]); - assigned_cells.pop().unwrap() - } - - // | 0 | a | b | out | - fn and<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - self.mul(ctx, a, b) - } - - fn not<'v>(&self, ctx: &mut Context<'_, F>, a: QuantumCell<'_, 'v, F>) -> AssignedValue<'v, F> { - self.sub(ctx, Constant(F::one()), a) - } - - fn select<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - sel: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F>; - - fn or_and<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - c: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F>; - - /// assume bits has boolean values - /// returns vec[idx] with vec[idx] = 1 if and only if bits == idx as a binary number - fn bits_to_indicator<'v>( - &self, - ctx: &mut Context<'_, F>, - bits: &[AssignedValue<'v, F>], - ) -> Vec> { - let k = bits.len(); - - let (inv_last_bit, last_bit) = { - let mut assignments = self - .assign_region( - ctx, - vec![ - Witness(bits[k - 1].value().map(|b| F::one() - b)), - Existing(&bits[k - 1]), - Constant(F::one()), - Constant(F::one()), - ], - vec![(0, None)], - ) - .into_iter(); - (assignments.next().unwrap(), assignments.next().unwrap()) - }; - let mut indicator = Vec::with_capacity(2 * (1 << k) - 2); - let mut offset = 0; - indicator.push(inv_last_bit); - indicator.push(last_bit); - for (idx, bit) in bits.iter().rev().enumerate().skip(1) { - for old_idx in 0..(1 << idx) { - let inv_prod_val = indicator[offset + old_idx] - .value() - .zip(bit.value()) - .map(|(a, b)| (F::one() - b) * a); - let inv_prod = self - .assign_region_smart( - ctx, - vec![ - Witness(inv_prod_val), - Existing(&indicator[offset + old_idx]), - Existing(bit), - Existing(&indicator[offset + old_idx]), - ], - vec![0], - vec![], - vec![], - ) - .into_iter() - .next() - .unwrap(); - indicator.push(inv_prod); - - let prod = self.mul(ctx, Existing(&indicator[offset + old_idx]), Existing(bit)); - indicator.push(prod); - } - offset += 1 << idx; - } - indicator.split_off((1 << k) - 2) - } - - // returns vec with vec.len() == len such that: - // vec[i] == 1{i == idx} - fn idx_to_indicator<'v>( - &self, - ctx: &mut Context<'_, F>, - mut idx: QuantumCell<'_, 'v, F>, - len: usize, - ) -> Vec> { - let ind = self.assign_region( - ctx, - (0..len).map(|i| { - Witness(idx.value().map(|x| { - if x.get_lower_32() == i as u32 { - F::one() - } else { - F::zero() - } - })) - }), - vec![], - ); - - // check ind[i] * (i - idx) == 0 - for (i, ind) in ind.iter().enumerate() { - let val = ind.value().zip(idx.value()).map(|(ind, idx)| *ind * idx); - let assignments = self.assign_region( - ctx, - vec![ - Constant(F::zero()), - Existing(ind), - idx, - Witness(val), - Constant(-F::from(i as u64)), - Existing(ind), - Constant(F::zero()), - ], - vec![(0, None), (3, None)], - ); - // need to use assigned idx after i > 0 so equality constraint holds - idx = ExistingOwned(assignments.into_iter().nth(2).unwrap()); - } - ind - } - - // performs inner product on a, indicator - // `indicator` values are all boolean - /// Assumes for witness generation that only one element of `indicator` has non-zero value and that value is `F::one()`. - fn select_by_indicator<'a, 'i, 'b: 'a + 'i>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - indicator: impl IntoIterator>, - ) -> AssignedValue<'b, F> { - let mut sum = Value::known(F::zero()); - let a = a.into_iter(); - let (len, hi) = a.size_hint(); - debug_assert_eq!(Some(len), hi); - - let cells = - std::iter::once(Constant(F::zero())).chain(a.zip(indicator).flat_map(|(a, ind)| { - sum = sum.zip(a.value().zip(ind.value())).map(|(sum, (a, ind))| { - if ind.is_zero_vartime() { - sum - } else { - *a - } - }); - [a, Existing(ind), Witness(sum)] - })); - self.assign_region_last(ctx, cells, (0..len).map(|i| (3 * i as isize, None))) - } - - fn select_from_idx<'a, 'v: 'a>( - &self, - ctx: &mut Context<'_, F>, - cells: impl IntoIterator>, - idx: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let cells = cells.into_iter(); - let (len, hi) = cells.size_hint(); - debug_assert_eq!(Some(len), hi); - - let ind = self.idx_to_indicator(ctx, idx, len); - let out = self.select_by_indicator(ctx, cells, &ind); - out - } - - // | out | a | inv | 1 | 0 | a | out | 0 - fn is_zero<'v>( - &self, - ctx: &mut Context<'_, F>, - a: &AssignedValue<'v, F>, - ) -> AssignedValue<'v, F> { - let (is_zero, inv) = a - .value() - .map(|x| { - if x.is_zero_vartime() { - (F::one(), Assigned::Trivial(F::one())) - } else { - (F::zero(), Assigned::Rational(F::one(), *x)) - } - }) - .unzip(); - - let cells = vec![ - Witness(is_zero), - Existing(a), - WitnessFraction(inv), - Constant(F::one()), - Constant(F::zero()), - Existing(a), - Witness(is_zero), - Constant(F::zero()), - ]; - let assigned_cells = self.assign_region_smart(ctx, cells, vec![0, 4], vec![(0, 6)], vec![]); - assigned_cells.into_iter().next().unwrap() - } - - fn is_equal<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let diff = self.sub(ctx, a, b); - self.is_zero(ctx, &diff) - } - - // returns little-endian bit vectors - fn num_to_bits<'v>( - &self, - ctx: &mut Context<'_, F>, - a: &AssignedValue<'v, F>, - range_bits: usize, - ) -> Vec>; - - /// given pairs `coords[i] = (x_i, y_i)`, let `f` be the unique degree `len(coords)` polynomial such that `f(x_i) = y_i` for all `i`. - /// - /// input: coords, x - /// - /// output: (f(x), Prod_i (x - x_i)) - /// - /// constrains all x_i and x are distinct - fn lagrange_and_eval<'v>( - &self, - ctx: &mut Context<'_, F>, - coords: &[(AssignedValue<'v, F>, AssignedValue<'v, F>)], - x: &AssignedValue<'v, F>, - ) -> (AssignedValue<'v, F>, AssignedValue<'v, F>) { - let mut z = self.sub(ctx, Existing(x), Existing(&coords[0].0)); - for coord in coords.iter().skip(1) { - let sub = self.sub(ctx, Existing(x), Existing(&coord.0)); - z = self.mul(ctx, Existing(&z), Existing(&sub)); - } - let mut eval = None; - for i in 0..coords.len() { - // compute (x - x_i) * Prod_{j != i} (x_i - x_j) - let mut denom = self.sub(ctx, Existing(x), Existing(&coords[i].0)); - for j in 0..coords.len() { - if i == j { - continue; - } - let sub = self.sub(ctx, Existing(&coords[i].0), Existing(&coords[j].0)); - denom = self.mul(ctx, Existing(&denom), Existing(&sub)); - } - // TODO: batch inversion - let is_zero = self.is_zero(ctx, &denom); - self.assert_is_const(ctx, &is_zero, F::zero()); - - // y_i / denom - let quot = self.div_unsafe(ctx, Existing(&coords[i].1), Existing(&denom)); - eval = if let Some(eval) = eval { - let eval = self.add(ctx, Existing(&eval), Existing(")); - Some(eval) - } else { - Some(quot) - }; - } - let out = self.mul(ctx, Existing(&eval.unwrap()), Existing(&z)); - (out, z) - } -} - -pub trait RangeInstructions { - type Gate: GateInstructions; - - fn gate(&self) -> &Self::Gate; - fn strategy(&self) -> RangeStrategy; - - fn lookup_bits(&self) -> usize; - - fn range_check<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - range_bits: usize, - ); - - fn check_less_than<'a>( - &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, - num_bits: usize, - ); - - /// Checks that `a` is in `[0, b)`. - /// - /// Does not require bit assumptions on `a, b` because we range check that `a` has at most `bit_length(b)` bits. - fn check_less_than_safe<'a>(&self, ctx: &mut Context<'a, F>, a: &AssignedValue<'a, F>, b: u64) { - let range_bits = - (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - - self.range_check(ctx, a, range_bits); - self.check_less_than( - ctx, - Existing(a), - Constant(self.gate().get_field_element(b)), - range_bits, - ) - } - - /// Checks that `a` is in `[0, b)`. - /// - /// Does not require bit assumptions on `a, b` because we range check that `a` has at most `bit_length(b)` bits. - fn check_big_less_than_safe<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - b: BigUint, - ) where - F: PrimeField, - { - let range_bits = - (b.bits() as usize + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - - self.range_check(ctx, a, range_bits); - self.check_less_than(ctx, Existing(a), Constant(biguint_to_fe(&b)), range_bits) - } - - /// Returns whether `a` is in `[0, b)`. - /// - /// Warning: This may fail silently if `a` or `b` have more than `num_bits` bits - fn is_less_than<'a>( - &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, - num_bits: usize, - ) -> AssignedValue<'a, F>; - - /// Returns whether `a` is in `[0, b)`. - /// - /// Does not require bit assumptions on `a, b` because we range check that `a` has at most `range_bits` bits. - fn is_less_than_safe<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - b: u64, - ) -> AssignedValue<'a, F> { - let range_bits = - (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - - self.range_check(ctx, a, range_bits); - self.is_less_than(ctx, Existing(a), Constant(F::from(b)), range_bits) - } - - /// Returns whether `a` is in `[0, b)`. - /// - /// Does not require bit assumptions on `a, b` because we range check that `a` has at most `range_bits` bits. - fn is_big_less_than_safe<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - b: BigUint, - ) -> AssignedValue<'a, F> - where - F: PrimeField, - { - let range_bits = - (b.bits() as usize + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - - self.range_check(ctx, a, range_bits); - self.is_less_than(ctx, Existing(a), Constant(biguint_to_fe(&b)), range_bits) - } - - /// Returns `(c, r)` such that `a = b * c + r`. - /// - /// Assumes that `b != 0`. - fn div_mod<'a>( - &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: impl Into, - a_num_bits: usize, - ) -> (AssignedValue<'a, F>, AssignedValue<'a, F>) - where - F: PrimeField, - { - let b = b.into(); - let mut a_val = BigUint::zero(); - a.value().map(|v| a_val = fe_to_biguint(v)); - let (div, rem) = a_val.div_mod_floor(&b); - let [div, rem] = [div, rem].map(|v| biguint_to_fe(&v)); - let assigned = self.gate().assign_region( - ctx, - vec![ - Witness(Value::known(rem)), - Constant(biguint_to_fe(&b)), - Witness(Value::known(div)), - a, - ], - vec![(0, None)], - ); - self.check_big_less_than_safe( - ctx, - &assigned[2], - BigUint::one().shl(a_num_bits as u32) / &b + BigUint::one(), - ); - self.check_big_less_than_safe(ctx, &assigned[0], b); - (assigned[2].clone(), assigned[0].clone()) - } - - /// Returns `(c, r)` such that `a = b * c + r`. - /// - /// Assumes that `b != 0`. - /// - /// Let `X = 2 ** b_num_bits`. - /// Write `a = a1 * X + a0` and `c = c1 * X + c0`. - /// If we write `b * c0 + r = d1 * X + d0` then - /// `b * c + r = (b * c1 + d1) * X + d0` - fn div_mod_var<'a>( - &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, - a_num_bits: usize, - b_num_bits: usize, - ) -> (AssignedValue<'a, F>, AssignedValue<'a, F>) - where - F: PrimeField, - { - let mut a_val = BigUint::zero(); - a.value().map(|v| a_val = fe_to_biguint(v)); - let mut b_val = BigUint::one(); - b.value().map(|v| b_val = fe_to_biguint(v)); - let (div, rem) = a_val.div_mod_floor(&b_val); - let x = BigUint::one().shl(b_num_bits as u32); - let (div_hi, div_lo) = div.div_mod_floor(&x); - - let x_fe = self.gate().pow_of_two()[b_num_bits]; - let [div, div_hi, div_lo, rem] = [div, div_hi, div_lo, rem].map(|v| biguint_to_fe(&v)); - let assigned = self.gate().assign_region( - ctx, - vec![ - Witness(Value::known(div_lo)), - Witness(Value::known(div_hi)), - Constant(x_fe), - Witness(Value::known(div)), - Witness(Value::known(rem)), - ], - vec![(0, None)], - ); - self.range_check(ctx, &assigned[0], b_num_bits); - self.range_check(ctx, &assigned[1], a_num_bits.saturating_sub(b_num_bits)); - - let (bcr0_hi, bcr0_lo) = { - let bcr0 = - self.gate().mul_add(ctx, b.clone(), Existing(&assigned[0]), Existing(&assigned[4])); - self.div_mod(ctx, Existing(&bcr0), x.clone(), a_num_bits) - }; - let bcr_hi = - self.gate().mul_add(ctx, b.clone(), Existing(&assigned[1]), Existing(&bcr0_hi)); - - let (a_hi, a_lo) = self.div_mod(ctx, a, x, a_num_bits); - ctx.constrain_equal(&bcr_hi, &a_hi); - ctx.constrain_equal(&bcr0_lo, &a_lo); - - self.range_check(ctx, &assigned[4], b_num_bits); - self.check_less_than(ctx, Existing(&assigned[4]), b, b_num_bits); - (assigned[3].clone(), assigned[4].clone()) - } -} - -#[cfg(test)] +/// Tests +#[cfg(any(test, feature = "test-utils"))] pub mod tests; + +pub use flex_gate::{GateChip, GateInstructions}; +pub use range::{RangeChip, RangeInstructions}; diff --git a/halo2-base/src/gates/range.rs b/halo2-base/src/gates/range.rs index 07033ee7..7a6b6173 100644 --- a/halo2-base/src/gates/range.rs +++ b/halo2-base/src/gates/range.rs @@ -1,13 +1,5 @@ use crate::{ - gates::{ - flex_gate::{FlexGateConfig, GateStrategy, MAX_PHASE}, - GateInstructions, - }, - utils::{decompose_fe_to_u64_limbs, value_to_option, ScalarField}, - AssignedValue, - QuantumCell::{self, Constant, Existing, Witness}, -}; -use crate::{ + gates::flex_gate::{FlexGateConfig, GateInstructions, GateStrategy, MAX_PHASE}, halo2_proofs::{ circuit::{Layouter, Value}, plonk::{ @@ -15,44 +7,70 @@ use crate::{ }, poly::Rotation, }, - utils::PrimeField, + utils::{ + biguint_to_fe, bit_length, decompose_fe_to_u64_limbs, fe_to_biguint, BigPrimeField, + ScalarField, + }, + AssignedValue, Context, + QuantumCell::{self, Constant, Existing, Witness}, }; -use std::cmp::Ordering; +use num_bigint::BigUint; +use num_integer::Integer; +use num_traits::One; +use std::{cmp::Ordering, ops::Shl}; -use super::{Context, RangeInstructions}; +use super::flex_gate::GateChip; +/// Specifies the gate strategy for the range chip #[derive(Clone, Copy, Debug, PartialEq)] pub enum RangeStrategy { + /// # Vertical Gate Strategy: + /// `q_0 * (a + b * c - d) = 0` + /// where + /// * a = value[0], b = value[1], c = value[2], d = value[3] + /// * q = q_lookup[0] + /// * q is either 0 or 1 so this is just a simple selector + /// + /// Using `a + b * c` instead of `a * b + c` allows for "chaining" of gates, i.e., the output of one gate becomes `a` in the next gate. Vertical, // vanilla implementation with vertical basic gate(s) - // CustomVerticalShort, // vertical basic gate(s) and vertical custom range gates of length 2,3 - PlonkPlus, - // CustomHorizontal, // vertical basic gate and dedicated horizontal custom gate } +/// Configuration for Range Chip #[derive(Clone, Debug)] pub struct RangeConfig { - // `lookup_advice` are special advice columns only used for lookups - // - // If `strategy` is `Vertical` or `CustomVertical`: - // * If `gate` has only 1 advice column, enable lookups for that column, in which case `lookup_advice` is empty - // * Otherwise, add some user-specified number of `lookup_advice` columns - // * In this case, we don't even need a selector so `q_lookup` is empty - // If `strategy` is `CustomHorizontal`: - // * TODO + /// Underlying Gate Configuration + pub gate: FlexGateConfig, + /// Special advice (witness) Columns used only for lookup tables. + /// + /// Each phase of a halo2 circuit has a distinct lookup_advice column. + /// + /// * If `gate` has only 1 advice column, lookups are enabled for that column, in which case `lookup_advice` is empty + /// * If `gate` has more than 1 advice column some number of user-specified `lookup_advice` columns are added + /// * In this case, we don't need a selector so `q_lookup` is empty pub lookup_advice: [Vec>; MAX_PHASE], + /// Selector values for the lookup table. pub q_lookup: Vec>, + /// Column for lookup table values. pub lookup: TableColumn, - pub lookup_bits: usize, - pub limb_bases: Vec>, - // selector for custom range gate - // `q_range[k][i]` stores the selector for a custom range gate of length `k` - // pub q_range: HashMap>, - pub gate: FlexGateConfig, - strategy: RangeStrategy, - pub context_id: usize, + /// Defines the number of bits represented in the lookup table [0,2^lookup_bits). + lookup_bits: usize, + /// Gate Strategy used for specifying advice values. + _strategy: RangeStrategy, } impl RangeConfig { + /// Generates a new [RangeConfig] with the specified parameters. + /// + /// If `num_columns` is 0, then we assume you do not want to perform any lookups in that phase. + /// + /// Panics if `lookup_bits` > 28. + /// * `meta`: [ConstraintSystem] of the circuit + /// * `range_strategy`: [GateStrategy] of the range chip + /// * `num_advice`: Number of [Advice] [Column]s without lookup enabled in each phase + /// * `num_lookup_advice`: Number of `lookup_advice` [Column]s in each phase + /// * `num_fixed`: Number of fixed [Column]s in each phase + /// * `lookup_bits`: Number of bits represented in the LookUp table [0,2^lookup_bits) + /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) pub fn configure( meta: &mut ConstraintSystem, range_strategy: RangeStrategy, @@ -60,7 +78,6 @@ impl RangeConfig { num_lookup_advice: &[usize], num_fixed: usize, lookup_bits: usize, - context_id: usize, // params.k() circuit_degree: usize, ) -> Self { @@ -71,11 +88,9 @@ impl RangeConfig { meta, match range_strategy { RangeStrategy::Vertical => GateStrategy::Vertical, - RangeStrategy::PlonkPlus => GateStrategy::PlonkPlus, }, num_advice, num_fixed, - context_id, circuit_degree, ); @@ -101,31 +116,29 @@ impl RangeConfig { } } - let limb_base = F::from(1u64 << lookup_bits); - let mut running_base = limb_base; - let num_bases = F::NUM_BITS as usize / lookup_bits; - let mut limb_bases = Vec::with_capacity(num_bases + 1); - limb_bases.extend([Constant(F::one()), Constant(running_base)]); - for _ in 2..=num_bases { - running_base *= &limb_base; - limb_bases.push(Constant(running_base)); - } + let mut config = + Self { lookup_advice, q_lookup, lookup, lookup_bits, gate, _strategy: range_strategy }; - let config = Self { - lookup_advice, - q_lookup, - lookup, - lookup_bits, - limb_bases, - gate, - strategy: range_strategy, - context_id, - }; - config.create_lookup(meta); + // sanity check: only create lookup table if there are lookup_advice columns + if !num_lookup_advice.is_empty() { + config.create_lookup(meta); + } + config.gate.max_rows = (1 << circuit_degree) - meta.minimum_rows(); + assert!( + (1 << lookup_bits) <= config.gate.max_rows, + "lookup table is too large for the circuit degree plus blinding factors!" + ); config } + /// Returns the number of bits represented in the lookup table [0,2^lookup_bits). + pub fn lookup_bits(&self) -> usize { + self.lookup_bits + } + + /// Instantiates the lookup table of the circuit. + /// * `meta`: [ConstraintSystem] of the circuit fn create_lookup(&self, meta: &mut ConstraintSystem) { for (phase, q_l) in self.q_lookup.iter().enumerate() { if let Some(q) = q_l { @@ -138,6 +151,7 @@ impl RangeConfig { }); } } + //if multiple columns for la in self.lookup_advice.iter().flat_map(|advices| advices.iter()) { meta.lookup("lookup wo selector", |meta| { let a = meta.query_advice(*la, Rotation::cur()); @@ -146,6 +160,8 @@ impl RangeConfig { } } + /// Loads the lookup table into the circuit using the provided `layouter`. + /// * `layouter`: layouter for the circuit pub fn load_lookup_table(&self, layouter: &mut impl Layouter) -> Result<(), Error> { layouter.assign_table( || format!("{} bit lookup", self.lookup_bits), @@ -163,194 +179,397 @@ impl RangeConfig { )?; Ok(()) } +} + +/// Trait that implements methods to constrain a field element number `x` is within a range of bits. +pub trait RangeInstructions { + /// The type of Gate used within the instructions. + type Gate: GateInstructions; + + /// Returns the type of gate used. + fn gate(&self) -> &Self::Gate; + + /// Returns the [GateStrategy] for this range. + fn strategy(&self) -> RangeStrategy; + + /// Returns the number of bits the lookup table represents. + fn lookup_bits(&self) -> usize; + + /// Checks and constrains that `a` lies in the range [0, 2range_bits). + /// + /// Assumes that both `a`<= `range_bits` bits. + /// * a: [AssignedValue] value to be range checked + /// * range_bits: number of bits to represent the range + fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize); + + /// Constrains that 'a' is less than 'b'. + /// + /// Assumes that `a` and `b` have bit length <= num_bits bits. + /// + /// Note: This may fail silently if a or b have more than num_bits. + /// * a: [QuantumCell] value to check + /// * b: upper bound expressed as a [QuantumCell] + /// * num_bits: number of bits used to represent the values of `a` and `b` + fn check_less_than( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + num_bits: usize, + ); - /// Call this at the end of a phase to assign cells to special columns for lookup arguments + /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. /// - /// returns total number of lookup cells assigned - pub fn finalize(&self, ctx: &mut Context<'_, F>) -> usize { - ctx.copy_and_lookup_cells(self.lookup_advice[ctx.current_phase].clone()) + /// * a: [AssignedValue] value to check + /// * b: upper bound expressed as a [u64] value + fn check_less_than_safe(&self, ctx: &mut Context, a: AssignedValue, b: u64) { + let range_bits = + (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); + + self.range_check(ctx, a, range_bits); + self.check_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) } - /// assuming this is called when ctx.region is not in shape mode - /// `offset` is the offset of the cell in `ctx.region` - /// `offset` is only used if there is a single advice column - fn enable_lookup<'a>(&self, ctx: &mut Context<'a, F>, acell: AssignedValue<'a, F>) { - let phase = ctx.current_phase(); - if let Some(q) = &self.q_lookup[phase] { - q.enable(&mut ctx.region, acell.row()).expect("enable selector should not fail"); - } else { - ctx.cells_to_lookup.push(acell); - } + /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. + /// + /// * a: [AssignedValue] value to check + /// * b: upper bound expressed as a [BigUint] value + fn check_big_less_than_safe(&self, ctx: &mut Context, a: AssignedValue, b: BigUint) + where + F: BigPrimeField, + { + let range_bits = + (b.bits() as usize + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); + + self.range_check(ctx, a, range_bits); + self.check_less_than(ctx, a, Constant(biguint_to_fe(&b)), range_bits) } - // returns the limbs - fn range_check_simple<'a>( + /// Constrains whether `a` is in `[0, b)`, and returns 1 if `a` < `b`, otherwise 0. + /// + /// Assumes that`a` and `b` are known to have <= num_bits bits. + /// * a: first [QuantumCell] to compare + /// * b: second [QuantumCell] to compare + /// * num_bits: number of bits to represent the values + fn is_less_than( &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - range_bits: usize, - limbs_assigned: &mut Vec>, - ) { - let k = (range_bits + self.lookup_bits - 1) / self.lookup_bits; - // println!("range check {} bits {} len", range_bits, k); - let rem_bits = range_bits % self.lookup_bits; + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + num_bits: usize, + ) -> AssignedValue; - assert!(self.limb_bases.len() >= k); - if k == 1 { - limbs_assigned.clear(); - limbs_assigned.push(a.clone()) - } else { - let acc = match value_to_option(a.value()) { - Some(a) => { - let limbs = decompose_fe_to_u64_limbs(a, k, self.lookup_bits) - .into_iter() - .map(|x| Witness(Value::known(F::from(x)))); - self.gate.inner_product_left( - ctx, - limbs, - self.limb_bases[..k].iter().cloned(), - limbs_assigned, - ) - } - _ => self.gate.inner_product_left( - ctx, - vec![Witness(Value::unknown()); k], - self.limb_bases[..k].iter().cloned(), - limbs_assigned, - ), - }; - // the inner product above must equal `a` - ctx.region.constrain_equal(a.cell(), acc.cell()); - }; - assert_eq!(limbs_assigned.len(), k); + /// Performs a range check that `a` has at most `ceil(bit_length(b) / lookup_bits) * lookup_bits` and then constrains that `a` is in `[0,b)`. + /// + /// Returns 1 if `a` < `b`, otherwise 0. + /// + /// * a: [AssignedValue] value to check + /// * b: upper bound as [u64] value + fn is_less_than_safe( + &self, + ctx: &mut Context, + a: AssignedValue, + b: u64, + ) -> AssignedValue { + let range_bits = + (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - // range check all the limbs - for limb in limbs_assigned.iter() { - self.enable_lookup(ctx, limb.clone()); - } + self.range_check(ctx, a, range_bits); + self.is_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) + } - // additional constraints for the last limb if rem_bits != 0 - match rem_bits.cmp(&1) { - // we want to check x := limbs[k-1] is boolean - // we constrain x*(x-1) = 0 + x * x - x == 0 - // | 0 | x | x | x | - Ordering::Equal => { - self.gate.assert_bit(ctx, &limbs_assigned[k - 1]); - } - Ordering::Greater => { - let mult_val = self.gate.get_field_element(1u64 << (self.lookup_bits - rem_bits)); - let check = self.gate.assign_region_last( - ctx, - vec![ - Constant(F::zero()), - Existing(&limbs_assigned[k - 1]), - Constant(mult_val), - Witness(limbs_assigned[k - 1].value().map(|limb| mult_val * limb)), - ], - vec![(0, None)], - ); - self.enable_lookup(ctx, check); - } - _ => {} - } + /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is in `[0,b)`. + /// + /// Returns 1 if `a` < `b`, otherwise 0. + /// + /// * a: [AssignedValue] value to check + /// * b: upper bound as [BigUint] value + /// + /// For the current implementation using [`is_less_than`], we require `ceil(b.bits() / lookup_bits) + 1 < F::NUM_BITS / lookup_bits` + fn is_big_less_than_safe( + &self, + ctx: &mut Context, + a: AssignedValue, + b: BigUint, + ) -> AssignedValue + where + F: BigPrimeField, + { + let range_bits = + (b.bits() as usize + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); + + self.range_check(ctx, a, range_bits); + self.is_less_than(ctx, a, Constant(biguint_to_fe(&b)), range_bits) } - /// breaks up `a` into smaller pieces to lookup and stores them in `limbs_assigned` + /// Constrains and returns `(c, r)` such that `a = b * c + r`. /// - /// this is an internal function to avoid memory re-allocation of `limbs_assigned` - pub fn range_check_limbs<'a>( + /// Assumes that `b != 0` and that `a` has <= `a_num_bits` bits. + /// * a: [QuantumCell] value to divide + /// * b: [BigUint] value to divide by + /// * a_num_bits: number of bits needed to represent the value of `a` + fn div_mod( &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - range_bits: usize, - limbs_assigned: &mut Vec>, - ) { - assert_ne!(range_bits, 0); - #[cfg(feature = "display")] - { - let key = format!( - "range check length {}", - (range_bits + self.lookup_bits - 1) / self.lookup_bits - ); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - } - match self.strategy { - RangeStrategy::Vertical | RangeStrategy::PlonkPlus => { - self.range_check_simple(ctx, a, range_bits, limbs_assigned) - } + ctx: &mut Context, + a: impl Into>, + b: impl Into, + a_num_bits: usize, + ) -> (AssignedValue, AssignedValue) + where + F: BigPrimeField, + { + let a = a.into(); + let b = b.into(); + let a_val = fe_to_biguint(a.value()); + let (div, rem) = a_val.div_mod_floor(&b); + let [div, rem] = [div, rem].map(|v| biguint_to_fe(&v)); + ctx.assign_region([Witness(rem), Constant(biguint_to_fe(&b)), Witness(div), a], [0]); + let rem = ctx.get(-4); + let div = ctx.get(-2); + // Constrain that a_num_bits fulfills `div < 2 ** a_num_bits / b`. + self.check_big_less_than_safe( + ctx, + div, + BigUint::one().shl(a_num_bits as u32) / &b + BigUint::one(), + ); + // Constrain that remainder is less than divisor (i.e. `r < b`). + self.check_big_less_than_safe(ctx, rem, b); + (div, rem) + } + + /// Constrains and returns `(c, r)` such that `a = b * c + r`. + /// + /// Assumes: + /// that `b != 0`. + /// that `a` has <= `a_num_bits` bits. + /// that `b` has <= `b_num_bits` bits. + /// + /// Note: + /// Let `X = 2 ** b_num_bits` + /// Write `a = a1 * X + a0` and `c = c1 * X + c0` + /// If we write `b * c0 + r = d1 * X + d0` then + /// `b * c + r = (b * c1 + d1) * X + d0` + /// * a: [QuantumCell] value to divide + /// * b: [QuantumCell] value to divide by + /// * a_num_bits: number of bits needed to represent the value of `a` + /// * b_num_bits: number of bits needed to represent the value of `b` + /// + fn div_mod_var( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + a_num_bits: usize, + b_num_bits: usize, + ) -> (AssignedValue, AssignedValue) + where + F: BigPrimeField, + { + let a = a.into(); + let b = b.into(); + let a_val = fe_to_biguint(a.value()); + let b_val = fe_to_biguint(b.value()); + let (div, rem) = a_val.div_mod_floor(&b_val); + let x = BigUint::one().shl(b_num_bits as u32); + let (div_hi, div_lo) = div.div_mod_floor(&x); + + let x_fe = self.gate().pow_of_two()[b_num_bits]; + let [div, div_hi, div_lo, rem] = [div, div_hi, div_lo, rem].map(|v| biguint_to_fe(&v)); + ctx.assign_region( + [Witness(div_lo), Witness(div_hi), Constant(x_fe), Witness(div), Witness(rem)], + [0], + ); + let [div_lo, div_hi, div, rem] = [-5, -4, -2, -1].map(|i| ctx.get(i)); + self.range_check(ctx, div_lo, b_num_bits); + if a_num_bits <= b_num_bits { + self.gate().assert_is_const(ctx, &div_hi, &F::zero()); + } else { + self.range_check(ctx, div_hi, a_num_bits - b_num_bits); } + + let (bcr0_hi, bcr0_lo) = { + let bcr0 = self.gate().mul_add(ctx, b, Existing(div_lo), Existing(rem)); + self.div_mod(ctx, Existing(bcr0), x.clone(), a_num_bits) + }; + let bcr_hi = self.gate().mul_add(ctx, b, Existing(div_hi), Existing(bcr0_hi)); + + let (a_hi, a_lo) = self.div_mod(ctx, a, x, a_num_bits); + ctx.constrain_equal(&bcr_hi, &a_hi); + ctx.constrain_equal(&bcr0_lo, &a_lo); + + self.range_check(ctx, rem, b_num_bits); + self.check_less_than(ctx, Existing(rem), b, b_num_bits); + (div, rem) } - /// assume `a` has been range checked already to `limb_bits` bits - pub fn get_last_bit<'a>( + /// Constrains and returns the last bit of the value of `a`. + /// + /// Assume `a` has been range checked already to `limb_bits` bits. + /// * a: [AssignedValue] value to get the last bit of + /// * limb_bits: number of bits in a limb + fn get_last_bit( &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, + ctx: &mut Context, + a: AssignedValue, limb_bits: usize, - ) -> AssignedValue<'a, F> { - let a_v = a.value(); - let bit_v = a_v.map(|a| { - let a = a.get_lower_32(); - if a ^ 1 == 0 { - F::zero() - } else { - F::one() - } + ) -> AssignedValue { + let a_big = fe_to_biguint(a.value()); + let bit_v = F::from(a_big.bit(0)); + let two = self.gate().get_field_element(2u64); + let h_v = F::from_bytes_le(&(a_big >> 1usize).to_bytes_le()); + + ctx.assign_region([Witness(bit_v), Witness(h_v), Constant(two), Existing(a)], [0]); + let half = ctx.get(-3); + let bit = ctx.get(-4); + + self.range_check(ctx, half, limb_bits - 1); + self.gate().assert_bit(ctx, bit); + bit + } +} + +/// A chip that implements RangeInstructions which provides methods to constrain a field element `x` is within a range of bits. +#[derive(Clone, Debug)] +pub struct RangeChip { + /// # RangeChip + /// Provides methods to constrain a field element `x` is within a range of bits. + /// Declares a lookup table of [0, 2lookup_bits) and constrains whether a field element appears in this table. + + /// [GateStrategy] for advice values in this chip. + strategy: RangeStrategy, + /// Underlying [GateChip] for this chip. + pub gate: GateChip, + /// Defines the number of bits represented in the lookup table [0,2lookup_bits). + pub lookup_bits: usize, + /// [Vec] of powers of `2 ** lookup_bits` represented as [QuantumCell::Constant]. + /// These are precomputed and cached as a performance optimization for later limb decompositions. We precompute up to the higher power that fits in `F`, which is `2 ** ((F::CAPACITY / lookup_bits) * lookup_bits)`. + pub limb_bases: Vec>, +} + +impl RangeChip { + /// Creates a new [RangeChip] with the given strategy and lookup_bits. + /// * strategy: [GateStrategy] for advice values in this chip + /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) + pub fn new(strategy: RangeStrategy, lookup_bits: usize) -> Self { + let limb_base = F::from(1u64 << lookup_bits); + let mut running_base = limb_base; + let num_bases = F::CAPACITY as usize / lookup_bits; + let mut limb_bases = Vec::with_capacity(num_bases + 1); + limb_bases.extend([Constant(F::one()), Constant(running_base)]); + for _ in 2..=num_bases { + running_base *= &limb_base; + limb_bases.push(Constant(running_base)); + } + let gate = GateChip::new(match strategy { + RangeStrategy::Vertical => GateStrategy::Vertical, }); - let two = self.gate.get_field_element(2u64); - let h_v = a.value().zip(bit_v).map(|(a, b)| (*a - b) * two.invert().unwrap()); - let assignments = self.gate.assign_region_smart( - ctx, - vec![Witness(bit_v), Witness(h_v), Constant(two), Existing(a)], - vec![0], - vec![], - vec![], - ); - self.range_check(ctx, &assignments[1], limb_bits - 1); - assignments.into_iter().next().unwrap() + Self { strategy, gate, lookup_bits, limb_bases } + } + + /// Creates a new [RangeChip] with the default strategy and provided lookup_bits. + /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) + pub fn default(lookup_bits: usize) -> Self { + Self::new(RangeStrategy::Vertical, lookup_bits) } } -impl RangeInstructions for RangeConfig { - type Gate = FlexGateConfig; +impl RangeInstructions for RangeChip { + type Gate = GateChip; + /// The type of Gate used in this chip. fn gate(&self) -> &Self::Gate { &self.gate } + + /// Returns the [GateStrategy] for this range. fn strategy(&self) -> RangeStrategy { self.strategy } + /// Defines the number of bits represented in the lookup table [0,2lookup_bits). fn lookup_bits(&self) -> usize { self.lookup_bits } - fn range_check<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - range_bits: usize, - ) { - let tmp = ctx.preallocated_vec_to_assign(); - self.range_check_limbs(ctx, a, range_bits, &mut tmp.as_ref().borrow_mut()); + /// Checks and constrains that `a` lies in the range [0, 2range_bits). + /// + /// This is done by decomposing `a` into `k` limbs, where `k = ceil(range_bits / lookup_bits)`. + /// Each limb is constrained to be within the range [0, 2lookup_bits). + /// The limbs are then combined to form `a` again with the last limb having `rem_bits` number of bits. + /// + /// * `a`: [AssignedValue] value to be range checked + /// * `range_bits`: number of bits in the range + /// * `lookup_bits`: number of bits in the lookup table + /// + /// # Assumptions + /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` + fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { + // the number of limbs + let k = (range_bits + self.lookup_bits - 1) / self.lookup_bits; + // println!("range check {} bits {} len", range_bits, k); + let rem_bits = range_bits % self.lookup_bits; + + debug_assert!(self.limb_bases.len() >= k); + + if k == 1 { + ctx.cells_to_lookup.push(a); + } else { + let limbs = decompose_fe_to_u64_limbs(a.value(), k, self.lookup_bits) + .into_iter() + .map(|x| Witness(F::from(x))); + let row_offset = ctx.advice.len() as isize; + let acc = self.gate.inner_product(ctx, limbs, self.limb_bases[..k].to_vec()); + // the inner product above must equal `a` + ctx.constrain_equal(&a, &acc); + // we fetch the cells to lookup by getting the indices where `limbs` were assigned in `inner_product`. Because `limb_bases[0]` is 1, the progression of indices is 0,1,4,...,4+3*i + ctx.cells_to_lookup.push(ctx.get(row_offset)); + for i in 0..k - 1 { + ctx.cells_to_lookup.push(ctx.get(row_offset + 1 + 3 * i as isize)); + } + }; + + // additional constraints for the last limb if rem_bits != 0 + match rem_bits.cmp(&1) { + // we want to check x := limbs[k-1] is boolean + // we constrain x*(x-1) = 0 + x * x - x == 0 + // | 0 | x | x | x | + Ordering::Equal => { + self.gate.assert_bit(ctx, *ctx.cells_to_lookup.last().unwrap()); + } + Ordering::Greater => { + let mult_val = self.gate.pow_of_two[self.lookup_bits - rem_bits]; + let check = + self.gate.mul(ctx, *ctx.cells_to_lookup.last().unwrap(), Constant(mult_val)); + ctx.cells_to_lookup.push(check); + } + _ => {} + } } - /// Warning: This may fail silently if a or b have more than num_bits - fn check_less_than<'a>( + /// Constrains that 'a' is less than 'b'. + /// + /// Assumes that`a` and `b` are known to have <= num_bits bits. + /// + /// Note: This may fail silently if a or b have more than num_bits + /// * a: [QuantumCell] value to check + /// * b: upper bound expressed as a [QuantumCell] + /// * num_bits: number of bits to represent the values + fn check_less_than( &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, num_bits: usize, ) { + let a = a.into(); + let b = b.into(); let pow_of_two = self.gate.pow_of_two[num_bits]; let check_cell = match self.strategy { RangeStrategy::Vertical => { - let shift_a_val = a.value().map(|av| pow_of_two + av); + let shift_a_val = pow_of_two + a.value(); // | a + 2^(num_bits) - b | b | 1 | a + 2^(num_bits) | - 2^(num_bits) | 1 | a | - let cells = vec![ + let cells = [ Witness(shift_a_val - b.value()), b, Constant(F::one()), @@ -359,48 +578,47 @@ impl RangeInstructions for RangeConfig { Constant(F::one()), a, ]; - let assigned_cells = - self.gate.assign_region(ctx, cells, vec![(0, None), (3, None)]); - assigned_cells.into_iter().next().unwrap() - } - RangeStrategy::PlonkPlus => { - // | a | 1 | b | a + 2^{num_bits} - b | - // selectors: - // | 1 | 0 | 0 | - // | 0 | 2^{num_bits} | -1 | - let out_val = Value::known(pow_of_two) + a.value() - b.value(); - let assigned_cells = self.gate.assign_region( - ctx, - vec![a, Constant(F::one()), b, Witness(out_val)], - vec![(0, Some([F::zero(), pow_of_two, -F::one()]))], - ); - assigned_cells.into_iter().nth(3).unwrap() + ctx.assign_region(cells, [0, 3]); + ctx.get(-7) } }; - self.range_check(ctx, &check_cell, num_bits); + self.range_check(ctx, check_cell, num_bits); } - /// Warning: This may fail silently if a or b have more than num_bits - fn is_less_than<'a>( + /// Constrains whether `a` is in `[0, b)`, and returns 1 if `a` < `b`, otherwise 0. + /// + /// * a: first [QuantumCell] to compare + /// * b: second [QuantumCell] to compare + /// * num_bits: number of bits to represent the values + /// + /// # Assumptions + /// * `a` and `b` are known to have `<= num_bits` bits. + /// * (`ceil(num_bits / lookup_bits) + 1) * lookup_bits <= F::CAPACITY` + fn is_less_than( &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, num_bits: usize, - ) -> AssignedValue<'a, F> { - // TODO: optimize this for PlonkPlus strategy + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let k = (num_bits + self.lookup_bits - 1) / self.lookup_bits; let padded_bits = k * self.lookup_bits; + debug_assert!( + padded_bits + self.lookup_bits <= F::CAPACITY as usize, + "num_bits is too large for this is_less_than implementation" + ); let pow_padded = self.gate.pow_of_two[padded_bits]; - let shift_a_val = a.value().map(|av| pow_padded + av); + let shift_a_val = pow_padded + a.value(); let shifted_val = shift_a_val - b.value(); let shifted_cell = match self.strategy { RangeStrategy::Vertical => { - let assignments = self.gate.assign_region_smart( - ctx, - vec![ + ctx.assign_region( + [ Witness(shifted_val), b, Constant(F::one()), @@ -409,29 +627,16 @@ impl RangeInstructions for RangeConfig { Constant(F::one()), a, ], - vec![0, 3], - vec![], - vec![], + [0, 3], ); - assignments.into_iter().next().unwrap() + ctx.get(-7) } - RangeStrategy::PlonkPlus => self.gate.assign_region_last( - ctx, - vec![a, Constant(pow_padded), b, Witness(shifted_val)], - vec![(0, Some([F::zero(), F::one(), -F::one()]))], - ), }; // check whether a - b + 2^padded_bits < 2^padded_bits ? // since assuming a, b < 2^padded_bits we are guaranteed a - b + 2^padded_bits < 2^{padded_bits + 1} - let limbs = ctx.preallocated_vec_to_assign(); - self.range_check_limbs( - ctx, - &shifted_cell, - padded_bits + self.lookup_bits, - &mut limbs.borrow_mut(), - ); - let res = self.gate().is_zero(ctx, limbs.borrow().get(k).unwrap()); - res + self.range_check(ctx, shifted_cell, padded_bits + self.lookup_bits); + // ctx.cells_to_lookup.last() will have the (k + 1)-th limb of `a - b + 2^{k * limb_bits}`, which is zero iff `a < b` + self.gate.is_zero(ctx, *ctx.cells_to_lookup.last().unwrap()) } } diff --git a/halo2-base/src/gates/tests.rs b/halo2-base/src/gates/tests.rs deleted file mode 100644 index c4e811a3..00000000 --- a/halo2-base/src/gates/tests.rs +++ /dev/null @@ -1,463 +0,0 @@ -use super::{ - flex_gate::{FlexGateConfig, GateStrategy}, - range, GateInstructions, RangeInstructions, -}; -use crate::halo2_proofs::{circuit::*, dev::MockProver, halo2curves::bn256::Fr, plonk::*}; -use crate::{ - Context, ContextParams, - QuantumCell::{Constant, Existing, Witness}, - SKIP_FIRST_PASS, -}; - -#[derive(Default)] -struct MyCircuit { - a: Value, - b: Value, - c: Value, -} - -const NUM_ADVICE: usize = 2; - -impl Circuit for MyCircuit { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FlexGateConfig::configure( - meta, - GateStrategy::Vertical, - &[NUM_ADVICE], - 1, - 0, - 6, /* params K */ - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "gate", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.max_rows, - num_context_ids: 1, - fixed_columns: config.constants.clone(), - }, - ); - let ctx = &mut aux; - - let (a_cell, b_cell, c_cell) = { - let cells = config.assign_region_smart( - ctx, - vec![Witness(self.a), Witness(self.b), Witness(self.c)], - vec![], - vec![], - vec![], - ); - (cells[0].clone(), cells[1].clone(), cells[2].clone()) - }; - - // test add - { - config.add(ctx, Existing(&a_cell), Existing(&b_cell)); - } - - // test sub - { - config.sub(ctx, Existing(&a_cell), Existing(&b_cell)); - } - - // test multiply - { - config.mul(ctx, Existing(&c_cell), Existing(&b_cell)); - } - - // test idx_to_indicator - { - config.idx_to_indicator(ctx, Constant(Fr::from(3u64)), 4); - } - - { - let bits = config.assign_witnesses( - ctx, - vec![Value::known(Fr::zero()), Value::known(Fr::one())], - ); - config.bits_to_indicator(ctx, &bits); - } - - #[cfg(feature = "display")] - { - println!("total advice cells: {}", ctx.total_advice); - let const_rows = ctx.fixed_offset + 1; - println!("maximum rows used by a fixed column: {const_rows}"); - } - - Ok(()) - }, - ) - } -} - -#[test] -fn test_gates() { - let k = 6; - let circuit = MyCircuit:: { - a: Value::known(Fr::from(10u64)), - b: Value::known(Fr::from(12u64)), - c: Value::known(Fr::from(120u64)), - }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - // assert_eq!(prover.verify(), Ok(())); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_gates() { - let k = 5; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Gates Layout", ("sans-serif", 60)).unwrap(); - - let circuit = MyCircuit::::default(); - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); -} - -#[derive(Default)] -struct RangeTestCircuit { - range_bits: usize, - lt_bits: usize, - a: Value, - b: Value, -} - -impl Circuit for RangeTestCircuit { - type Config = range::RangeConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - range_bits: self.range_bits, - lt_bits: self.lt_bits, - a: Value::unknown(), - b: Value::unknown(), - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - range::RangeConfig::configure( - meta, - range::RangeStrategy::Vertical, - &[NUM_ADVICE], - &[1], - 1, - 3, - 0, - 11, /* params K */ - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.load_lookup_table(&mut layouter)?; - - /* - // let's try a separate layouter for loading private inputs - let (a, b) = layouter.assign_region( - || "load private inputs", - |region| { - let mut aux = Context::new( - region, - ContextParams { - num_advice: vec![("default".to_string(), NUM_ADVICE)], - fixed_columns: config.gate.constants.clone(), - }, - ); - let cells = config.gate.assign_region_smart( - &mut aux, - vec![Witness(self.a), Witness(self.b)], - vec![], - vec![], - vec![], - )?; - Ok((cells[0].clone(), cells[1].clone())) - }, - )?; */ - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "range", - |region| { - // If we uncomment out the line below, get_shape will be empty and the layouter will try to assign at row 0, but "load private inputs" has already assigned to row 0, so this will panic and fail - - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.gate.max_rows, - num_context_ids: 1, - fixed_columns: config.gate.constants.clone(), - }, - ); - let ctx = &mut aux; - - let (a, b) = { - let cells = config.gate.assign_region_smart( - ctx, - vec![Witness(self.a), Witness(self.b)], - vec![], - vec![], - vec![], - ); - (cells[0].clone(), cells[1].clone()) - }; - - { - config.range_check(ctx, &a, self.range_bits); - } - { - config.check_less_than(ctx, Existing(&a), Existing(&b), self.lt_bits); - } - { - config.is_less_than(ctx, Existing(&a), Existing(&b), self.lt_bits); - } - { - config.is_less_than(ctx, Existing(&b), Existing(&a), self.lt_bits); - } - { - config.gate().is_equal(ctx, Existing(&b), Existing(&a)); - } - { - config.gate().is_zero(ctx, &a); - } - - config.finalize(ctx); - - #[cfg(feature = "display")] - { - println!("total advice cells: {}", ctx.total_advice); - let const_rows = ctx.fixed_offset + 1; - println!("maximum rows used by a fixed column: {const_rows}"); - println!("lookup cells used: {}", ctx.cells_to_lookup.len()); - } - Ok(()) - }, - ) - } -} - -#[test] -fn test_range() { - let k = 11; - let circuit = RangeTestCircuit:: { - range_bits: 8, - lt_bits: 8, - a: Value::known(Fr::from(100u64)), - b: Value::known(Fr::from(101u64)), - }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - //assert_eq!(prover.verify(), Ok(())); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_range() { - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Range Layout", ("sans-serif", 60)).unwrap(); - - let circuit = RangeTestCircuit:: { - range_bits: 8, - lt_bits: 8, - a: Value::unknown(), - b: Value::unknown(), - }; - - halo2_proofs::dev::CircuitLayout::default().render(7, &circuit, &root).unwrap(); -} - -mod lagrange { - use crate::halo2_proofs::{ - arithmetic::Field, - halo2curves::bn256::{Bn256, G1Affine}, - poly::{ - commitment::{Params, ParamsProver}, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, - }; - use ark_std::{end_timer, start_timer}; - use rand::rngs::OsRng; - - use super::*; - - #[derive(Default)] - struct MyCircuit { - coords: Vec>, - a: Value, - } - - const NUM_ADVICE: usize = 6; - - impl Circuit for MyCircuit { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - coords: self.coords.iter().map(|_| Value::unknown()).collect(), - a: Value::unknown(), - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FlexGateConfig::configure(meta, GateStrategy::PlonkPlus, &[NUM_ADVICE], 1, 0, 14) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "gate", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.max_rows, - num_context_ids: 1, - fixed_columns: config.constants.clone(), - }, - ); - let ctx = &mut aux; - - let x = - config.assign_witnesses(ctx, self.coords.iter().map(|c| c.map(|c| c.0))); - let y = - config.assign_witnesses(ctx, self.coords.iter().map(|c| c.map(|c| c.1))); - - let a = config.assign_witnesses(ctx, vec![self.a]).pop().unwrap(); - - config.lagrange_and_eval( - ctx, - &x.into_iter().zip(y.into_iter()).collect::>(), - &a, - ); - - #[cfg(feature = "display")] - { - println!("total advice cells: {}", ctx.total_advice); - } - - Ok(()) - }, - ) - } - } - - #[test] - fn test_lagrange() -> Result<(), Box> { - let k = 14; - let mut rng = OsRng; - let circuit = MyCircuit:: { - coords: (0..100) - .map(|i: u64| Value::known((Fr::from(i), Fr::random(&mut rng)))) - .collect(), - a: Value::known(Fr::from(100u64)), - }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - - let fd = std::fs::File::open(format!("../halo2_ecc/params/kzg_bn254_{k}.srs").as_str()); - let params = if let Ok(mut f) = fd { - println!("Found existing params file. Reading params..."); - ParamsKZG::::read(&mut f).unwrap() - } else { - ParamsKZG::::setup(k, &mut rng) - }; - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); - end_timer!(verify_time); - - Ok(()) - } -} diff --git a/halo2-base/src/gates/tests/README.md b/halo2-base/src/gates/tests/README.md new file mode 100644 index 00000000..24f34537 --- /dev/null +++ b/halo2-base/src/gates/tests/README.md @@ -0,0 +1,9 @@ +# Tests + +For tests that use `GateCircuitBuilder` or `RangeCircuitBuilder`, we currently must use environmental variables `FLEX_GATE_CONFIG` and `LOOKUP_BITS` to pass circuit configuration parameters to the `Circuit::configure` function. This is troublesome when Rust executes tests in parallel, so we to make sure all tests pass, run + +``` +cargo test -- --test-threads=1 +``` + +to force serial execution. diff --git a/halo2-base/src/gates/tests/flex_gate_tests.rs b/halo2-base/src/gates/tests/flex_gate_tests.rs new file mode 100644 index 00000000..b6d3e5ec --- /dev/null +++ b/halo2-base/src/gates/tests/flex_gate_tests.rs @@ -0,0 +1,266 @@ +use super::*; +use crate::halo2_proofs::dev::MockProver; +use crate::halo2_proofs::dev::VerifyFailure; +use crate::utils::ScalarField; +use crate::QuantumCell::Witness; +use crate::{ + gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder}, + flex_gate::{GateChip, GateInstructions}, + }, + QuantumCell, +}; +use test_case::test_case; + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "add(): 1 + 1 == 2")] +pub fn test_add(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.add(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub(): 1 - 1 == 0")] +pub fn test_sub(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.sub(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(Witness(Fr::from(1)) => -Fr::from(1) ; "neg(): 1 -> -1")] +pub fn test_neg(a: QuantumCell) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.neg(ctx, a); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "mul(): 1 * 1 == 1")] +pub fn test_mul(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.mul(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "mul_add(): 1 * 1 + 1 == 2")] +pub fn test_mul_add(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.mul_add(ctx, inputs[0], inputs[1], inputs[2]); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "mul_not(): 1 * 1 == 0")] +pub fn test_mul_not(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.mul_not(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(Fr::from(1) => Ok(()); "assert_bit(): 1 == bit")] +pub fn test_assert_bit(input: F) -> Result<(), Vec> { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([input])[0]; + chip.assert_bit(ctx, a); + // auto-tune circuit + builder.config(6, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + MockProver::run(6, &circuit, vec![]).unwrap().verify() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "div_unsafe(): 1 / 1 == 1")] +pub fn test_div_unsafe(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.div_unsafe(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from); "assert_is_const()")] +pub fn test_assert_is_const(inputs: &[F]) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([inputs[0]])[0]; + chip.assert_is_const(ctx, &a, &inputs[1]); + // auto-tune circuit + builder.config(6, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + MockProver::run(6, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => Fr::from(5) ; "inner_product(): 1 * 1 + ... + 1 * 1 == 5")] +pub fn test_inner_product(input: (Vec>, Vec>)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.inner_product(ctx, input.0, input.1); + *a.value() +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (Fr::from(5), Fr::from(1)); "inner_product_left_last(): 1 * 1 + ... + 1 * 1 == (5, 1)")] +pub fn test_inner_product_left_last( + input: (Vec>, Vec>), +) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.inner_product_left_last(ctx, input.0, input.1); + (*a.0.value(), *a.1.value()) +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => vec![Fr::one(), Fr::from(2), Fr::from(3), Fr::from(4), Fr::from(5)]; "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] +pub fn test_inner_product_with_sums( + input: (Vec>, Vec>), +) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.inner_product_with_sums(ctx, input.0, input.1); + a.into_iter().map(|x| *x.value()).collect() +} + +#[test_case((vec![(Fr::from(1), Witness(Fr::from(1)), Witness(Fr::from(1)))], Witness(Fr::from(1))) => Fr::from(2) ; "sum_product_with_coeff_and_var(): 1 * 1 + 1 == 2")] +pub fn test_sum_products_with_coeff_and_var( + input: (Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), +) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.sum_products_with_coeff_and_var(ctx, input.0, input.1); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "and(): 1 && 1 == 1")] +pub fn test_and(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.and(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(Witness(Fr::from(1)) => Fr::zero() ; "not(): !1 == 0")] +pub fn test_not(a: QuantumCell) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.not(ctx, a); + *a.value() +} + +#[test_case(&[2, 3, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "select(): 2 ? 3 : 1 == 2")] +pub fn test_select(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.select(ctx, inputs[0], inputs[1], inputs[2]); + *a.value() +} + +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "or_and(): 1 || 1 && 1 == 1")] +pub fn test_or_and(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.or_and(ctx, inputs[0], inputs[1], inputs[2]); + *a.value() +} + +#[test_case(Fr::zero() => vec![Fr::one(), Fr::zero()]; "bits_to_indicator(): 0 -> [1, 0]")] +pub fn test_bits_to_indicator(bits: F) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([bits])[0]; + let a = chip.bits_to_indicator(ctx, &[a]); + a.iter().map(|x| *x.value()).collect() +} + +#[test_case((Witness(Fr::zero()), 3) => vec![Fr::one(), Fr::zero(), Fr::zero()] ; "idx_to_indicator(): 0 -> [1, 0, 0]")] +pub fn test_idx_to_indicator(input: (QuantumCell, usize)) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.idx_to_indicator(ctx, input.0, input.1); + a.iter().map(|x| *x.value()).collect() +} + +#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_by_indicator(): [0, 1, 2] -> 1")] +pub fn test_select_by_indicator(input: (Vec>, QuantumCell)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); + let a = chip.select_by_indicator(ctx, input.0, a); + *a.value() +} + +#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_from_idx(): [0, 1, 2] -> 1")] +pub fn test_select_from_idx(input: (Vec>, QuantumCell)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); + let a = chip.select_by_indicator(ctx, input.0, a); + *a.value() +} + +#[test_case(Fr::zero() => Fr::from(1) ; "is_zero(): 0 -> 1")] +pub fn test_is_zero(x: F) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([x])[0]; + let a = chip.is_zero(ctx, a); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::one() ; "is_equal(): 1 == 1")] +pub fn test_is_equal(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.is_equal(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case((Fr::from(6u64), 3) => vec![Fr::zero(), Fr::one(), Fr::one()] ; "num_to_bits(): 6")] +pub fn test_num_to_bits(input: (F, usize)) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([input.0])[0]; + let a = chip.num_to_bits(ctx, a, input.1); + a.iter().map(|x| *x.value()).collect() +} + +#[test_case(&[0, 1, 2].map(Fr::from) => (Fr::one(), Fr::from(2)) ; "lagrange_eval(): constant fn")] +pub fn test_lagrange_eval(input: &[F]) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let input = ctx.assign_witnesses(input.iter().copied()); + let a = chip.lagrange_and_eval(ctx, &[(input[0], input[1])], input[2]); + (*a.0.value(), *a.1.value()) +} + +#[test_case(1 => Fr::one(); "inner_product_simple(): 1 -> 1")] +pub fn test_get_field_element(n: u64) -> F { + let chip = GateChip::default(); + chip.get_field_element(n) +} diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs new file mode 100644 index 00000000..61b4f870 --- /dev/null +++ b/halo2-base/src/gates/tests/general.rs @@ -0,0 +1,170 @@ +use super::*; +use crate::gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, + flex_gate::{GateChip, GateInstructions}, + range::{RangeChip, RangeInstructions}, +}; +use crate::halo2_proofs::dev::MockProver; +use crate::utils::{BigPrimeField, ScalarField}; +use crate::{Context, QuantumCell::Constant}; +use ff::Field; +use rayon::prelude::*; + +fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { + let [a, b, c]: [_; 3] = ctx.assign_witnesses(inputs).try_into().unwrap(); + let chip = GateChip::default(); + + // test add + chip.add(ctx, a, b); + + // test sub + chip.sub(ctx, a, b); + + // test multiply + chip.mul(ctx, c, b); + + // test idx_to_indicator + chip.idx_to_indicator(ctx, Constant(F::from(3u64)), 4); + + let bits = ctx.assign_witnesses([F::zero(), F::one()]); + chip.bits_to_indicator(ctx, &bits); + + chip.is_equal(ctx, b, a); + + chip.is_zero(ctx, a); +} + +#[test] +fn test_gates() { + let k = 6; + let inputs = [10u64, 12u64, 120u64].map(Fr::from); + let mut builder = GateThreadBuilder::mock(); + gate_tests(builder.main(0), inputs); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_multithread_gates() { + let k = 6; + let inputs = [10u64, 12u64, 120u64].map(Fr::from); + let mut builder = GateThreadBuilder::mock(); + gate_tests(builder.main(0), inputs); + + let thread_ids = (0..4usize).map(|_| builder.get_new_thread_id()).collect::>(); + let new_threads = thread_ids + .into_par_iter() + .map(|id| { + let mut ctx = Context::new(builder.witness_gen_only(), id); + gate_tests(&mut ctx, [(); 3].map(|_| Fr::random(OsRng))); + ctx + }) + .collect::>(); + builder.threads[0].extend(new_threads); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[cfg(feature = "dev-graph")] +#[test] +fn plot_gates() { + let k = 5; + use plotters::prelude::*; + + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let root = root.titled("Gates Layout", ("sans-serif", 60)).unwrap(); + + let inputs = [Fr::zero(); 3]; + let builder = GateThreadBuilder::new(false); + gate_tests(builder.main(0), inputs); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::keygen(builder); + halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); +} + +fn range_tests( + ctx: &mut Context, + lookup_bits: usize, + inputs: [F; 2], + range_bits: usize, + lt_bits: usize, +) { + let [a, b]: [_; 2] = ctx.assign_witnesses(inputs).try_into().unwrap(); + let chip = RangeChip::default(lookup_bits); + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + + chip.range_check(ctx, a, range_bits); + + chip.check_less_than(ctx, a, b, lt_bits); + + chip.is_less_than(ctx, a, b, lt_bits); + + chip.is_less_than(ctx, b, a, lt_bits); + + chip.div_mod(ctx, a, 7u64, lt_bits); +} + +#[test] +fn test_range_single() { + let k = 11; + let inputs = [100, 101].map(Fr::from); + let mut builder = GateThreadBuilder::mock(); + range_tests(builder.main(0), 3, inputs, 8, 8); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_range_multicolumn() { + let k = 5; + let inputs = [100, 101].map(Fr::from); + let mut builder = GateThreadBuilder::mock(); + range_tests(builder.main(0), 3, inputs, 8, 8); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[cfg(feature = "dev-graph")] +#[test] +fn plot_range() { + use plotters::prelude::*; + + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let root = root.titled("Range Layout", ("sans-serif", 60)).unwrap(); + + let k = 11; + let inputs = [0, 0].map(Fr::from); + let mut builder = GateThreadBuilder::new(false); + range_tests(builder.main(0), 3, inputs, 8, 8); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::keygen(builder); + halo2_proofs::dev::CircuitLayout::default().render(7, &circuit, &root).unwrap(); +} diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs new file mode 100644 index 00000000..4db68e3e --- /dev/null +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -0,0 +1,119 @@ +use crate::{ + gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder}, + GateChip, GateInstructions, + }, + halo2_proofs::{ + plonk::keygen_pk, + plonk::{keygen_vk, Assigned}, + poly::kzg::commitment::ParamsKZG, + }, +}; + +use ff::Field; +use itertools::Itertools; +use rand::{thread_rng, Rng}; + +use super::*; +use crate::QuantumCell::Witness; + +// soundness checks for `idx_to_indicator` function +fn test_idx_to_indicator_gen(k: u32, len: usize) { + // first create proving and verifying key + let mut builder = GateThreadBuilder::keygen(); + let gate = GateChip::default(); + let dummy_idx = Witness(Fr::zero()); + let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); + // get the offsets of the indicator cells for later 'pranking' + let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); + // set env vars + builder.config(k as usize, Some(9)); + let circuit = GateCircuitBuilder::keygen(builder); + + let params = ParamsKZG::setup(k, OsRng); + // generate proving key + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = pk.get_vk(); // pk consumed vk + + // now create different proofs to test the soundness of the circuit + + let gen_pf = |idx: usize, ind_witnesses: &[Fr]| { + let mut builder = GateThreadBuilder::prover(); + let gate = GateChip::default(); + let idx = Witness(Fr::from(idx as u64)); + gate.idx_to_indicator(builder.main(0), idx, len); + // prank the indicator cells + for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { + builder.main(0).advice[*offset] = Assigned::Trivial(*witness); + } + let circuit = GateCircuitBuilder::prover(builder, vec![vec![]]); // no break points + gen_proof(¶ms, &pk, circuit) + }; + + // expected answer + for idx in 0..len { + let mut ind_witnesses = vec![Fr::zero(); len]; + ind_witnesses[idx] = Fr::one(); + let pf = gen_pf(idx, &ind_witnesses); + check_proof(¶ms, vk, &pf, true); + } + + let mut rng = thread_rng(); + // bad cases + for idx in 0..len { + let mut ind_witnesses = vec![Fr::zero(); len]; + // all zeros is bad! + let pf = gen_pf(idx, &ind_witnesses); + check_proof(¶ms, vk, &pf, false); + + // ind[idx] != 1 is bad! + for _ in 0..100usize { + ind_witnesses.fill(Fr::zero()); + ind_witnesses[idx] = Fr::random(OsRng); + if ind_witnesses[idx] == Fr::one() { + continue; + } + let pf = gen_pf(idx, &ind_witnesses); + check_proof(¶ms, vk, &pf, false); + } + + if len < 2 { + continue; + } + // nonzeros where there should be zeros is bad! + for _ in 0..100usize { + ind_witnesses.fill(Fr::zero()); + ind_witnesses[idx] = Fr::one(); + let num_nonzeros = rng.gen_range(1..len); + let mut count = 0usize; + for _ in 0..num_nonzeros { + let index = rng.gen_range(0..len); + if index == idx { + continue; + } + ind_witnesses[index] = Fr::random(&mut rng); + count += 1; + } + if count == 0usize { + continue; + } + let pf = gen_pf(idx, &ind_witnesses); + check_proof(¶ms, vk, &pf, false); + } + } +} + +#[test] +fn test_idx_to_indicator() { + test_idx_to_indicator_gen(8, 1); + test_idx_to_indicator_gen(8, 4); + test_idx_to_indicator_gen(8, 10); + test_idx_to_indicator_gen(8, 20); +} + +#[test] +#[ignore = "takes too long"] +fn test_idx_to_indicator_large() { + test_idx_to_indicator_gen(11, 100); +} diff --git a/halo2-base/src/gates/tests/mod.rs b/halo2-base/src/gates/tests/mod.rs new file mode 100644 index 00000000..a12adeba --- /dev/null +++ b/halo2-base/src/gates/tests/mod.rs @@ -0,0 +1,73 @@ +#![allow(clippy::type_complexity)] +use crate::halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, + multiopen::VerifierSHPLONK, strategy::SingleStrategy, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, +}; +use rand::rngs::OsRng; + +#[cfg(test)] +mod flex_gate_tests; +#[cfg(test)] +mod general; +#[cfg(test)] +mod idx_to_indicator; +#[cfg(test)] +mod neg_prop_tests; +#[cfg(test)] +mod pos_prop_tests; +#[cfg(test)] +mod range_gate_tests; +#[cfg(test)] +mod test_ground_truths; + +/// helper function to generate a proof with real prover +pub fn gen_proof( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, +) -> Vec { + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255<_>, + _, + Blake2bWrite, G1Affine, _>, + _, + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("prover should not fail"); + transcript.finalize() +} + +/// helper function to verify a proof +pub fn check_proof( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + expect_satisfied: bool, +) { + let verifier_params = params.verifier_params(); + let strategy = SingleStrategy::new(params); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); + let res = verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >(verifier_params, vk, strategy, &[&[]], &mut transcript); + + if expect_satisfied { + assert!(res.is_ok()); + } else { + assert!(res.is_err()); + } +} diff --git a/halo2-base/src/gates/tests/neg_prop_tests.rs b/halo2-base/src/gates/tests/neg_prop_tests.rs new file mode 100644 index 00000000..226a01f9 --- /dev/null +++ b/halo2-base/src/gates/tests/neg_prop_tests.rs @@ -0,0 +1,398 @@ +use std::env::set_var; + +use ff::Field; +use itertools::Itertools; +use num_bigint::BigUint; +use proptest::{collection::vec, prelude::*}; +use rand::rngs::OsRng; + +use crate::halo2_proofs::{ + dev::MockProver, + halo2curves::{bn256::Fr, FieldExt}, + plonk::Assigned, +}; +use crate::{ + gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, + range::{RangeChip, RangeInstructions}, + tests::{ + pos_prop_tests::{rand_bin_witness, rand_fr, rand_witness}, + test_ground_truths, + }, + GateChip, GateInstructions, + }, + utils::{biguint_to_fe, bit_length, fe_to_biguint, ScalarField}, + QuantumCell, + QuantumCell::Witness, +}; + +// Strategies for generating random witnesses +prop_compose! { + // length == 1 is just selecting [0] which should be covered in unit test + fn idx_to_indicator_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, idx_val in prop::sample::select(vec![Fr::zero(), Fr::one(), Fr::random(OsRng)]), len in 2usize..=max_size) + (k in Just(k), idx in 0..len, idx_val in Just(idx_val), len in Just(len), mut witness_vals in arb_indicator::(len)) + -> (usize, usize, usize, Vec) { + witness_vals[idx] = idx_val; + (k, len, idx, witness_vals) + } +} + +prop_compose! { + fn select_strat(k_bounds: (usize, usize)) + (k in k_bounds.0..=k_bounds.1, a in rand_witness(), b in rand_witness(), sel in rand_bin_witness(), rand_output in rand_fr()) + -> (usize, QuantumCell, QuantumCell, QuantumCell, Fr) { + (k, a, b, sel, rand_output) + } +} + +prop_compose! { + fn select_by_indicator_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec>, usize, Fr) { + (k, a, idx, rand_output) + } +} + +prop_compose! { + fn select_from_idx_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), cells in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec>, usize, Fr) { + (k, cells, idx, rand_output) + } +} + +prop_compose! { + fn inner_product_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in rand_fr()) + -> (usize, Vec>, Vec>, Fr) { + (k, a, b, rand_output) + } +} + +prop_compose! { + fn inner_product_left_last_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in (rand_fr(), rand_fr())) + -> (usize, Vec>, Vec>, (Fr, Fr)) { + (k, a, b, rand_output) + } +} + +prop_compose! { + pub fn range_check_strat(k_bounds: (usize, usize), max_range_bits: usize) + (k in k_bounds.0..=k_bounds.1, range_bits in 1usize..=max_range_bits) // lookup_bits must be less than k + (k in Just(k), range_bits in Just(range_bits), lookup_bits in 8..k, + rand_a in prop::sample::select(vec![ + biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) - 1usize)), + biguint_to_fe(&BigUint::from(2u64).pow(range_bits as u32)), + biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) + 1usize)), + Fr::random(OsRng) + ])) + -> (usize, usize, usize, Fr) { + (k, range_bits, lookup_bits, rand_a) + } +} + +prop_compose! { + fn is_less_than_safe_strat(k_bounds: (usize, usize)) + // compose strat to generate random rand fr in range + (b in any::().prop_filter("not zero", |&i| i != 0), k in k_bounds.0..=k_bounds.1) + (k in Just(k), b in Just(b), lookup_bits in k_bounds.0 - 1..k, rand_a in rand_fr(), out in any::()) + -> (usize, u64, usize, Fr, bool) { + (k, b, lookup_bits, rand_a, out) + } +} + +fn arb_indicator(max_size: usize) -> impl Strategy> { + vec(Just(0), max_size).prop_map(|val| val.iter().map(|&x| F::from(x)).collect::>()) +} + +fn check_idx_to_indicator(idx: Fr, len: usize, ind_witnesses: &[Fr]) -> bool { + // check that: + // the length of the witnes array is correct + // the sum of the witnesses is 1, indicting that there is only one index that is 1 + if ind_witnesses.len() != len + || ind_witnesses.iter().fold(Fr::zero(), |acc, val| acc + *val) != Fr::one() + { + return false; + } + + let idx_val = idx.get_lower_128() as usize; + + // Check that all indexes are zero except for the one at idx + for (i, v) in ind_witnesses.iter().enumerate() { + if i != idx_val && *v != Fr::zero() { + return false; + } + } + true +} + +// verify rand_output == a if sel == 1, rand_output == b if sel == 0 +fn check_select(a: Fr, b: Fr, sel: Fr, rand_output: Fr) -> bool { + if (sel == Fr::zero() && rand_output != b) || (sel == Fr::one() && rand_output != a) { + return false; + } + true +} + +fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[Fr]) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + // assign value to advice column before by assigning `idx` via ctx.load() -> use same method as ind_offsets to get offset + let dummy_idx = Witness(Fr::from(idx as u64)); + let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); + // get the offsets of the indicator cells for later 'pranking' + builder.config(k, Some(9)); + let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); + // prank the indicator cells + // TODO: prank the entire advice column with random values + for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { + builder.main(0).advice[*offset] = Assigned::Trivial(*witness); + } + // Get idx and indicator from advice column + // Apply check instance function to `idx` and `ind_witnesses` + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = check_idx_to_indicator(Fr::from(idx as u64), len, ind_witnesses); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_select( + k: usize, + a: QuantumCell, + b: QuantumCell, + sel: QuantumCell, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + // add select gate + let select = gate.select(builder.main(0), a, b, sel); + + // Get the offset of `select`s output for later 'pranking' + builder.config(k, Some(9)); + let select_offset = select.cell.unwrap().offset; + // Prank the output + builder.main(0).advice[select_offset] = Assigned::Trivial(rand_output); + + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of output + let is_valid_instance = check_select(*a.value(), *b.value(), *sel.value(), rand_output); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_instance, + // if the proof is invalid, ignore + Err(_) => !is_valid_instance, + } +} + +fn neg_test_select_by_indicator( + k: usize, + a: Vec>, + idx: usize, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let indicator = gate.idx_to_indicator(builder.main(0), Witness(Fr::from(idx as u64)), a.len()); + let a_idx = gate.select_by_indicator(builder.main(0), a.clone(), indicator); + builder.config(k, Some(9)); + + let a_idx_offset = a_idx.cell.unwrap().offset; + builder.main(0).advice[a_idx_offset] = Assigned::Trivial(rand_output); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // retrieve the value of a[idx] and check that it is equal to rand_output + let is_valid_witness = rand_output == *a[idx].value(); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_select_from_idx( + k: usize, + cells: Vec>, + idx: usize, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let idx_val = + gate.select_from_idx(builder.main(0), cells.clone(), Witness(Fr::from(idx as u64))); + builder.config(k, Some(9)); + + let idx_offset = idx_val.cell.unwrap().offset; + builder.main(0).advice[idx_offset] = Assigned::Trivial(rand_output); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = rand_output == *cells[idx].value(); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_inner_product( + k: usize, + a: Vec>, + b: Vec>, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let inner_product = gate.inner_product(builder.main(0), a.clone(), b.clone()); + builder.config(k, Some(9)); + + let inner_product_offset = inner_product.cell.unwrap().offset; + builder.main(0).advice[inner_product_offset] = Assigned::Trivial(rand_output); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = rand_output == test_ground_truths::inner_product_ground_truth(&(a, b)); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_inner_product_left_last( + k: usize, + a: Vec>, + b: Vec>, + rand_output: (Fr, Fr), +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let inner_product = gate.inner_product_left_last(builder.main(0), a.clone(), b.clone()); + builder.config(k, Some(9)); + + let inner_product_offset = + (inner_product.0.cell.unwrap().offset, inner_product.1.cell.unwrap().offset); + // prank the output cells + builder.main(0).advice[inner_product_offset.0] = Assigned::Trivial(rand_output.0); + builder.main(0).advice[inner_product_offset.1] = Assigned::Trivial(rand_output.1); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // (inner_product_ground_truth, a[a.len()-1]) + let inner_product_ground_truth = + test_ground_truths::inner_product_ground_truth(&(a.clone(), b)); + let is_valid_witness = + rand_output.0 == inner_product_ground_truth && rand_output.1 == *a[a.len() - 1].value(); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +// Range Check + +fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: Fr) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = RangeChip::default(lookup_bits); + + let a_witness = builder.main(0).load_witness(rand_a); + gate.range_check(builder.main(0), a_witness, range_bits); + + builder.config(k, Some(9)); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let correct = fe_to_biguint(&rand_a).bits() <= range_bits as u64; + + MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct +} + +// TODO: expand to prank output of is_less_than_safe() +fn neg_test_is_less_than_safe( + k: usize, + b: u64, + lookup_bits: usize, + rand_a: Fr, + prank_out: bool, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = RangeChip::default(lookup_bits); + let ctx = builder.main(0); + + let a_witness = ctx.load_witness(rand_a); // cannot prank this later because this witness will be copy-constrained + let out = gate.is_less_than_safe(ctx, a_witness, b); + + let out_idx = out.cell.unwrap().offset; + ctx.advice[out_idx] = Assigned::Trivial(Fr::from(prank_out)); + + builder.config(k, Some(9)); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // println!("rand_a: {rand_a:?}, b: {b:?}"); + let a_big = fe_to_biguint(&rand_a); + let is_lt = a_big < BigUint::from(b); + let correct = (is_lt == prank_out) + && (a_big.bits() as usize <= (bit_length(b) + lookup_bits - 1) / lookup_bits * lookup_bits); // circuit should always fail if `a` doesn't pass range check + MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct +} + +proptest! { + // Note setting the minimum value of k to 8 is intentional as it is the smallest value that will not cause an `out of columns` error. Should be noted that filtering by len * (number cells per iteration) < 2^k leads to the filtering of to many cases and the failure of the tests w/o any runs. + #[test] + fn prop_test_neg_idx_to_indicator((k, len, idx, witness_vals) in idx_to_indicator_strat((10,20),100)) { + prop_assert!(neg_test_idx_to_indicator(k, len, idx, witness_vals.as_slice())); + } + + #[test] + fn prop_test_neg_select((k, a, b, sel, rand_output) in select_strat((10,20))) { + prop_assert!(neg_test_select(k, a, b, sel, rand_output)); + } + + #[test] + fn prop_test_neg_select_by_indicator((k, a, idx, rand_output) in select_by_indicator_strat((12,20),100)) { + prop_assert!(neg_test_select_by_indicator(k, a, idx, rand_output)); + } + + #[test] + fn prop_test_neg_select_from_idx((k, cells, idx, rand_output) in select_from_idx_strat((10,20),100)) { + prop_assert!(neg_test_select_from_idx(k, cells, idx, rand_output)); + } + + #[test] + fn prop_test_neg_inner_product((k, a, b, rand_output) in inner_product_strat((10,20),100)) { + prop_assert!(neg_test_inner_product(k, a, b, rand_output)); + } + + #[test] + fn prop_test_neg_inner_product_left_last((k, a, b, rand_output) in inner_product_left_last_strat((10,20),100)) { + prop_assert!(neg_test_inner_product_left_last(k, a, b, rand_output)); + } + + #[test] + fn prop_test_neg_range_check((k, range_bits, lookup_bits, rand_a) in range_check_strat((10,23),90)) { + prop_assert!(neg_test_range_check(k, range_bits, lookup_bits, rand_a)); + } + + #[test] + fn prop_test_neg_is_less_than_safe((k, b, lookup_bits, rand_a, out) in is_less_than_safe_strat((10,20))) { + prop_assert!(neg_test_is_less_than_safe(k, b, lookup_bits, rand_a, out)); + } +} diff --git a/halo2-base/src/gates/tests/pos_prop_tests.rs b/halo2-base/src/gates/tests/pos_prop_tests.rs new file mode 100644 index 00000000..f110d12f --- /dev/null +++ b/halo2-base/src/gates/tests/pos_prop_tests.rs @@ -0,0 +1,326 @@ +use crate::gates::tests::{flex_gate_tests, range_gate_tests, test_ground_truths::*, Fr}; +use crate::utils::{bit_length, fe_to_biguint}; +use crate::{QuantumCell, QuantumCell::Witness}; +use proptest::{collection::vec, prelude::*}; +//TODO: implement Copy for rand witness and rand fr to allow for array creation +// create vec and convert to array??? +//TODO: implement arbitrary for fr using looks like you'd probably need to implement your own TestFr struct to implement Arbitrary: https://docs.rs/quickcheck/latest/quickcheck/trait.Arbitrary.html , can probably just hack it from Fr = [u64; 4] +prop_compose! { + pub fn rand_fr()(val in any::()) -> Fr { + Fr::from(val) + } +} + +prop_compose! { + pub fn rand_witness()(val in any::()) -> QuantumCell { + Witness(Fr::from(val)) + } +} + +prop_compose! { + pub fn sum_products_with_coeff_and_var_strat(max_length: usize)(val in vec((rand_fr(), rand_witness(), rand_witness()), 1..=max_length), witness in rand_witness()) -> (Vec<(Fr, QuantumCell, QuantumCell)>, QuantumCell) { + (val, witness) + } +} + +prop_compose! { + pub fn rand_bin_witness()(val in prop::sample::select(vec![Fr::zero(), Fr::one()])) -> QuantumCell { + Witness(val) + } +} + +prop_compose! { + pub fn rand_fr_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> Fr { + Fr::from(val) + } +} + +prop_compose! { + pub fn rand_witness_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> QuantumCell { + Witness(Fr::from(val)) + } +} + +// LEsson here 0..2^range_bits fails with 'Uniform::new called with `low >= high` +// therfore to still have a range of 0..2^range_bits we need on a mod it by 2^range_bits +// note k > lookup_bits +prop_compose! { + fn range_check_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_range_bits: u32) + (range_bits in 2..=max_range_bits, k in k_lo..=k_hi) + (k in Just(k), lookup_bits in min_lookup_bits..(k-3), a in rand_fr_range(0, range_bits), + range_bits in Just(range_bits)) + -> (usize, usize, Fr, usize) { + (k, lookup_bits, a, range_bits as usize) + } +} + +prop_compose! { + fn check_less_than_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_num_bits: usize) + (num_bits in 2..max_num_bits, k in k_lo..=k_hi) + (k in Just(k), a in rand_witness_range(0, num_bits as u32), b in rand_witness_range(0, num_bits as u32), + num_bits in Just(num_bits), lookup_bits in min_lookup_bits..k) + -> (usize, usize, QuantumCell, QuantumCell, usize) { + (k, lookup_bits, a, b, num_bits) + } +} + +prop_compose! { + fn check_less_than_safe_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) + (k in k_lo..=k_hi) + (k in Just(k), b in any::(), a in rand_fr(), lookup_bits in min_lookup_bits..k) + -> (usize, usize, Fr, u64) { + (k, lookup_bits, a, b) + } +} + +proptest! { + + // Flex Gate Positive Tests + #[test] + fn prop_test_add(input in vec(rand_witness(), 2)) { + let ground_truth = add_ground_truth(input.as_slice()); + let result = flex_gate_tests::test_add(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_sub(input in vec(rand_witness(), 2)) { + let ground_truth = sub_ground_truth(input.as_slice()); + let result = flex_gate_tests::test_sub(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_neg(input in rand_witness()) { + let ground_truth = neg_ground_truth(input); + let result = flex_gate_tests::test_neg(input); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul(inputs in vec(rand_witness(), 2)) { + let ground_truth = mul_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_mul(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul_add(inputs in vec(rand_witness(), 3)) { + let ground_truth = mul_add_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_mul_add(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul_not(inputs in vec(rand_witness(), 2)) { + let ground_truth = mul_not_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_mul_not(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_assert_bit(input in rand_fr()) { + let ground_truth = input == Fr::one() || input == Fr::zero(); + let result = flex_gate_tests::test_assert_bit(input).is_ok(); + prop_assert_eq!(result, ground_truth); + } + + // Note: due to unwrap after inversion this test will fail if the denominator is zero so we want to test for that. Therefore we do not filter for zero values. + #[test] + fn prop_test_div_unsafe(inputs in vec(rand_witness().prop_filter("Input cannot be 0",|x| *x.value() != Fr::zero()), 2)) { + let ground_truth = div_unsafe_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_div_unsafe(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_assert_is_const(input in rand_fr()) { + flex_gate_tests::test_assert_is_const(&[input; 2]); + } + + #[test] + fn prop_test_inner_product(inputs in (vec(rand_witness(), 0..=100), vec(rand_witness(), 0..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_ground_truth(&inputs); + let result = flex_gate_tests::test_inner_product(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_inner_product_left_last(inputs in (vec(rand_witness(), 1..=100), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_left_last_ground_truth(&inputs); + let result = flex_gate_tests::test_inner_product_left_last(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_inner_product_with_sums(inputs in (vec(rand_witness(), 0..=10), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_with_sums_ground_truth(&inputs); + let result = flex_gate_tests::test_inner_product_with_sums(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_sum_products_with_coeff_and_var(input in sum_products_with_coeff_and_var_strat(100)) { + let expected = sum_products_with_coeff_and_var_ground_truth(&input); + let output = flex_gate_tests::test_sum_products_with_coeff_and_var(input); + prop_assert_eq!(expected, output); + } + + #[test] + fn prop_test_and(inputs in vec(rand_witness(), 2)) { + let ground_truth = and_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_and(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_not(input in rand_witness()) { + let ground_truth = not_ground_truth(&input); + let result = flex_gate_tests::test_not(input); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select(vals in vec(rand_witness(), 2), sel in rand_bin_witness()) { + let inputs = vec![vals[0], vals[1], sel]; + let ground_truth = select_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_select(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_or_and(inputs in vec(rand_witness(), 3)) { + let ground_truth = or_and_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_or_and(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_idx_to_indicator(input in (rand_witness(), 1..=16_usize)) { + let ground_truth = idx_to_indicator_ground_truth(input); + let result = flex_gate_tests::test_idx_to_indicator((input.0, input.1)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select_by_indicator(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { + let ground_truth = select_by_indicator_ground_truth(&inputs); + let result = flex_gate_tests::test_select_by_indicator(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select_from_idx(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { + let ground_truth = select_from_idx_ground_truth(&inputs); + let result = flex_gate_tests::test_select_from_idx(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_zero(x in rand_fr()) { + let ground_truth = is_zero_ground_truth(x); + let result = flex_gate_tests::test_is_zero(x); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_equal(inputs in vec(rand_witness(), 2)) { + let ground_truth = is_equal_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_is_equal(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_num_to_bits(num in any::()) { + let mut tmp = num; + let mut bits = vec![]; + if num == 0 { + bits.push(0); + } + while tmp > 0 { + bits.push(tmp & 1); + tmp /= 2; + } + let result = flex_gate_tests::test_num_to_bits((Fr::from(num), bits.len())); + prop_assert_eq!(bits.into_iter().map(Fr::from).collect::>(), result); + } + + /* + #[test] + fn prop_test_lagrange_eval(inputs in vec(rand_fr(), 3)) { + } + */ + + #[test] + fn prop_test_get_field_element(n in any::()) { + let ground_truth = get_field_element_ground_truth(n); + let result = flex_gate_tests::test_get_field_element::(n); + prop_assert_eq!(result, ground_truth); + } + + // Range Check Property Tests + + #[test] + fn prop_test_is_less_than(a in rand_witness(), b in any::().prop_filter("not zero", |&x| x != 0), + lookup_bits in 4..=16_usize) { + let bits = std::cmp::max(fe_to_biguint(a.value()).bits() as usize, bit_length(b)); + let ground_truth = is_less_than_ground_truth((*a.value(), Fr::from(b))); + let result = range_gate_tests::test_is_less_than(([a, Witness(Fr::from(b))], bits, lookup_bits)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_less_than_safe(a in rand_fr().prop_filter("not zero", |&x| x != Fr::zero()), + b in any::().prop_filter("not zero", |&x| x != 0), + lookup_bits in 4..=16_usize) { + prop_assume!(fe_to_biguint(&a).bits() as usize <= bit_length(b)); + let ground_truth = is_less_than_ground_truth((a, Fr::from(b))); + let result = range_gate_tests::test_is_less_than_safe((a, b, lookup_bits)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_div_mod(inputs in (rand_witness().prop_filter("Non-zero num", |x| *x.value() != Fr::zero()), any::().prop_filter("Non-zero divisor", |x| *x != 0u64), 1..=16_usize)) { + let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); + let result = range_gate_tests::test_div_mod((inputs.0, inputs.1, inputs.2)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_get_last_bit(input in rand_fr(), pad_bits in 0..10usize) { + let ground_truth = get_last_bit_ground_truth(input); + let bits = fe_to_biguint(&input).bits() as usize + pad_bits; + let result = range_gate_tests::test_get_last_bit((input, bits)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_div_mod_var(inputs in (rand_witness(), any::(), 1..=16_usize, 1..=16_usize)) { + let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); + let result = range_gate_tests::test_div_mod_var((inputs.0, Witness(Fr::from(inputs.1)), inputs.2, inputs.3)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_range_check((k, lookup_bits, a, range_bits) in range_check_strat((14,24), 3, 63)) { + prop_assert_eq!(range_gate_tests::test_range_check(k, lookup_bits, a, range_bits), ()); + } + + #[test] + fn prop_test_check_less_than((k, lookup_bits, a, b, num_bits) in check_less_than_strat((14,24), 3, 10)) { + prop_assume!(a.value() < b.value()); + prop_assert_eq!(range_gate_tests::test_check_less_than(k, lookup_bits, a, b, num_bits), ()); + } + + #[test] + fn prop_test_check_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { + prop_assume!(a < Fr::from(b)); + prop_assert_eq!(range_gate_tests::test_check_less_than_safe(k, lookup_bits, a, b), ()); + } + + #[test] + fn prop_test_check_big_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { + prop_assume!(a < Fr::from(b)); + prop_assert_eq!(range_gate_tests::test_check_big_less_than_safe(k, lookup_bits, a, b), ()); + } +} diff --git a/halo2-base/src/gates/tests/range_gate_tests.rs b/halo2-base/src/gates/tests/range_gate_tests.rs new file mode 100644 index 00000000..c781af2e --- /dev/null +++ b/halo2-base/src/gates/tests/range_gate_tests.rs @@ -0,0 +1,155 @@ +use std::env::set_var; + +use super::*; +use crate::halo2_proofs::dev::MockProver; +use crate::utils::{biguint_to_fe, ScalarField}; +use crate::QuantumCell::Witness; +use crate::{ + gates::{ + builder::{GateThreadBuilder, RangeCircuitBuilder}, + range::{RangeChip, RangeInstructions}, + }, + utils::BigPrimeField, + QuantumCell, +}; +use num_bigint::BigUint; +use test_case::test_case; + +#[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] +pub fn test_range_check(k: usize, lookup_bits: usize, a_val: F, range_bits: usize) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.assign_witnesses([a_val])[0]; + chip.range_check(ctx, a, range_bits); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(12, 10, Witness(Fr::zero()), Witness(Fr::one()), 64; "check_less_than() pos")] +pub fn test_check_less_than( + k: usize, + lookup_bits: usize, + a: QuantumCell, + b: QuantumCell, + num_bits: usize, +) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + chip.check_less_than(ctx, a, b, num_bits); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(10, 8, Fr::zero(), 1; "check_less_than_safe() pos")] +pub fn test_check_less_than_safe(k: usize, lookup_bits: usize, a_val: F, b: u64) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.assign_witnesses([a_val])[0]; + chip.check_less_than_safe(ctx, a, b); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(10, 8, Fr::zero(), 1; "check_big_less_than_safe() pos")] +pub fn test_check_big_less_than_safe( + k: usize, + lookup_bits: usize, + a_val: F, + b: u64, +) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.assign_witnesses([a_val])[0]; + chip.check_big_less_than_safe(ctx, a, BigUint::from(b)); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(([0, 1].map(Fr::from).map(Witness), 3, 12) => Fr::from(1) ; "is_less_than() pos")] +pub fn test_is_less_than( + (inputs, bits, lookup_bits): ([QuantumCell; 2], usize, usize), +) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = chip.is_less_than(ctx, inputs[0], inputs[1], bits); + *a.value() +} + +#[test_case((Fr::zero(), 3, 3) => Fr::from(1) ; "is_less_than_safe() pos")] +pub fn test_is_less_than_safe((a, b, lookup_bits): (F, u64, usize)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.load_witness(a); + let lt = chip.is_less_than_safe(ctx, a, b); + *lt.value() +} + +#[test_case((biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize, 8) => Fr::from(1) ; "is_big_less_than_safe() pos")] +pub fn test_is_big_less_than_safe( + (a, b, lookup_bits): (F, BigUint, usize), +) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.load_witness(a); + let b = chip.is_big_less_than_safe(ctx, a, b); + *b.value() +} + +#[test_case((Witness(Fr::one()), 1, 2) => (Fr::one(), Fr::zero()) ; "div_mod() pos")] +pub fn test_div_mod( + inputs: (QuantumCell, u64, usize), +) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(3); + let a = chip.div_mod(ctx, inputs.0, BigUint::from(inputs.1), inputs.2); + (*a.0.value(), *a.1.value()) +} + +#[test_case((Fr::from(3), 8) => Fr::one() ; "get_last_bit(): 3, 8 bits")] +#[test_case((Fr::from(3), 2) => Fr::one() ; "get_last_bit(): 3, 2 bits")] +#[test_case((Fr::from(0), 2) => Fr::zero() ; "get_last_bit(): 0")] +#[test_case((Fr::from(1), 2) => Fr::one() ; "get_last_bit(): 1")] +#[test_case((Fr::from(2), 2) => Fr::zero() ; "get_last_bit(): 2")] +pub fn test_get_last_bit((a, bits): (F, usize)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(3); + let a = ctx.load_witness(a); + let b = chip.get_last_bit(ctx, a, bits); + *b.value() +} + +#[test_case((Witness(Fr::from(3)), Witness(Fr::from(2)), 3, 3) => (Fr::one(), Fr::one()) ; "div_mod_var() pos")] +pub fn test_div_mod_var( + inputs: (QuantumCell, QuantumCell, usize, usize), +) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(3); + let a = chip.div_mod_var(ctx, inputs.0, inputs.1, inputs.2, inputs.3); + (*a.0.value(), *a.1.value()) +} diff --git a/halo2-base/src/gates/tests/test_ground_truths.rs b/halo2-base/src/gates/tests/test_ground_truths.rs new file mode 100644 index 00000000..894ff8c5 --- /dev/null +++ b/halo2-base/src/gates/tests/test_ground_truths.rs @@ -0,0 +1,190 @@ +use num_integer::Integer; + +use crate::utils::biguint_to_fe; +use crate::utils::fe_to_biguint; +use crate::utils::BigPrimeField; +use crate::utils::ScalarField; +use crate::QuantumCell; + +// Ground truth functions + +// Flex Gate Ground Truths + +pub fn add_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() + *inputs[1].value() +} + +pub fn sub_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() - *inputs[1].value() +} + +pub fn neg_ground_truth(input: QuantumCell) -> F { + -(*input.value()) +} + +pub fn mul_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() * *inputs[1].value() +} + +pub fn mul_add_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() * *inputs[1].value() + *inputs[2].value() +} + +pub fn mul_not_ground_truth(inputs: &[QuantumCell]) -> F { + (F::one() - *inputs[0].value()) * *inputs[1].value() +} + +pub fn div_unsafe_ground_truth(inputs: &[QuantumCell]) -> F { + inputs[1].value().invert().unwrap() * *inputs[0].value() +} + +pub fn inner_product_ground_truth( + inputs: &(Vec>, Vec>), +) -> F { + inputs + .0 + .iter() + .zip(inputs.1.iter()) + .fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b.value())) +} + +pub fn inner_product_left_last_ground_truth( + inputs: &(Vec>, Vec>), +) -> (F, F) { + let product = inner_product_ground_truth(inputs); + let last = *inputs.0.last().unwrap().value(); + (product, last) +} + +pub fn inner_product_with_sums_ground_truth( + input: &(Vec>, Vec>), +) -> Vec { + let (a, b) = &input; + let mut result = Vec::new(); + let mut sum = F::zero(); + // TODO: convert to fold + for (ai, bi) in a.iter().zip(b) { + let product = *ai.value() * *bi.value(); + sum += product; + result.push(sum); + } + result +} + +pub fn sum_products_with_coeff_and_var_ground_truth( + input: &(Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), +) -> F { + let expected = input.0.iter().fold(F::zero(), |acc, (coeff, cell1, cell2)| { + acc + *coeff * *cell1.value() * *cell2.value() + }) + *input.1.value(); + expected +} + +pub fn and_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() * *inputs[1].value() +} + +pub fn not_ground_truth(a: &QuantumCell) -> F { + F::one() - *a.value() +} + +pub fn select_ground_truth(inputs: &[QuantumCell]) -> F { + (*inputs[0].value() - inputs[1].value()) * *inputs[2].value() + *inputs[1].value() +} + +pub fn or_and_ground_truth(inputs: &[QuantumCell]) -> F { + let bc_val = *inputs[1].value() * inputs[2].value(); + bc_val + inputs[0].value() - bc_val * inputs[0].value() +} + +pub fn idx_to_indicator_ground_truth(inputs: (QuantumCell, usize)) -> Vec { + let (idx, size) = inputs; + let mut indicator = vec![F::zero(); size]; + let mut idx_value = size + 1; + for i in 0..size as u64 { + if F::from(i) == *idx.value() { + idx_value = i as usize; + break; + } + } + if idx_value < size { + indicator[idx_value] = F::one(); + } + indicator +} + +pub fn select_by_indicator_ground_truth( + inputs: &(Vec>, QuantumCell), +) -> F { + let mut idx_value = inputs.0.len() + 1; + let mut indicator = vec![F::zero(); inputs.0.len()]; + for i in 0..inputs.0.len() as u64 { + if F::from(i) == *inputs.1.value() { + idx_value = i as usize; + break; + } + } + if idx_value < inputs.0.len() { + indicator[idx_value] = F::one(); + } + // take cross product of indicator and inputs.0 + inputs.0.iter().zip(indicator.iter()).fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b)) +} + +pub fn select_from_idx_ground_truth( + inputs: &(Vec>, QuantumCell), +) -> F { + let idx = inputs.1.value(); + // Since F does not implement From, we have to iterate and find the matching index + for i in 0..inputs.0.len() as u64 { + if F::from(i) == *idx { + return *inputs.0[i as usize].value(); + } + } + F::zero() +} + +pub fn is_zero_ground_truth(x: F) -> F { + if x.is_zero().into() { + F::one() + } else { + F::zero() + } +} + +pub fn is_equal_ground_truth(inputs: &[QuantumCell]) -> F { + if inputs[0].value() == inputs[1].value() { + F::one() + } else { + F::zero() + } +} + +/* +pub fn lagrange_eval_ground_truth(inputs: &[F]) -> (F, F) { +} +*/ + +pub fn get_field_element_ground_truth(n: u64) -> F { + F::from(n) +} + +// Range Chip Ground Truths + +pub fn is_less_than_ground_truth(inputs: (F, F)) -> F { + if inputs.0 < inputs.1 { + F::one() + } else { + F::zero() + } +} + +pub fn div_mod_ground_truth(inputs: (F, u64)) -> (F, F) { + let a = fe_to_biguint(&inputs.0); + let (div, rem) = a.div_mod_floor(&inputs.1.into()); + (biguint_to_fe(&div), biguint_to_fe(&rem)) +} + +pub fn get_last_bit_ground_truth(input: F) -> F { + F::from(input.get_lower_32() & 1 == 1) +} diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 13fb664d..289d4057 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -1,16 +1,19 @@ +//! Base library to build Halo2 circuits. #![feature(stmt_expr_attributes)] #![feature(trait_alias)] #![deny(clippy::perf)] #![allow(clippy::too_many_arguments)] +#![warn(clippy::default_numeric_fallback)] +#![warn(missing_docs)] -// different memory allocator options: -// mimalloc is fastest on Mac M2 +// Different memory allocator options: #[cfg(feature = "jemallocator")] use jemallocator::Jemalloc; #[cfg(feature = "jemallocator")] #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; +// mimalloc is fastest on Mac M2 #[cfg(feature = "mimalloc")] use mimalloc::MiMalloc; #[cfg(feature = "mimalloc")] @@ -24,552 +27,385 @@ compile_error!( #[cfg(not(any(feature = "halo2-pse", feature = "halo2-axiom")))] compile_error!("Must enable exactly one of \"halo2-pse\" or \"halo2-axiom\" features to choose which halo2_proofs crate to use."); -use gates::flex_gate::MAX_PHASE; +// use gates::flex_gate::MAX_PHASE; #[cfg(feature = "halo2-pse")] pub use halo2_proofs; #[cfg(feature = "halo2-axiom")] pub use halo2_proofs_axiom as halo2_proofs; -use halo2_proofs::{ - circuit::{AssignedCell, Cell, Region, Value}, - plonk::{Advice, Assigned, Column, Fixed}, -}; -use rustc_hash::FxHashMap; -#[cfg(feature = "halo2-pse")] -use std::marker::PhantomData; -use std::{cell::RefCell, rc::Rc}; +use halo2_proofs::plonk::Assigned; use utils::ScalarField; +/// Module that contains the main API for creating and working with circuits. pub mod gates; -// pub mod hashes; +/// Utility functions for converting between different types of field elements. pub mod utils; +/// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-axiom")] pub const SKIP_FIRST_PASS: bool = false; +/// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-pse")] pub const SKIP_FIRST_PASS: bool = true; -#[derive(Clone, Debug)] -pub enum QuantumCell<'a, 'b: 'a, F: ScalarField> { - Existing(&'a AssignedValue<'b, F>), - ExistingOwned(AssignedValue<'b, F>), // this is similar to the Cow enum - Witness(Value), - WitnessFraction(Value>), +/// Convenience Enum which abstracts the scenarios under a value is added to an advice column. +#[derive(Clone, Copy, Debug)] +pub enum QuantumCell { + /// An [AssignedValue] already existing in the advice column (e.g., a witness value that was already assigned in a previous cell in the column). + /// * Assigns a new cell into the advice column with value equal to the value of a. + /// * Imposes an equality constraint between the new cell and the cell of a so the Verifier guarantees that these two cells are always equal. + Existing(AssignedValue), + // This is a guard for witness values assigned after pkey generation. We do not use `Value` api anymore. + /// A non-existing witness [ScalarField] value (e.g. private input) to add to an advice column. + Witness(F), + /// A non-existing witness [ScalarField] marked as a fraction for optimization in batch inversion later. + WitnessFraction(Assigned), + /// A known constant value added as a witness value to the advice column and added to the "Fixed" column during circuit creation time. + /// * Visible to both the Prover and the Verifier. + /// * Imposes an equality constraint between the two corresponding cells in the advice and fixed columns. Constant(F), } -impl QuantumCell<'_, '_, F> { - pub fn value(&self) -> Value<&F> { +impl From> for QuantumCell { + /// Converts an [AssignedValue] into a [QuantumCell] of [type Existing(AssignedValue)] + fn from(a: AssignedValue) -> Self { + Self::Existing(a) + } +} + +impl QuantumCell { + /// Returns an immutable reference to the underlying [ScalarField] value of a QuantumCell. + /// + /// Panics if the QuantumCell is of type WitnessFraction. + pub fn value(&self) -> &F { match self { Self::Existing(a) => a.value(), - Self::ExistingOwned(a) => a.value(), - Self::Witness(a) => a.as_ref(), + Self::Witness(a) => a, Self::WitnessFraction(_) => { panic!("Trying to get value of a fraction before batch inversion") } - Self::Constant(a) => Value::known(a), + Self::Constant(a) => a, } } } -#[derive(Clone, Debug)] -pub struct AssignedValue<'a, F: ScalarField> { - #[cfg(feature = "halo2-axiom")] - pub cell: AssignedCell<&'a Assigned, F>, - - #[cfg(feature = "halo2-pse")] - pub cell: Cell, - #[cfg(feature = "halo2-pse")] - pub value: Value, - #[cfg(feature = "halo2-pse")] - pub row_offset: usize, - #[cfg(feature = "halo2-pse")] - pub _marker: PhantomData<&'a F>, - - #[cfg(feature = "display")] +/// Pointer to the position of a cell at `offset` in an advice column within a [Context] of `context_id`. +#[derive(Clone, Copy, Debug)] +pub struct ContextCell { + /// Identifier of the [Context] that this cell belongs to. pub context_id: usize, + /// Relative offset of the cell within this [Context] advice column. + pub offset: usize, } -impl<'a, F: ScalarField> AssignedValue<'a, F> { - #[cfg(feature = "display")] - pub fn context_id(&self) -> usize { - self.context_id - } - - pub fn row(&self) -> usize { - #[cfg(feature = "halo2-axiom")] - { - self.cell.row_offset() - } - - #[cfg(feature = "halo2-pse")] - { - self.row_offset - } - } - - #[cfg(feature = "halo2-axiom")] - pub fn cell(&self) -> &Cell { - self.cell.cell() - } - #[cfg(feature = "halo2-pse")] - pub fn cell(&self) -> Cell { - self.cell - } +/// Pointer containing cell value and location within [Context]. +/// +/// Note: Performs a copy of the value, should only be used when you are about to assign the value again elsewhere. +#[derive(Clone, Copy, Debug)] +pub struct AssignedValue { + /// Value of the cell. + pub value: Assigned, // we don't use reference to avoid issues with lifetimes (you can't safely borrow from vector and push to it at the same time). + // only needed during vkey, pkey gen to fetch the actual cell from the relevant context + /// [ContextCell] pointer to the cell the value is assigned to within an advice column of a [Context]. + pub cell: Option, +} - pub fn value(&self) -> Value<&F> { - #[cfg(feature = "halo2-axiom")] - { - self.cell.value().map(|a| match *a { - Assigned::Trivial(a) => a, - _ => unreachable!(), - }) - } - #[cfg(feature = "halo2-pse")] - { - self.value.as_ref() +impl AssignedValue { + /// Returns an immutable reference to the underlying value of an AssignedValue. + /// + /// Panics if the AssignedValue is of type WitnessFraction. + pub fn value(&self) -> &F { + match &self.value { + Assigned::Trivial(a) => a, + _ => unreachable!(), // if trying to fetch an un-evaluated fraction, you will have to do something manual } } - - #[cfg(feature = "halo2-axiom")] - pub fn copy_advice<'v>( - &'a self, - region: &mut Region<'_, F>, - column: Column, - offset: usize, - ) -> AssignedCell<&'v Assigned, F> { - let assigned_cell = region - .assign_advice(column, offset, self.cell.value().map(|v| **v)) - .unwrap_or_else(|err| panic!("{err:?}")); - region.constrain_equal(assigned_cell.cell(), self.cell()); - - assigned_cell - } - - #[cfg(feature = "halo2-pse")] - pub fn copy_advice( - &'a self, - region: &mut Region<'_, F>, - column: Column, - offset: usize, - ) -> Cell { - let cell = region - .assign_advice(|| "", column, offset, || self.value) - .expect("assign copy advice should not fail") - .cell(); - region.constrain_equal(cell, self.cell()).expect("constrain equal should not fail"); - - cell - } } -// The reason we have a `Context` is that we will need to mutably borrow `advice_rows` (etc.) to update row count -// The `Circuit` trait takes in `Config` as an input that is NOT mutable, so we must pass around &mut Context everywhere for function calls -// We follow halo2wrong's convention of having `Context` also include the `Region` to be passed around, instead of a `Layouter`, so that everything happens within a single `layouter.assign_region` call. This allows us to circumvent the Halo2 layouter and use our own "pseudo-layouter", which is more specialized (and hence faster) for our specific gates -#[derive(Debug)] -pub struct Context<'a, F: ScalarField> { - pub region: Region<'a, F>, // I don't see a reason to use Box> since we will pass mutable reference of `Context` anyways +/// Represents a single thread of an execution trace. +/// * We keep the naming [Context] for historical reasons. +#[derive(Clone, Debug)] +pub struct Context { + /// Flag to determine whether only witness generation or proving and verification key generation is being performed. + /// * If witness gen is performed many operations can be skipped for optimization. + witness_gen_only: bool, - pub max_rows: usize, + /// Identifier to reference cells from this [Context]. + pub context_id: usize, - // Assigning advice in a "horizontal" first fashion requires getting the column with min rows used each time `assign_region` is called, which takes a toll on witness generation speed, so instead we will just assigned a column all the way down until it reaches `max_rows` and then increment the column index - // - /// `advice_alloc[context_id] = (index, offset)` where `index` contains the current column index corresponding to `context_id`, and `offset` contains the current row offset within column `index` - /// - /// This assumes the phase is `ctx.current_phase()` to enforce the design pattern that advice should be assigned one phase at a time. - pub advice_alloc: Vec<(usize, usize)>, // [Vec<(usize, usize)>; MAX_PHASE], + /// Single column of advice cells. + pub advice: Vec>, - #[cfg(feature = "display")] - pub total_advice: usize, + /// [Vec] tracking all cells that lookup is enabled for. + /// * When there is more than 1 advice column all `advice` cells will be copied to a single lookup enabled column to perform lookups. + pub cells_to_lookup: Vec>, + + /// Cell that represents the zero value as AssignedValue + pub zero_cell: Option>, // To save time from re-allocating new temporary vectors that get quickly dropped (e.g., for some range checks), we keep a vector with high capacity around that we `clear` before use each time + // This is NOT THREAD SAFE // Need to use RefCell to avoid borrow rules // Need to use Rc to borrow this and mutably borrow self at same time - preallocated_vec_to_assign: Rc>>>, - - // `assigned_constants` is a HashMap keeping track of all constants that we use throughout - // we assign them to fixed columns as we go, re-using a fixed cell if the constant value has been assigned previously - fixed_columns: Vec>, - fixed_col: usize, - fixed_offset: usize, - // fxhash is faster than normal HashMap: https://nnethercote.github.io/perf-book/hashing.html - #[cfg(feature = "halo2-axiom")] - pub assigned_constants: FxHashMap, - // PSE's halo2curves does not derive Hash - #[cfg(feature = "halo2-pse")] - pub assigned_constants: FxHashMap, Cell>, - - pub zero_cell: Option>, - - // `cells_to_lookup` is a vector keeping track of all cells that we want to enable lookup for. When there is more than 1 advice column we will copy_advice all of these cells to the single lookup enabled column and do lookups there - pub cells_to_lookup: Vec>, - - current_phase: usize, - - #[cfg(feature = "display")] - pub op_count: FxHashMap, - #[cfg(feature = "display")] - pub advice_alloc_cache: [Vec<(usize, usize)>; MAX_PHASE], - #[cfg(feature = "display")] - pub total_lookup_cells: [usize; MAX_PHASE], - #[cfg(feature = "display")] - pub total_fixed: usize, -} + // preallocated_vec_to_assign: Rc>>>, -//impl<'a, F: ScalarField> std::ops::Drop for Context<'a, F> { -// fn drop(&mut self) { -// assert!( -// self.cells_to_lookup.is_empty(), -// "THERE ARE STILL ADVICE CELLS THAT NEED TO BE LOOKED UP" -// ); -// } -//} - -impl<'a, F: ScalarField> std::fmt::Display for Context<'a, F> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:#?}") - } -} + // ======================================== + // General principle: we don't need to optimize anything specific to `witness_gen_only == false` because it is only done during keygen + // If `witness_gen_only == false`: + /// [Vec] representing the selector column of this [Context] accompanying each `advice` column + /// * Assumed to have the same length as `advice` + pub selector: Vec, -// a single struct to package any configuration parameters we will need for constructing a new `Context` -#[derive(Clone, Debug)] -pub struct ContextParams { - pub max_rows: usize, - /// `num_advice[context_id][phase]` contains the number of advice columns that context `context_id` keeps track of in phase `phase` - pub num_context_ids: usize, - pub fixed_columns: Vec>, -} + // TODO: gates that use fixed columns as selectors? + /// A [Vec] tracking equality constraints between pairs of [Context] `advice` cells. + /// + /// Assumes both `advice` cells are in the same [Context]. + pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, -impl<'a, F: ScalarField> Context<'a, F> { - pub fn new(region: Region<'a, F>, params: ContextParams) -> Self { - let advice_alloc = vec![(0, 0); params.num_context_ids]; + /// A [Vec] tracking pairs equality constraints between Fixed values and [Context] `advice` cells. + /// + /// Assumes the constant and `advice` cell are in the same [Context]. + pub constant_equality_constraints: Vec<(F, ContextCell)>, +} +impl Context { + /// Creates a new [Context] with the given `context_id` and witness generation enabled/disabled by the `witness_gen_only` flag. + /// * `witness_gen_only`: flag to determine whether public key generation or only witness generation is being performed. + /// * `context_id`: identifier to reference advice cells from this [Context] later. + pub fn new(witness_gen_only: bool, context_id: usize) -> Self { Self { - region, - max_rows: params.max_rows, - advice_alloc, - #[cfg(feature = "display")] - total_advice: 0, - preallocated_vec_to_assign: Rc::new(RefCell::new(Vec::with_capacity(256))), - fixed_columns: params.fixed_columns, - fixed_col: 0, - fixed_offset: 0, - assigned_constants: FxHashMap::default(), - zero_cell: None, + witness_gen_only, + context_id, + advice: Vec::new(), cells_to_lookup: Vec::new(), - current_phase: 0, - #[cfg(feature = "display")] - op_count: FxHashMap::default(), - #[cfg(feature = "display")] - advice_alloc_cache: [(); MAX_PHASE].map(|_| vec![]), - #[cfg(feature = "display")] - total_lookup_cells: [0; MAX_PHASE], - #[cfg(feature = "display")] - total_fixed: 0, + zero_cell: None, + selector: Vec::new(), + advice_equality_constraints: Vec::new(), + constant_equality_constraints: Vec::new(), } } - pub fn preallocated_vec_to_assign(&self) -> Rc>>> { - Rc::clone(&self.preallocated_vec_to_assign) + /// Returns the `witness_gen_only` flag of the [Context] + pub fn witness_gen_only(&self) -> bool { + self.witness_gen_only } - pub fn next_phase(&mut self) { - assert!( - self.cells_to_lookup.is_empty(), - "THERE ARE STILL ADVICE CELLS THAT NEED TO BE LOOKED UP" - ); - #[cfg(feature = "display")] - { - self.advice_alloc_cache[self.current_phase] = self.advice_alloc.clone(); - } - #[cfg(feature = "halo2-axiom")] - self.region.next_phase(); - self.current_phase += 1; - for advice_alloc in self.advice_alloc.iter_mut() { - *advice_alloc = (0, 0); + /// Pushes a [QuantumCell] to the end of the `advice` column ([Vec] of advice cells) in this [Context]. + /// * `input`: the cell to be assigned. + pub fn assign_cell(&mut self, input: impl Into>) { + // Determine the type of the cell and push it to the relevant vector + match input.into() { + QuantumCell::Existing(acell) => { + self.advice.push(acell.value); + // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); + } + } + QuantumCell::Witness(val) => { + self.advice.push(Assigned::Trivial(val)); + } + QuantumCell::WitnessFraction(val) => { + self.advice.push(val); + } + QuantumCell::Constant(c) => { + self.advice.push(Assigned::Trivial(c)); + // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.constant_equality_constraints.push((c, new_cell)); + } + } } - assert!(self.current_phase < MAX_PHASE); - } - - pub fn current_phase(&self) -> usize { - self.current_phase } - #[cfg(feature = "display")] - /// Returns (number of fixed columns used, total fixed cells used) - pub fn fixed_stats(&self) -> (usize, usize) { - // heuristic, fixed cells don't need to worry about blinding factors - ((self.total_fixed + self.max_rows - 1) / self.max_rows, self.total_fixed) + /// Returns the [AssignedValue] of the last cell in the `advice` column of [Context] or [None] if `advice` is empty + pub fn last(&self) -> Option> { + self.advice.last().map(|v| { + let cell = (!self.witness_gen_only).then_some(ContextCell { + context_id: self.context_id, + offset: self.advice.len() - 1, + }); + AssignedValue { value: *v, cell } + }) } - #[cfg(feature = "halo2-axiom")] - pub fn assign_fixed(&mut self, c: F) -> Cell { - let fixed = self.assigned_constants.get(&c); - if let Some(cell) = fixed { - *cell + /// Returns the [AssignedValue] of the cell at the given `offset` in the `advice` column of [Context] + /// * `offset`: the offset of the cell to be fetched + /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last cell) + /// * Assumes `offset` is a valid index in `advice`; + /// * `0` <= `offset` < `advice.len()` (or `advice.len() + offset >= 0` if `offset` is negative) + pub fn get(&self, offset: isize) -> AssignedValue { + let offset = if offset < 0 { + self.advice.len().wrapping_add_signed(offset) } else { - let cell = self.assign_fixed_without_caching(c); - self.assigned_constants.insert(c, cell); - cell - } - } - #[cfg(feature = "halo2-pse")] - pub fn assign_fixed(&mut self, c: F) -> Cell { - let fixed = self.assigned_constants.get(c.to_repr().as_ref()); - if let Some(cell) = fixed { - *cell - } else { - let cell = self.assign_fixed_without_caching(c); - self.assigned_constants.insert(c.to_repr().as_ref().to_vec(), cell); - cell - } + offset as usize + }; + assert!(offset < self.advice.len()); + let cell = + (!self.witness_gen_only).then_some(ContextCell { context_id: self.context_id, offset }); + AssignedValue { value: self.advice[offset], cell } } - /// Saving the assigned constant to the hashmap takes time. - /// - /// In situations where you don't expect to reuse the value, you can assign the fixed value directly using this function. - pub fn assign_fixed_without_caching(&mut self, c: F) -> Cell { - #[cfg(feature = "halo2-axiom")] - let cell = self.region.assign_fixed( - self.fixed_columns[self.fixed_col], - self.fixed_offset, - Assigned::Trivial(c), - ); - #[cfg(feature = "halo2-pse")] - let cell = self - .region - .assign_fixed( - || "", - self.fixed_columns[self.fixed_col], - self.fixed_offset, - || Value::known(c), - ) - .expect("assign fixed should not fail") - .cell(); - #[cfg(feature = "display")] - { - self.total_fixed += 1; - } - self.fixed_col += 1; - if self.fixed_col == self.fixed_columns.len() { - self.fixed_col = 0; - self.fixed_offset += 1; + /// Creates an equality constraint between two `advice` cells. + /// * `a`: the first `advice` cell to be constrained equal + /// * `b`: the second `advice` cell to be constrained equal + /// * Assumes both cells are `advice` cells + pub fn constrain_equal(&mut self, a: &AssignedValue, b: &AssignedValue) { + if !self.witness_gen_only { + self.advice_equality_constraints.push((a.cell.unwrap(), b.cell.unwrap())); } - cell } - /// Assuming that this is only called if ctx.region is not in shape mode! - #[cfg(feature = "halo2-axiom")] - pub fn assign_cell<'v>( + /// Pushes multiple advice cells to the `advice` column of [Context] and enables them by enabling the corresponding selector specified in `gate_offset`. + /// + /// * `inputs`: Iterator that specifies the cells to be assigned + /// * `gate_offsets`: specifies relative offset from current position to enable selector for the gate (e.g., `0` is inputs[0]). + /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last previously assigned cell) + pub fn assign_region( &mut self, - input: QuantumCell<'_, 'v, F>, - column: Column, - #[cfg(feature = "display")] context_id: usize, - row_offset: usize, - ) -> AssignedValue<'v, F> { - match input { - QuantumCell::Existing(acell) => { - AssignedValue { - cell: acell.copy_advice( - // || "gate: copy advice", - &mut self.region, - column, - row_offset, - ), - #[cfg(feature = "display")] - context_id, - } + inputs: impl IntoIterator, + gate_offsets: impl IntoIterator, + ) where + Q: Into>, + { + if self.witness_gen_only { + for input in inputs { + self.assign_cell(input); } - QuantumCell::ExistingOwned(acell) => { - AssignedValue { - cell: acell.copy_advice( - // || "gate: copy advice", - &mut self.region, - column, - row_offset, - ), - #[cfg(feature = "display")] - context_id, - } + } else { + let row_offset = self.advice.len(); + // note: row_offset may not equal self.selector.len() at this point if we previously used `load_constant` or `load_witness` + for input in inputs { + self.assign_cell(input); } - QuantumCell::Witness(val) => AssignedValue { - cell: self - .region - .assign_advice(column, row_offset, val.map(Assigned::Trivial)) - .expect("assign advice should not fail"), - #[cfg(feature = "display")] - context_id, - }, - QuantumCell::WitnessFraction(val) => AssignedValue { - cell: self - .region - .assign_advice(column, row_offset, val) - .expect("assign advice should not fail"), - #[cfg(feature = "display")] - context_id, - }, - QuantumCell::Constant(c) => { - let acell = self - .region - .assign_advice(column, row_offset, Value::known(Assigned::Trivial(c))) - .expect("assign fixed advice should not fail"); - let c_cell = self.assign_fixed(c); - self.region.constrain_equal(acell.cell(), &c_cell); - AssignedValue { - cell: acell, - #[cfg(feature = "display")] - context_id, - } + self.selector.resize(self.advice.len(), false); + for offset in gate_offsets { + *self + .selector + .get_mut(row_offset.checked_add_signed(offset).expect("Invalid gate offset")) + .expect("Invalid selector offset") = true; } } } - #[cfg(feature = "halo2-pse")] - pub fn assign_cell<'v>( + /// Pushes multiple advice cells to the `advice` column of [Context] and enables them by enabling the corresponding selector specified in `gate_offset` and returns the last assigned cell. + /// + /// Assumes `gate_offsets` is the same length as `inputs` + /// + /// Returns the last assigned cell + /// * `inputs`: Iterator that specifies the cells to be assigned + /// * `gate_offsets`: specifies indices to enable selector for the gate; assume `gate_offsets` is sorted in increasing order + /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last cell) + pub fn assign_region_last( + &mut self, + inputs: impl IntoIterator, + gate_offsets: impl IntoIterator, + ) -> AssignedValue + where + Q: Into>, + { + self.assign_region(inputs, gate_offsets); + self.last().unwrap() + } + + /// Pushes multiple advice cells to the `advice` column of [Context] and enables them by enabling the corresponding selector specified in `gate_offset`. + /// + /// Allows for the specification of equality constraints between cells at `equality_offsets` within the `advice` column and external advice cells specified in `external_equality` (e.g, Fixed column). + /// * `gate_offsets`: specifies indices to enable selector for the gate; + /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last cell) + /// * `equality_offsets`: specifies pairs of indices to constrain equality + /// * `external_equality`: specifies an existing cell to constrain equality with the cell at a certain index + pub fn assign_region_smart( &mut self, - input: QuantumCell<'_, 'v, F>, - column: Column, - #[cfg(feature = "display")] context_id: usize, - row_offset: usize, - phase: u8, - ) -> AssignedValue<'v, F> { - match input { - QuantumCell::Existing(acell) => { - AssignedValue { - cell: acell.copy_advice( - // || "gate: copy advice", - &mut self.region, - column, - row_offset, - ), - value: acell.value, - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - } + inputs: impl IntoIterator, + gate_offsets: impl IntoIterator, + equality_offsets: impl IntoIterator, + external_equality: impl IntoIterator, isize)>, + ) where + Q: Into>, + { + let row_offset = self.advice.len(); + self.assign_region(inputs, gate_offsets); + + // note: row_offset may not equal self.selector.len() at this point if we previously used `load_constant` or `load_witness` + // If not in witness generation mode, add equality constraints. + if !self.witness_gen_only { + // Add equality constraints between cells in the advice column. + for (offset1, offset2) in equality_offsets { + self.advice_equality_constraints.push(( + ContextCell { + context_id: self.context_id, + offset: row_offset.wrapping_add_signed(offset1), + }, + ContextCell { + context_id: self.context_id, + offset: row_offset.wrapping_add_signed(offset2), + }, + )); } - QuantumCell::ExistingOwned(acell) => { - AssignedValue { - cell: acell.copy_advice( - // || "gate: copy advice", - &mut self.region, - column, - row_offset, - ), - value: acell.value, - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - } - } - QuantumCell::Witness(value) => AssignedValue { - cell: self - .region - .assign_advice(|| "", column, row_offset, || value) - .expect("assign advice should not fail") - .cell(), - value, - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - }, - QuantumCell::WitnessFraction(val) => AssignedValue { - cell: self - .region - .assign_advice(|| "", column, row_offset, || val) - .expect("assign advice should not fail") - .cell(), - value: Value::unknown(), - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - }, - QuantumCell::Constant(c) => { - let acell = self - .region - .assign_advice(|| "", column, row_offset, || Value::known(c)) - .expect("assign fixed advice should not fail") - .cell(); - let c_cell = self.assign_fixed(c); - self.region.constrain_equal(acell, c_cell).unwrap(); - AssignedValue { - cell: acell, - value: Value::known(c), - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - } + // Add equality constraints between cells in the advice column and external cells (Fixed column). + for (cell, offset) in external_equality { + self.advice_equality_constraints.push(( + cell.unwrap(), + ContextCell { + context_id: self.context_id, + offset: row_offset.wrapping_add_signed(offset), + }, + )); } } } - // convenience function to deal with rust warnings - pub fn constrain_equal(&mut self, a: &AssignedValue, b: &AssignedValue) { - #[cfg(feature = "halo2-axiom")] - self.region.constrain_equal(a.cell(), b.cell()); - #[cfg(not(feature = "halo2-axiom"))] - self.region.constrain_equal(a.cell(), b.cell()).unwrap(); + /// Assigns a region of witness cells in an iterator and returns a [Vec] of assigned cells. + /// * `witnesses`: Iterator that specifies the cells to be assigned + pub fn assign_witnesses( + &mut self, + witnesses: impl IntoIterator, + ) -> Vec> { + let row_offset = self.advice.len(); + self.assign_region(witnesses.into_iter().map(QuantumCell::Witness), []); + self.advice[row_offset..] + .iter() + .enumerate() + .map(|(i, v)| { + let cell = (!self.witness_gen_only) + .then_some(ContextCell { context_id: self.context_id, offset: row_offset + i }); + AssignedValue { value: *v, cell } + }) + .collect() } - /// Call this at the end of a phase - /// - /// assumes self.region is not in shape mode - pub fn copy_and_lookup_cells(&mut self, lookup_advice: Vec>) -> usize { - let total_cells = self.cells_to_lookup.len(); - let mut cells_to_lookup = self.cells_to_lookup.iter().peekable(); - for column in lookup_advice.into_iter() { - let mut offset = 0; - while offset < self.max_rows && cells_to_lookup.peek().is_some() { - let acell = cells_to_lookup.next().unwrap(); - acell.copy_advice(&mut self.region, column, offset); - offset += 1; - } - } - if cells_to_lookup.peek().is_some() { - panic!("NOT ENOUGH ADVICE COLUMNS WITH LOOKUP ENABLED"); - } - self.cells_to_lookup.clear(); - #[cfg(feature = "display")] - { - self.total_lookup_cells[self.current_phase] = total_cells; + /// Assigns a witness value and returns the corresponding assigned cell. + /// * `witness`: the witness value to be assigned + pub fn load_witness(&mut self, witness: F) -> AssignedValue { + self.assign_cell(QuantumCell::Witness(witness)); + if !self.witness_gen_only { + self.selector.resize(self.advice.len(), false); } - total_cells + self.last().unwrap() } - #[cfg(feature = "display")] - pub fn print_stats(&mut self, context_names: &[&str]) { - let curr_phase = self.current_phase(); - self.advice_alloc_cache[curr_phase] = self.advice_alloc.clone(); - for phase in 0..=curr_phase { - for (context_name, alloc) in - context_names.iter().zip(self.advice_alloc_cache[phase].iter()) - { - println!("Context \"{context_name}\" used {} advice columns and {} total advice cells in phase {phase}", alloc.0 + 1, alloc.0 * self.max_rows + alloc.1); - } - let num_lookup_advice_cells = self.total_lookup_cells[phase]; - println!("Special lookup advice cells: optimal columns: {}, total {num_lookup_advice_cells} cells used in phase {phase}.", (num_lookup_advice_cells + self.max_rows - 1)/self.max_rows); + /// Assigns a constant value and returns the corresponding assigned cell. + /// * `c`: the constant value to be assigned + pub fn load_constant(&mut self, c: F) -> AssignedValue { + self.assign_cell(QuantumCell::Constant(c)); + if !self.witness_gen_only { + self.selector.resize(self.advice.len(), false); } - let (fixed_cols, total_fixed) = self.fixed_stats(); - println!("Fixed columns: {fixed_cols}, Total fixed cells: {total_fixed}"); + self.last().unwrap() } -} -#[derive(Clone, Debug)] -pub struct AssignedPrimitive<'a, T: Into + Copy, F: ScalarField> { - pub value: Value, - - #[cfg(feature = "halo2-axiom")] - pub cell: AssignedCell<&'a Assigned, F>, - - #[cfg(feature = "halo2-pse")] - pub cell: Cell, - #[cfg(feature = "halo2-pse")] - row_offset: usize, - #[cfg(feature = "halo2-pse")] - _marker: PhantomData<&'a F>, + /// Assigns the 0 value to a new cell or returns a previously assigned zero cell from `zero_cell`. + pub fn load_zero(&mut self) -> AssignedValue { + if let Some(zcell) = &self.zero_cell { + return *zcell; + } + let zero_cell = self.load_constant(F::zero()); + self.zero_cell = Some(zero_cell); + zero_cell + } } diff --git a/halo2-base/src/utils.rs b/halo2-base/src/utils.rs index bb07150a..f722d8ce 100644 --- a/halo2-base/src/utils.rs +++ b/halo2-base/src/utils.rs @@ -8,14 +8,21 @@ use num_bigint::Sign; use num_traits::Signed; use num_traits::{One, Zero}; +/// Helper trait to convert to and from a [BigPrimeField] by converting a list of [u64] digits #[cfg(feature = "halo2-axiom")] pub trait BigPrimeField: ScalarField { + /// Converts a slice of [u64] to [BigPrimeField] + /// * `val`: the slice of u64 + /// + /// # Assumptions + /// * `val` has the correct length for the implementation + /// * The integer value of `val` is already less than the modulus of `Self` fn from_u64_digits(val: &[u64]) -> Self; } #[cfg(feature = "halo2-axiom")] impl BigPrimeField for F where - F: FieldExt + Hash + Into<[u64; 4]> + From<[u64; 4]>, + F: ScalarField + From<[u64; 4]>, // Assume [u64; 4] is little-endian. We only implement ScalarField when this is true. { #[inline(always)] fn from_u64_digits(val: &[u64]) -> Self { @@ -26,67 +33,82 @@ where } } -#[cfg(feature = "halo2-axiom")] +/// Helper trait to represent a field element that can be converted into [u64] limbs. +/// +/// Note: Since the number of bits necessary to represent a field element is larger than the number of bits in a u64, we decompose the integer representation of the field element into multiple [u64] values e.g. `limbs`. pub trait ScalarField: FieldExt + Hash { - /// Returns the base `2^bit_len` little endian representation of the prime field element - /// up to `num_limbs` number of limbs (truncates any extra limbs) - /// - /// Basically same as `to_repr` but does not go further into bytes + /// Returns the base `2bit_len` little endian representation of the [ScalarField] element up to `num_limbs` number of limbs (truncates any extra limbs). /// - /// Undefined behavior if `bit_len > 64` + /// Assumes `bit_len < 64`. + /// * `num_limbs`: number of limbs to return + /// * `bit_len`: number of bits in each limb fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec; -} -#[cfg(feature = "halo2-axiom")] -impl ScalarField for F -where - F: FieldExt + Hash + Into<[u64; 4]>, -{ - #[inline(always)] - fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { - let tmp: [u64; 4] = self.into(); - decompose_u64_digits_to_limbs(tmp, num_limbs, bit_len) + + /// Returns the little endian byte representation of the element. + fn to_bytes_le(&self) -> Vec; + + /// Creates a field element from a little endian byte representation. + /// + /// The default implementation assumes that `PrimeField::from_repr` is implemented for little-endian. + /// It should be overriden if this is not the case. + fn from_bytes_le(bytes: &[u8]) -> Self { + let mut repr = Self::Repr::default(); + repr.as_mut()[..bytes.len()].copy_from_slice(bytes); + Self::from_repr(repr).unwrap() } } +// See below for implementations -// Later: will need to separate PrimeField from ScalarField when Goldilocks is introduced -#[cfg(feature = "halo2-axiom")] -pub trait PrimeField = BigPrimeField; -#[cfg(feature = "halo2-pse")] -pub trait PrimeField = FieldExt; +// Later: will need to separate BigPrimeField from ScalarField when Goldilocks is introduced #[cfg(feature = "halo2-pse")] -pub trait ScalarField = FieldExt; +pub trait BigPrimeField = FieldExt + ScalarField; +/// Converts an [Iterator] of u64 digits into `number_of_limbs` limbs of `bit_len` bits returned as a [Vec]. +/// +/// Assumes: `bit_len < 64`. +/// * `e`: Iterator of [u64] digits +/// * `number_of_limbs`: number of limbs to return +/// * `bit_len`: number of bits in each limb #[inline(always)] pub(crate) fn decompose_u64_digits_to_limbs( e: impl IntoIterator, number_of_limbs: usize, bit_len: usize, ) -> Vec { - debug_assert!(bit_len <= 64); + debug_assert!(bit_len < 64); let mut e = e.into_iter(); + // Mask to extract the bits from each digit let mask: u64 = (1u64 << bit_len) - 1u64; let mut u64_digit = e.next().unwrap_or(0); let mut rem = 64; + + // For each digit, we extract its individual limbs by repeatedly masking and shifting the digit based on how many bits we have left to extract. (0..number_of_limbs) .map(|_| match rem.cmp(&bit_len) { + // If `rem` > `bit_len`, we mask the bits from the `u64_digit` to return the first limb. + // We shift the digit to the right by `bit_len` bits and subtract `bit_len` from `rem` core::cmp::Ordering::Greater => { let limb = u64_digit & mask; u64_digit >>= bit_len; rem -= bit_len; limb } + // If `rem` == `bit_len`, then we mask the bits from the `u64_digit` to return the first limb + // We retrieve the next digit and reset `rem` to 64 core::cmp::Ordering::Equal => { let limb = u64_digit & mask; u64_digit = e.next().unwrap_or(0); rem = 64; limb } + // If `rem` < `bit_len`, we retrieve the next digit, mask it, and shift left `rem` bits from the `u64_digit` to return the first limb. + // we shift the digit to the right by `bit_len` - `rem` bits to retrieve the start of the next limb and add 64 - bit_len to `rem` to get the remainder. core::cmp::Ordering::Less => { let mut limb = u64_digit; u64_digit = e.next().unwrap_or(0); - limb |= (u64_digit & ((1 << (bit_len - rem)) - 1)) << rem; + limb |= (u64_digit & ((1u64 << (bit_len - rem)) - 1u64)) << rem; u64_digit >>= bit_len - rem; rem += 64 - bit_len; limb @@ -95,24 +117,35 @@ pub(crate) fn decompose_u64_digits_to_limbs( .collect() } +/// Returns the number of bits needed to represent the value of `x`. pub fn bit_length(x: u64) -> usize { (u64::BITS - x.leading_zeros()) as usize } +/// Returns the ceiling of the base 2 logarithm of `x`. +/// +/// `log2_ceil(0)` returns 0. pub fn log2_ceil(x: u64) -> usize { - (u64::BITS - x.leading_zeros() - (x & (x - 1) == 0) as u32) as usize + (u64::BITS - x.leading_zeros()) as usize - usize::from(x.is_power_of_two()) } -pub fn modulus() -> BigUint { +/// Returns the modulus of [BigPrimeField]. +pub fn modulus() -> BigUint { fe_to_biguint(&-F::one()) + 1u64 } -pub fn power_of_two(n: usize) -> F { +/// Returns the [BigPrimeField] element of 2n. +/// * `n`: the desired power of 2. +pub fn power_of_two(n: usize) -> F { biguint_to_fe(&(BigUint::one() << n)) } -/// assume `e` less than modulus of F -pub fn biguint_to_fe(e: &BigUint) -> F { +/// Converts an immutable reference to [BigUint] to a [BigPrimeField]. +/// * `e`: immutable reference to [BigUint] +/// +/// # Assumptions: +/// * `e` is less than the modulus of `F` +pub fn biguint_to_fe(e: &BigUint) -> F { #[cfg(feature = "halo2-axiom")] { F::from_u64_digits(&e.to_u64_digits()) @@ -120,15 +153,17 @@ pub fn biguint_to_fe(e: &BigUint) -> F { #[cfg(feature = "halo2-pse")] { - let mut repr = F::Repr::default(); let bytes = e.to_bytes_le(); - repr.as_mut()[..bytes.len()].copy_from_slice(&bytes); - F::from_repr(repr).unwrap() + F::from_bytes_le(&bytes) } } -/// assume `|e|` less than modulus of F -pub fn bigint_to_fe(e: &BigInt) -> F { +/// Converts an immutable reference to [BigInt] to a [BigPrimeField]. +/// * `e`: immutable reference to [BigInt] +/// +/// # Assumptions: +/// * The absolute value of `e` is less than the modulus of `F` +pub fn bigint_to_fe(e: &BigInt) -> F { #[cfg(feature = "halo2-axiom")] { let (sign, digits) = e.to_u64_digits(); @@ -141,9 +176,7 @@ pub fn bigint_to_fe(e: &BigInt) -> F { #[cfg(feature = "halo2-pse")] { let (sign, bytes) = e.to_bytes_le(); - let mut repr = F::Repr::default(); - repr.as_mut()[..bytes.len()].copy_from_slice(&bytes); - let f_abs = F::from_repr(repr).unwrap(); + let f_abs = F::from_bytes_le(&bytes); if sign == Sign::Minus { -f_abs } else { @@ -152,11 +185,18 @@ pub fn bigint_to_fe(e: &BigInt) -> F { } } -pub fn fe_to_biguint(fe: &F) -> BigUint { - BigUint::from_bytes_le(fe.to_repr().as_ref()) +/// Converts an immutable reference to an PrimeField element into a [BigUint] element. +/// * `fe`: immutable reference to PrimeField element to convert +pub fn fe_to_biguint(fe: &F) -> BigUint { + BigUint::from_bytes_le(fe.to_bytes_le().as_ref()) } -pub fn fe_to_bigint(fe: &F) -> BigInt { +/// Converts a [BigPrimeField] element into a [BigInt] element by sending `fe` in `[0, F::modulus())` to +/// ```ignore +/// fe, if fe < F::modulus() / 2 +/// fe - F::modulus(), otherwise +/// ``` +pub fn fe_to_bigint(fe: &F) -> BigInt { // TODO: `F` should just have modulus as lazy_static or something let modulus = modulus::(); let e = fe_to_biguint(fe); @@ -167,7 +207,13 @@ pub fn fe_to_bigint(fe: &F) -> BigInt { } } -pub fn decompose(e: &F, number_of_limbs: usize, bit_len: usize) -> Vec { +/// Decomposes an immutable reference to a [BigPrimeField] element into `number_of_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs. +/// +/// Assumes `bit_len < 128`. +/// * `e`: immutable reference to [BigPrimeField] element to decompose +/// * `number_of_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb +pub fn decompose(e: &F, number_of_limbs: usize, bit_len: usize) -> Vec { if bit_len > 64 { decompose_biguint(&fe_to_biguint(e), number_of_limbs, bit_len) } else { @@ -175,7 +221,12 @@ pub fn decompose(e: &F, number_of_limbs: usize, bit_len: usize) - } } -/// Assumes `bit_len` <= 64 +/// Decomposes an immutable reference to a [ScalarField] element into `number_of_limbs` limbs of `bit_len` bits each and returns a [Vec] of [u64] represented by those limbs. +/// +/// Assumes `bit_len` < 64 +/// * `e`: immutable reference to [ScalarField] element to decompose +/// * `number_of_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb pub fn decompose_fe_to_u64_limbs( e: &F, number_of_limbs: usize, @@ -192,29 +243,45 @@ pub fn decompose_fe_to_u64_limbs( } } -pub fn decompose_biguint(e: &BigUint, num_limbs: usize, bit_len: usize) -> Vec { - debug_assert!(bit_len > 64 && bit_len <= 128); +/// Decomposes an immutable reference to a [BigUint] into `num_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs. +/// +/// Assumes 64 <= `bit_len` < 128. +/// * `e`: immutable reference to [BigInt] to decompose +/// * `num_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb +/// +/// Truncates to `num_limbs` limbs if `e` is too large. +pub fn decompose_biguint( + e: &BigUint, + num_limbs: usize, + bit_len: usize, +) -> Vec { + // bit_len must be between 64` and 128 + debug_assert!((64..128).contains(&bit_len)); let mut e = e.iter_u64_digits(); + // Grab first 128-bit limb from iterator let mut limb0 = e.next().unwrap_or(0) as u128; let mut rem = bit_len - 64; let mut u64_digit = e.next().unwrap_or(0); - limb0 |= ((u64_digit & ((1 << rem) - 1)) as u128) << 64; + // Extract second limb (bit length 64) from e + limb0 |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << 64u32; u64_digit >>= rem; rem = 64 - rem; + // Convert `limb0` into field element `F` and create an iterator by chaining `limb0` with the computing the remaining limbs core::iter::once(F::from_u128(limb0)) .chain((1..num_limbs).map(|_| { - let mut limb: u128 = u64_digit.into(); + let mut limb = u64_digit as u128; let mut bits = rem; u64_digit = e.next().unwrap_or(0); - if bit_len - bits >= 64 { + if bit_len >= 64 + bits { limb |= (u64_digit as u128) << bits; u64_digit = e.next().unwrap_or(0); bits += 64; } rem = bit_len - bits; - limb |= ((u64_digit & ((1 << rem) - 1)) as u128) << bits; + limb |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << bits; u64_digit >>= rem; rem = 64 - rem; F::from_u128(limb) @@ -222,7 +289,13 @@ pub fn decompose_biguint(e: &BigUint, num_limbs: usize, bit_len: .collect() } -pub fn decompose_bigint(e: &BigInt, num_limbs: usize, bit_len: usize) -> Vec { +/// Decomposes an immutable reference to a [BigInt] into `num_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs. +/// +/// Assumes `bit_len < 128`. +/// * `e`: immutable reference to `BigInt` to decompose +/// * `num_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb +pub fn decompose_bigint(e: &BigInt, num_limbs: usize, bit_len: usize) -> Vec { if e.is_negative() { decompose_biguint::(e.magnitude(), num_limbs, bit_len).into_iter().map(|x| -x).collect() } else { @@ -230,7 +303,13 @@ pub fn decompose_bigint(e: &BigInt, num_limbs: usize, bit_len: us } } -pub fn decompose_bigint_option( +/// Decomposes an immutable reference to a [BigInt] into `num_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs wrapped in [Value]. +/// +/// Assumes `bit_len` < 128. +/// * `e`: immutable reference to `BigInt` to decompose +/// * `num_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb +pub fn decompose_bigint_option( value: Value<&BigInt>, number_of_limbs: usize, bit_len: usize, @@ -238,6 +317,9 @@ pub fn decompose_bigint_option( value.map(|e| decompose_bigint(e, number_of_limbs, bit_len)).transpose_vec(number_of_limbs) } +/// Wraps the internal value of `value` in an [Option]. +/// If the value is [None], then the function returns [None]. +/// * `value`: Value to convert. pub fn value_to_option(value: Value) -> Option { let mut v = None; value.map(|val| { @@ -246,28 +328,22 @@ pub fn value_to_option(value: Value) -> Option { v } -/// Compute the represented value by a vector of values and a bit length. +/// Computes the value of an integer by passing as `input` a [Vec] of its limb values and the `bit_len` (bit length) used. /// -/// This function is used to compute the value of an integer -/// passing as input its limb values and the bit length used. -/// Returns the sum of all limbs scaled by 2^(bit_len * i) +/// Returns the sum of all limbs scaled by 2(bit_len * i) where i is the index of the limb. +/// * `input`: Limb values of the integer. +/// * `bit_len`: Length of limb in bits pub fn compose(input: Vec, bit_len: usize) -> BigUint { input.iter().rev().fold(BigUint::zero(), |acc, val| (acc << bit_len) + val) } -#[cfg(test)] -#[test] -fn test_signed_roundtrip() { - use crate::halo2_proofs::halo2curves::bn256::Fr; - assert_eq!(fe_to_bigint(&bigint_to_fe::(&-BigInt::one())), -BigInt::one()); -} - #[cfg(feature = "halo2-axiom")] pub use halo2_proofs_axiom::halo2curves::CurveAffineExt; +/// Helper trait #[cfg(feature = "halo2-pse")] pub trait CurveAffineExt: CurveAffine { - /// Unlike the `Coordinates` trait, this just returns the raw affine coordinantes without checking `is_on_curve` + /// Unlike the `Coordinates` trait, this just returns the raw affine (X, Y) coordinantes without checking `is_on_curve` fn into_coordinates(self) -> (Self::Base, Self::Base) { let coordinates = self.coordinates().unwrap(); (*coordinates.x(), *coordinates.y()) @@ -276,6 +352,68 @@ pub trait CurveAffineExt: CurveAffine { #[cfg(feature = "halo2-pse")] impl CurveAffineExt for C {} +mod scalar_field_impls { + use super::{decompose_u64_digits_to_limbs, ScalarField}; + use crate::halo2_proofs::halo2curves::{ + bn256::{Fq as bn254Fq, Fr as bn254Fr}, + secp256k1::{Fp as secpFp, Fq as secpFq}, + }; + #[cfg(feature = "halo2-pse")] + use ff::PrimeField; + + /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro + /// to implement the trait for each field. + #[cfg(feature = "halo2-axiom")] + #[macro_export] + macro_rules! impl_scalar_field { + ($field:ident) => { + impl ScalarField for $field { + #[inline(always)] + fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { + // Basically same as `to_repr` but does not go further into bytes + let tmp: [u64; 4] = self.into(); + decompose_u64_digits_to_limbs(tmp, num_limbs, bit_len) + } + + #[inline(always)] + fn to_bytes_le(&self) -> Vec { + let tmp: [u64; 4] = (*self).into(); + tmp.iter().flat_map(|x| x.to_le_bytes()).collect() + } + } + }; + } + + /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro + /// to implement the trait for each field. + #[cfg(feature = "halo2-pse")] + #[macro_export] + macro_rules! impl_scalar_field { + ($field:ident) => { + impl ScalarField for $field { + #[inline(always)] + fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { + let bytes = self.to_repr(); + let digits = (0..4) + .map(|i| u64::from_le_bytes(bytes[i * 8..(i + 1) * 8].try_into().unwrap())); + decompose_u64_digits_to_limbs(digits, num_limbs, bit_len) + } + + #[inline(always)] + fn to_bytes_le(&self) -> Vec { + self.to_repr().to_vec() + } + } + }; + } + + impl_scalar_field!(bn254Fr); + impl_scalar_field!(bn254Fq); + impl_scalar_field!(secpFp); + impl_scalar_field!(secpFq); +} + +/// Module for reading parameters for Halo2 proving system from the file system. pub mod fs { use std::{ env::var, @@ -288,10 +426,15 @@ pub mod fs { bn256::{Bn256, G1Affine}, CurveAffine, }, - poly::{commitment::{Params, ParamsProver}, kzg::commitment::ParamsKZG}, + poly::{ + commitment::{Params, ParamsProver}, + kzg::commitment::ParamsKZG, + }, }; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + /// Reads the srs from a file found in `./params/kzg_bn254_{k}.srs` or `{dir}/kzg_bn254_{k}.srs` if `PARAMS_DIR` env var is specified. + /// * `k`: degree that expresses the size of circuit (i.e., 2^k is the number of rows in the circuit) pub fn read_params(k: u32) -> ParamsKZG { let dir = var("PARAMS_DIR").unwrap_or_else(|_| "./params".to_string()); ParamsKZG::::read(&mut BufReader::new( @@ -301,6 +444,9 @@ pub mod fs { .unwrap() } + /// Attempts to read the srs from a file found in `./params/kzg_bn254_{k}.srs` or `{dir}/kzg_bn254_{k}.srs` if `PARAMS_DIR` env var is specified, creates a file it if it does not exist. + /// * `k`: degree that expresses the size of circuit (i.e., 2^k is the number of rows in the circuit) + /// * `setup`: a function that creates the srs pub fn read_or_create_srs<'a, C: CurveAffine, P: ParamsProver<'a, C>>( k: u32, setup: impl Fn(u32) -> P, @@ -325,9 +471,89 @@ pub mod fs { } } + /// Generates the SRS for the KZG scheme and writes it to a file found in "./params/kzg_bn2_{k}.srs` or `{dir}/kzg_bn254_{k}.srs` if `PARAMS_DIR` env var is specified, creates a file it if it does not exist" + /// * `k`: degree that expresses the size of circuit (i.e., 2^k is the number of rows in the circuit) pub fn gen_srs(k: u32) -> ParamsKZG { read_or_create_srs::(k, |k| { ParamsKZG::::setup(k, ChaCha20Rng::from_seed(Default::default())) }) } } + +#[cfg(test)] +mod tests { + use crate::halo2_proofs::halo2curves::bn256::Fr; + use num_bigint::RandomBits; + use rand::{rngs::OsRng, Rng}; + use std::ops::Shl; + + use super::*; + + #[test] + fn test_signed_roundtrip() { + use crate::halo2_proofs::halo2curves::bn256::Fr; + assert_eq!(fe_to_bigint(&bigint_to_fe::(&-BigInt::one())), -BigInt::one()); + } + + #[test] + fn test_decompose_biguint() { + let mut rng = OsRng; + const MAX_LIMBS: u64 = 5; + for bit_len in 64..128usize { + for num_limbs in 1..=MAX_LIMBS { + for _ in 0..10_000usize { + let mut e: BigUint = rng.sample(RandomBits::new(num_limbs * bit_len as u64)); + let limbs = decompose_biguint::(&e, num_limbs as usize, bit_len); + + let limbs2 = { + let mut limbs = vec![]; + let mask = BigUint::one().shl(bit_len) - 1usize; + for _ in 0..num_limbs { + let limb = &e & &mask; + let mut bytes_le = limb.to_bytes_le(); + bytes_le.resize(32, 0u8); + limbs.push(Fr::from_bytes(&bytes_le.try_into().unwrap()).unwrap()); + e >>= bit_len; + } + limbs + }; + assert_eq!(limbs, limbs2); + } + } + } + } + + #[test] + fn test_decompose_u64_digits_to_limbs() { + let mut rng = OsRng; + const MAX_LIMBS: u64 = 5; + for bit_len in 0..64usize { + for num_limbs in 1..=MAX_LIMBS { + for _ in 0..10_000usize { + let mut e: BigUint = rng.sample(RandomBits::new(num_limbs * bit_len as u64)); + let limbs = decompose_u64_digits_to_limbs( + e.to_u64_digits(), + num_limbs as usize, + bit_len, + ); + let limbs2 = { + let mut limbs = vec![]; + let mask = BigUint::one().shl(bit_len) - 1usize; + for _ in 0..num_limbs { + let limb = &e & &mask; + limbs.push(u64::try_from(limb).unwrap()); + e >>= bit_len; + } + limbs + }; + assert_eq!(limbs, limbs2); + } + } + } + } + + #[test] + fn test_log2_ceil_zero() { + assert_eq!(log2_ceil(0), 0); + } +} diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index a142200d..2b03e1cb 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2-ecc" -version = "0.2.2" +version = "0.3.0" edition = "2021" [dependencies] @@ -13,6 +13,8 @@ rand = "0.8" rand_chacha = "0.3.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +rayon = "1.6.1" +test-case = "3.1.0" # arithmetic ff = "0.12" @@ -25,6 +27,7 @@ ark-std = { version = "0.3.0", features = ["print-trace"] } pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" +halo2-base = { path = "../halo2-base", default-features = false, features = ["test-utils"] } [features] default = ["jemallocator", "halo2-axiom", "display"] diff --git a/halo2-ecc/benches/fixed_base_msm.rs b/halo2-ecc/benches/fixed_base_msm.rs index 0bdf7e12..b4f3df25 100644 --- a/halo2-ecc/benches/fixed_base_msm.rs +++ b/halo2-ecc/benches/fixed_base_msm.rs @@ -1,166 +1,93 @@ -use criterion::{criterion_group, criterion_main}; -use criterion::{BenchmarkId, Criterion}; - -#[allow(unused_imports)] -use ff::PrimeField as _; -use halo2_base::utils::modulus; -use pprof::criterion::{Output, PProfProfiler}; - use ark_std::{end_timer, start_timer}; -use halo2_base::SKIP_FIRST_PASS; -use rand_core::OsRng; -use serde::{Deserialize, Serialize}; -use std::marker::PhantomData; - +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; use halo2_base::halo2_proofs::{ arithmetic::Field, - circuit::{Layouter, SimpleFloorPlanner, Value}, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, poly::kzg::{ commitment::{KZGCommitmentScheme, ParamsKZG}, multiopen::ProverSHPLONK, }, - transcript::TranscriptWriterBuffer, - transcript::{Blake2bWrite, Challenge255}, -}; -use halo2_base::{gates::GateInstructions, utils::PrimeField}; -use halo2_ecc::{ - ecc::EccChip, - fields::fp::{FpConfig, FpStrategy}, + transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; +use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use rand::rngs::OsRng; -type FpChip = FpConfig; +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; -#[derive(Serialize, Deserialize, Debug)] +use pprof::criterion::{Output, PProfProfiler}; +// Thanks to the example provided by @jebbow in his article +// https://www.jibbow.com/posts/criterion-flamegraphs/ + +#[derive(Clone, Copy, Debug)] struct MSMCircuitParams { - strategy: FpStrategy, degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, lookup_bits: usize, limb_bits: usize, num_limbs: usize, batch_size: usize, - radix: usize, - clump_factor: usize, } -const BEST_100_CONFIG: MSMCircuitParams = MSMCircuitParams { - strategy: FpStrategy::Simple, - degree: 20, - num_advice: 10, - num_lookup_advice: 1, - num_fixed: 1, - lookup_bits: 19, - limb_bits: 88, - num_limbs: 3, - batch_size: 100, - radix: 0, - clump_factor: 4, -}; +const BEST_100_CONFIG: MSMCircuitParams = + MSMCircuitParams { degree: 20, lookup_bits: 19, limb_bits: 88, num_limbs: 3, batch_size: 100 }; const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; -#[derive(Clone, Debug)] -struct MSMConfig { - fp_chip: FpChip, - clump_factor: usize, -} - -impl MSMConfig { - #[allow(clippy::too_many_arguments)] - pub fn configure(meta: &mut ConstraintSystem, params: MSMCircuitParams) -> Self { - let fp_chip = FpChip::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - modulus::(), - 0, - params.degree as usize, - ); - MSMConfig { fp_chip, clump_factor: params.clump_factor } - } -} - -struct MSMCircuit { +fn fixed_base_msm_bench( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, bases: Vec, - scalars: Vec>, - _marker: PhantomData, + scalars: Vec, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let scalars_assigned = scalars + .iter() + .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) + .collect::>(); + + ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); } -impl Circuit for MSMCircuit { - type Config = MSMConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - bases: self.bases.clone(), - scalars: vec![None; self.scalars.len()], - _marker: PhantomData, +fn fixed_base_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + bases: Vec, + scalars: Vec, + break_points: Option, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + fixed_base_msm_bench(&mut builder, params, bases, scalars); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let params = TEST_CONFIG; - - MSMConfig::::configure(meta, params) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.fp_chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "fixed base msm", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let witness_time = start_timer!(|| "Witness generation"); - let mut scalars_assigned = Vec::new(); - for scalar in &self.scalars { - let assignment = config - .fp_chip - .range - .gate - .assign_witnesses(ctx, vec![scalar.map_or(Value::unknown(), Value::known)]); - scalars_assigned.push(assignment); - } - - let ecc_chip = EccChip::construct(config.fp_chip.clone()); - - let _msm = ecc_chip.fixed_base_msm::( - ctx, - &self.bases, - &scalars_assigned, - Fr::NUM_BITS as usize, - 0, - config.clump_factor, - ); - - config.fp_chip.finalize(ctx); - end_timer!(witness_time); - - Ok(()) - }, - ) - } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } fn bench(c: &mut Criterion) { @@ -168,39 +95,36 @@ fn bench(c: &mut Criterion) { let k = config.degree; let mut rng = OsRng; - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..config.batch_size { - let new_pt = G1Affine::random(&mut rng); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - let circuit = MSMCircuit:: { bases, scalars, _marker: PhantomData }; + let circuit = fixed_base_msm_circuit( + config, + CircuitBuilderStage::Keygen, + vec![G1Affine::generator(); config.batch_size], + vec![Fr::zero(); config.batch_size], + None, + ); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.0.break_points.take(); + drop(circuit); + let (bases, scalars): (Vec<_>, Vec<_>) = + (0..config.batch_size).map(|_| (G1Affine::random(&mut rng), Fr::random(&mut rng))).unzip(); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); group.bench_with_input( BenchmarkId::new("fixed base msm", k), - &(¶ms, &pk), - |b, &(params, pk)| { + &(¶ms, &pk, &bases, &scalars), + |b, &(params, pk, bases, scalars)| { b.iter(|| { - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..config.batch_size { - let new_pt = G1Affine::random(&mut rng); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - let circuit = MSMCircuit:: { bases, scalars, _marker: PhantomData }; + let circuit = fixed_base_msm_circuit( + config, + CircuitBuilderStage::Prover, + bases.clone(), + scalars.clone(), + Some(break_points.clone()), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-ecc/benches/fp_mul.rs b/halo2-ecc/benches/fp_mul.rs index d49162e0..48351c45 100644 --- a/halo2-ecc/benches/fp_mul.rs +++ b/halo2-ecc/benches/fp_mul.rs @@ -1,25 +1,28 @@ -use std::marker::PhantomData; - -use halo2_base::halo2_proofs::{ - arithmetic::Field, - circuit::*, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, - plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, +use ark_std::{end_timer, start_timer}; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + halo2_proofs::{ + arithmetic::Field, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + plonk::*, + poly::kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::ProverSHPLONK, + }, + transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + }, + Context, }; +use halo2_ecc::fields::fp::FpChip; +use halo2_ecc::fields::{FieldChip, PrimeField}; use rand::rngs::OsRng; -use halo2_base::{ - utils::{fe_to_bigint, modulus, PrimeField}, - SKIP_FIRST_PASS, -}; -use halo2_ecc::fields::fp::{FpConfig, FpStrategy}; -use halo2_ecc::fields::FieldChip; - use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; @@ -29,106 +32,88 @@ use pprof::criterion::{Output, PProfProfiler}; const K: u32 = 19; -#[derive(Default)] -struct MyCircuit { - a: Value, - b: Value, - _marker: PhantomData, -} - -const NUM_ADVICE: usize = 2; -const NUM_FIXED: usize = 1; - -impl Circuit for MyCircuit { - type Config = FpConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() +fn fp_mul_bench( + ctx: &mut Context, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + _a: Fq, + _b: Fq, +) { + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range = RangeChip::::default(lookup_bits); + let chip = FpChip::::new(&range, limb_bits, num_limbs); + + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + for _ in 0..2857 { + chip.mul(ctx, &a, &b); } +} - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FpConfig::::configure( - meta, - FpStrategy::Simple, - &[NUM_ADVICE], - &[1], - NUM_FIXED, - K as usize - 1, - 88, - 3, - modulus::(), - 0, - K as usize, - ) - } - - fn synthesize(&self, chip: Self::Config, mut layouter: impl Layouter) -> Result<(), Error> { - chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "fp", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = chip.new_context(region); - let ctx = &mut aux; - - let a_assigned = chip.load_private(ctx, self.a.as_ref().map(fe_to_bigint)); - let b_assigned = chip.load_private(ctx, self.b.as_ref().map(fe_to_bigint)); - - for _ in 0..2857 { - chip.mul(ctx, &a_assigned, &b_assigned); - } - - // IMPORTANT: this copies advice cells to enable lookup - // This is not optional. - chip.finalize(ctx); - - Ok(()) - }, - ) - } +fn fp_mul_circuit( + stage: CircuitBuilderStage, + a: Fq, + b: Fq, + break_points: Option, +) -> RangeCircuitBuilder { + let k = K as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + fp_mul_bench(builder.main(0), k - 1, 88, 3, a, b); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } fn bench(c: &mut Criterion) { - let a = Fq::random(OsRng); - let b = Fq::random(OsRng); - - let circuit = MyCircuit:: { a: Value::known(a), b: Value::known(b), _marker: PhantomData }; + let circuit = fp_mul_circuit(CircuitBuilderStage::Keygen, Fq::zero(), Fq::zero(), None); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.0.break_points.take(); + let a = Fq::random(OsRng); + let b = Fq::random(OsRng); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); - group.bench_with_input(BenchmarkId::new("fp mul", K), &(¶ms, &pk), |b, &(params, pk)| { - b.iter(|| { - let rng = OsRng; - let a = Fq::random(OsRng); - let b = Fq::random(OsRng); - - let circuit = - MyCircuit:: { a: Value::known(a), b: Value::known(b), _marker: PhantomData }; - - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], rng, &mut transcript) - .expect("prover should not fail"); - }) - }); + group.bench_with_input( + BenchmarkId::new("fp mul", K), + &(¶ms, &pk, a, b), + |bencher, &(params, pk, a, b)| { + bencher.iter(|| { + let circuit = + fp_mul_circuit(CircuitBuilderStage::Prover, a, b, Some(break_points.clone())); + + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255<_>>, + _, + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("prover should not fail"); + }) + }, + ); group.finish() } diff --git a/halo2-ecc/benches/msm.rs b/halo2-ecc/benches/msm.rs index 22be806e..3a98ee38 100644 --- a/halo2-ecc/benches/msm.rs +++ b/halo2-ecc/benches/msm.rs @@ -1,224 +1,109 @@ -use criterion::{criterion_group, criterion_main}; -use criterion::{BenchmarkId, Criterion}; - -use halo2_base::utils::modulus; -use pprof::criterion::{Output, PProfProfiler}; - use ark_std::{end_timer, start_timer}; -use halo2_base::SKIP_FIRST_PASS; -use rand_core::OsRng; -use serde::{Deserialize, Serialize}; -use std::marker::PhantomData; - +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; use halo2_base::halo2_proofs::{ arithmetic::Field, - circuit::{Layouter, SimpleFloorPlanner, Value}, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, poly::kzg::{ commitment::{KZGCommitmentScheme, ParamsKZG}, multiopen::ProverSHPLONK, }, - transcript::TranscriptWriterBuffer, - transcript::{Blake2bWrite, Challenge255}, -}; -use halo2_base::{ - gates::GateInstructions, - utils::{biguint_to_fe, fe_to_biguint, PrimeField}, - QuantumCell::Witness, -}; -use halo2_ecc::{ - ecc::EccChip, - fields::fp::{FpConfig, FpStrategy}, + transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; -use num_bigint::BigUint; +use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use rand::rngs::OsRng; -type FpChip = FpConfig; +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; -#[derive(Serialize, Deserialize, Debug)] +use pprof::criterion::{Output, PProfProfiler}; +// Thanks to the example provided by @jebbow in his article +// https://www.jibbow.com/posts/criterion-flamegraphs/ + +#[derive(Clone, Copy, Debug)] struct MSMCircuitParams { - strategy: FpStrategy, degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, lookup_bits: usize, limb_bits: usize, num_limbs: usize, batch_size: usize, - window_bits: usize, + clump_factor: usize, } const BEST_100_CONFIG: MSMCircuitParams = MSMCircuitParams { - strategy: FpStrategy::Simple, degree: 19, - num_advice: 20, - num_lookup_advice: 3, - num_fixed: 1, lookup_bits: 18, limb_bits: 90, num_limbs: 3, batch_size: 100, - window_bits: 4, + clump_factor: 4, }; - const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; -#[derive(Clone, Debug)] -struct MSMConfig { - fp_chip: FpChip, - batch_size: usize, - window_bits: usize, -} - -impl MSMConfig { - #[allow(clippy::too_many_arguments)] - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - batch_size: usize, - window_bits: usize, - context_id: usize, - k: usize, - ) -> Self { - let fp_chip = FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - p, - context_id, - k, - ); - MSMConfig { fp_chip, batch_size, window_bits } - } -} - -struct MSMCircuit { - bases: Vec>, - scalars: Vec>, - batch_size: usize, - _marker: PhantomData, +fn msm_bench( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); + + ecc_chip.variable_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + params.clump_factor, + 0, + ); } -impl Default for MSMCircuit { - fn default() -> Self { - Self { - bases: vec![None; 10], - scalars: vec![None; 10], - batch_size: 10, - _marker: PhantomData, +fn msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + bases: Vec, + scalars: Vec, + break_points: Option, +) -> RangeCircuitBuilder { + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + msm_bench(&mut builder, params, bases, scalars); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) } - } -} - -impl Circuit for MSMCircuit { - type Config = MSMConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - bases: vec![None; self.batch_size], - scalars: vec![None; self.batch_size], - batch_size: self.batch_size, - _marker: PhantomData, + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let params: MSMCircuitParams = TEST_CONFIG; - - MSMConfig::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - modulus::(), - params.batch_size, - params.window_bits, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - assert_eq!(config.batch_size, self.scalars.len()); - assert_eq!(config.batch_size, self.bases.len()); - - config.fp_chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "MSM", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let witness_time = start_timer!(|| "Witness Generation"); - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let mut scalars_assigned = Vec::new(); - for scalar in &self.scalars { - let assignment = config.fp_chip.range.gate.assign_region_last( - ctx, - vec![Witness(scalar.map_or(Value::unknown(), Value::known))], - vec![], - ); - scalars_assigned.push(vec![assignment]); - } - - let ecc_chip = EccChip::construct(config.fp_chip.clone()); - let mut bases_assigned = Vec::new(); - for base in &self.bases { - let base_assigned = ecc_chip.load_private( - ctx, - ( - base.map(|pt| Value::known(biguint_to_fe(&fe_to_biguint(&pt.x)))) - .unwrap_or(Value::unknown()), - base.map(|pt| Value::known(biguint_to_fe(&fe_to_biguint(&pt.y)))) - .unwrap_or(Value::unknown()), - ), - ); - bases_assigned.push(base_assigned); - } - - let _msm = ecc_chip.variable_base_msm::( - ctx, - &bases_assigned, - &scalars_assigned, - 254, - config.window_bits, - ); - - config.fp_chip.finalize(ctx); - end_timer!(witness_time); - - Ok(()) - }, - ) - } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } fn bench(c: &mut Criterion) { @@ -226,55 +111,50 @@ fn bench(c: &mut Criterion) { let k = config.degree; let mut rng = OsRng; - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..config.batch_size { - let new_pt = Some(G1Affine::random(&mut rng)); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - let circuit = - MSMCircuit:: { bases, scalars, batch_size: config.batch_size, _marker: PhantomData }; + let circuit = msm_circuit( + config, + CircuitBuilderStage::Keygen, + vec![G1Affine::generator(); config.batch_size], + vec![Fr::one(); config.batch_size], + None, + ); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.0.break_points.take(); + drop(circuit); + let (bases, scalars): (Vec<_>, Vec<_>) = + (0..config.batch_size).map(|_| (G1Affine::random(&mut rng), Fr::random(&mut rng))).unzip(); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); - group.bench_with_input(BenchmarkId::new("msm", k), &(¶ms, &pk), |b, &(params, pk)| { - b.iter(|| { - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..config.batch_size { - let new_pt = Some(G1Affine::random(&mut rng)); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - let circuit = MSMCircuit:: { - bases, - scalars, - batch_size: config.batch_size, - _marker: PhantomData, - }; + group.bench_with_input( + BenchmarkId::new("msm", k), + &(¶ms, &pk, &bases, &scalars), + |b, &(params, pk, bases, scalars)| { + b.iter(|| { + let circuit = msm_circuit( + config, + CircuitBuilderStage::Prover, + bases.clone(), + scalars.clone(), + Some(break_points.clone()), + ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) - .expect("prover should not fail"); - }) - }); + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255<_>>, + _, + >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) + .expect("prover should not fail"); + }) + }, + ); group.finish() } diff --git a/halo2-ecc/src/bn254/configs/bench_ec_add.config b/halo2-ecc/configs/bn254/bench_ec_add.config similarity index 100% rename from halo2-ecc/src/bn254/configs/bench_ec_add.config rename to halo2-ecc/configs/bn254/bench_ec_add.config diff --git a/halo2-ecc/src/bn254/configs/bench_fixed_msm.config b/halo2-ecc/configs/bn254/bench_fixed_msm.config similarity index 100% rename from halo2-ecc/src/bn254/configs/bench_fixed_msm.config rename to halo2-ecc/configs/bn254/bench_fixed_msm.config diff --git a/halo2-ecc/configs/bn254/bench_fixed_msm.t.config b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config new file mode 100644 index 00000000..61db5d6d --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":17,"num_advice":83,"num_lookup_advice":9,"num_fixed":7,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":18,"num_advice":42,"num_lookup_advice":5,"num_fixed":4,"lookup_bits":17,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":2,"num_fixed":2,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"radix":0,"clump_factor":4} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/configs/bench_msm.config b/halo2-ecc/configs/bn254/bench_msm.config similarity index 92% rename from halo2-ecc/src/bn254/configs/bench_msm.config rename to halo2-ecc/configs/bn254/bench_msm.config index 1d1f769c..d665c0a8 100644 --- a/halo2-ecc/src/bn254/configs/bench_msm.config +++ b/halo2-ecc/configs/bn254/bench_msm.config @@ -1,3 +1,4 @@ +{"strategy":"Simple","degree":16,"num_advice":170,"num_lookup_advice":23,"num_fixed":1,"lookup_bits":15,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} {"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} {"strategy":"Simple","degree":18,"num_advice":42,"num_lookup_advice":6,"num_fixed":1,"lookup_bits":17,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} {"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"window_bits":4} diff --git a/halo2-ecc/configs/bn254/bench_msm.t.config b/halo2-ecc/configs/bn254/bench_msm.t.config new file mode 100644 index 00000000..bd4c4318 --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_msm.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":16,"num_advice":170,"num_lookup_advice":23,"num_fixed":1,"lookup_bits":15,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"window_bits":4} +{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/configs/bench_pairing.config b/halo2-ecc/configs/bn254/bench_pairing.config similarity index 100% rename from halo2-ecc/src/bn254/configs/bench_pairing.config rename to halo2-ecc/configs/bn254/bench_pairing.config diff --git a/halo2-ecc/configs/bn254/bench_pairing.t.config b/halo2-ecc/configs/bn254/bench_pairing.t.config new file mode 100644 index 00000000..d76ebad1 --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_pairing.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":15,"num_advice":105,"num_lookup_advice":14,"num_fixed":1,"lookup_bits":14,"limb_bits":90,"num_limbs":3} +{"strategy":"Simple","degree":17,"num_advice":25,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3} +{"strategy":"Simple","degree":18,"num_advice":13,"num_lookup_advice":2,"num_fixed":1,"lookup_bits":17,"limb_bits":88,"num_limbs":3} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3} +{"strategy":"Simple","degree":20,"num_advice":3,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/configs/ec_add_circuit.config b/halo2-ecc/configs/bn254/ec_add_circuit.config similarity index 100% rename from halo2-ecc/src/bn254/configs/ec_add_circuit.config rename to halo2-ecc/configs/bn254/ec_add_circuit.config diff --git a/halo2-ecc/src/bn254/configs/fixed_msm_circuit.config b/halo2-ecc/configs/bn254/fixed_msm_circuit.config similarity index 100% rename from halo2-ecc/src/bn254/configs/fixed_msm_circuit.config rename to halo2-ecc/configs/bn254/fixed_msm_circuit.config diff --git a/halo2-ecc/configs/bn254/msm_circuit.config b/halo2-ecc/configs/bn254/msm_circuit.config new file mode 100644 index 00000000..f66f6077 --- /dev/null +++ b/halo2-ecc/configs/bn254/msm_circuit.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/configs/pairing_circuit.config b/halo2-ecc/configs/bn254/pairing_circuit.config similarity index 100% rename from halo2-ecc/src/bn254/configs/pairing_circuit.config rename to halo2-ecc/configs/bn254/pairing_circuit.config diff --git a/halo2-ecc/src/secp256k1/configs/bench_ecdsa.config b/halo2-ecc/configs/secp256k1/bench_ecdsa.config similarity index 100% rename from halo2-ecc/src/secp256k1/configs/bench_ecdsa.config rename to halo2-ecc/configs/secp256k1/bench_ecdsa.config diff --git a/halo2-ecc/src/secp256k1/configs/ecdsa_circuit.config b/halo2-ecc/configs/secp256k1/ecdsa_circuit.config similarity index 100% rename from halo2-ecc/src/secp256k1/configs/ecdsa_circuit.config rename to halo2-ecc/configs/secp256k1/ecdsa_circuit.config diff --git a/halo2-ecc/src/bigint/add_no_carry.rs b/halo2-ecc/src/bigint/add_no_carry.rs index 8cc687d4..19feb35d 100644 --- a/halo2-ecc/src/bigint/add_no_carry.rs +++ b/halo2-ecc/src/bigint/add_no_carry.rs @@ -1,34 +1,37 @@ use super::{CRTInteger, OverflowInteger}; -use halo2_base::{gates::GateInstructions, utils::PrimeField, Context, QuantumCell::Existing}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, Context}; +use itertools::Itertools; use std::cmp::max; -pub fn assign<'v, F: PrimeField>( +/// # Assumptions +/// * `a, b` have same number of limbs +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, -) -> OverflowInteger<'v, F> { - assert_eq!(a.limbs.len(), b.limbs.len()); - + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, +) -> OverflowInteger { let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(a_limb, b_limb)| gate.add(ctx, Existing(a_limb), Existing(b_limb))) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.add(ctx, a_limb, b_limb)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) + OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) } -pub fn crt<'v, F: PrimeField>( +/// # Assumptions +/// * `a, b` have same number of limbs +// pass by reference to avoid cloning the BigInt in CRTInteger, unclear if this is optimal +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, -) -> CRTInteger<'v, F> { - assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation); - let out_native = gate.add(ctx, Existing(&a.native), Existing(&b.native)); - let out_val = a.value.as_ref().zip(b.value.as_ref()).map(|(a, b)| a + b); - CRTInteger::construct(out_trunc, out_native, out_val) + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, +) -> CRTInteger { + let out_trunc = assign(gate, ctx, a.truncation, b.truncation); + let out_native = gate.add(ctx, a.native, b.native); + let out_val = a.value + b.value; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/big_is_equal.rs b/halo2-ecc/src/bigint/big_is_equal.rs index f963937f..78626b22 100644 --- a/halo2-ecc/src/bigint/big_is_equal.rs +++ b/halo2-ecc/src/bigint/big_is_equal.rs @@ -1,47 +1,29 @@ -use super::{CRTInteger, OverflowInteger}; -use halo2_base::{ - gates::GateInstructions, utils::PrimeField, AssignedValue, Context, QuantumCell::Existing, -}; +use super::ProperUint; +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use itertools::Itertools; -// given OverflowInteger's `a` and `b` of the same shape, -// returns whether `a == b` -pub fn assign<'v, F: PrimeField>( +/// Given [`ProperUint`]s `a` and `b` with the same number of limbs, +/// returns whether `a == b`. +/// +/// # Assumptions: +/// * `a, b` have the same number of limbs. +/// * The number of limbs is nonzero. +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, -) -> AssignedValue<'v, F> { - let k = a.limbs.len(); - assert_eq!(k, b.limbs.len()); - assert_ne!(k, 0); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, +) -> AssignedValue { + let a = a.into(); + let b = b.into(); + debug_assert!(!a.0.is_empty()); - let mut a_limbs = a.limbs.iter(); - let mut b_limbs = b.limbs.iter(); - let mut partial = - gate.is_equal(ctx, Existing(a_limbs.next().unwrap()), Existing(b_limbs.next().unwrap())); - for (a_limb, b_limb) in a_limbs.zip(b_limbs) { - let eq_limb = gate.is_equal(ctx, Existing(a_limb), Existing(b_limb)); - partial = gate.and(ctx, Existing(&eq_limb), Existing(&partial)); + let mut a_limbs = a.0.into_iter(); + let mut b_limbs = b.0.into_iter(); + let mut partial = gate.is_equal(ctx, a_limbs.next().unwrap(), b_limbs.next().unwrap()); + for (a_limb, b_limb) in a_limbs.zip_eq(b_limbs) { + let eq_limb = gate.is_equal(ctx, a_limb, b_limb); + partial = gate.and(ctx, eq_limb, partial); } partial } - -pub fn wrapper<'v, F: PrimeField>( - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, -) -> AssignedValue<'v, F> { - assign(gate, ctx, &a.truncation, &b.truncation) -} - -pub fn crt<'v, F: PrimeField>( - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, -) -> AssignedValue<'v, F> { - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation); - let out_native = gate.is_equal(ctx, Existing(&a.native), Existing(&b.native)); - gate.and(ctx, Existing(&out_trunc), Existing(&out_native)) -} diff --git a/halo2-ecc/src/bigint/big_is_zero.rs b/halo2-ecc/src/bigint/big_is_zero.rs index 4ab84fa3..aa67c842 100644 --- a/halo2-ecc/src/bigint/big_is_zero.rs +++ b/halo2-ecc/src/bigint/big_is_zero.rs @@ -1,46 +1,53 @@ -use super::{CRTInteger, OverflowInteger}; -use halo2_base::{ - gates::GateInstructions, utils::PrimeField, AssignedValue, Context, QuantumCell::Existing, -}; +use super::{OverflowInteger, ProperCrtUint, ProperUint}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; -/// assume you know that the limbs of `a` are all in [0, 2^{a.max_limb_bits}) -pub fn positive<'v, F: PrimeField>( +/// # Assumptions +/// * `a` has nonzero number of limbs +/// * The limbs of `a` are all in [0, 2a.max_limb_bits) +/// * a.limbs.len() * 2a.max_limb_bits ` is less than modulus of `F` +pub fn positive( gate: &impl GateInstructions, - ctx: &mut Context<'v, F>, - a: &OverflowInteger<'v, F>, -) -> AssignedValue<'v, F> { + ctx: &mut Context, + a: OverflowInteger, +) -> AssignedValue { let k = a.limbs.len(); assert_ne!(k, 0); - debug_assert!(a.max_limb_bits as u32 + k.ilog2() < F::CAPACITY); + assert!(a.max_limb_bits as u32 + k.ilog2() < F::CAPACITY); - let sum = gate.sum(ctx, a.limbs.iter().map(Existing)); - gate.is_zero(ctx, &sum) + let sum = gate.sum(ctx, a.limbs); + gate.is_zero(ctx, sum) } -// given OverflowInteger `a`, returns whether `a == 0` -pub fn assign<'v, F: PrimeField>( +/// Given ProperUint `a`, returns 1 iff every limb of `a` is zero. Returns 0 otherwise. +/// +/// It is almost always more efficient to use [`positive`] instead. +/// +/// # Assumptions +/// * `a` has nonzero number of limbs +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, -) -> AssignedValue<'v, F> { - let k = a.limbs.len(); - assert_ne!(k, 0); + ctx: &mut Context, + a: ProperUint, +) -> AssignedValue { + assert!(!a.0.is_empty()); - let mut a_limbs = a.limbs.iter(); + let mut a_limbs = a.0.into_iter(); let mut partial = gate.is_zero(ctx, a_limbs.next().unwrap()); for a_limb in a_limbs { let limb_is_zero = gate.is_zero(ctx, a_limb); - partial = gate.and(ctx, Existing(&limb_is_zero), Existing(&partial)); + partial = gate.and(ctx, limb_is_zero, partial); } partial } -pub fn crt<'v, F: PrimeField>( +/// Returns 0 or 1. Returns 1 iff the limbs of `a` are identically zero. +/// This just calls [`assign`] on the limbs. +/// +/// It is almost always more efficient to use [`positive`] instead. +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, -) -> AssignedValue<'v, F> { - let out_trunc = assign::(gate, ctx, &a.truncation); - let out_native = gate.is_zero(ctx, &a.native); - gate.and(ctx, Existing(&out_trunc), Existing(&out_native)) + ctx: &mut Context, + a: ProperCrtUint, +) -> AssignedValue { + assign(gate, ctx, ProperUint(a.0.truncation.limbs)) } diff --git a/halo2-ecc/src/bigint/big_less_than.rs b/halo2-ecc/src/bigint/big_less_than.rs index 52528870..01fe1eae 100644 --- a/halo2-ecc/src/bigint/big_less_than.rs +++ b/halo2-ecc/src/bigint/big_less_than.rs @@ -1,17 +1,17 @@ -use super::OverflowInteger; -use halo2_base::{gates::RangeInstructions, utils::PrimeField, AssignedValue, Context}; +use super::ProperUint; +use halo2_base::{gates::RangeInstructions, utils::ScalarField, AssignedValue, Context}; // given OverflowInteger's `a` and `b` of the same shape, // returns whether `a < b` -pub fn assign<'a, F: PrimeField>( +pub fn assign( range: &impl RangeInstructions, - ctx: &mut Context<'a, F>, - a: &OverflowInteger<'a, F>, - b: &OverflowInteger<'a, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, limb_bits: usize, limb_base: F, -) -> AssignedValue<'a, F> { +) -> AssignedValue { // a < b iff a - b has underflow - let (_, underflow) = super::sub::assign::(range, ctx, a, b, limb_bits, limb_base); + let (_, underflow) = super::sub::assign(range, ctx, a, b, limb_bits, limb_base); underflow } diff --git a/halo2-ecc/src/bigint/carry_mod.rs b/halo2-ecc/src/bigint/carry_mod.rs index 111f31d5..a78fd32b 100644 --- a/halo2-ecc/src/bigint/carry_mod.rs +++ b/halo2-ecc/src/bigint/carry_mod.rs @@ -1,15 +1,16 @@ -use super::{check_carry_to_zero, CRTInteger, OverflowInteger}; -use crate::halo2_proofs::circuit::Value; +use std::{cmp::max, iter}; + use halo2_base::{ gates::{range::RangeStrategy, GateInstructions, RangeInstructions}, - utils::{biguint_to_fe, decompose_bigint_option, value_to_option, PrimeField}, + utils::{decompose_bigint, BigPrimeField}, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, }; -use num_bigint::{BigInt, BigUint}; +use num_bigint::BigInt; use num_integer::Integer; use num_traits::{One, Signed}; -use std::{cmp::max, iter}; + +use super::{check_carry_to_zero, CRTInteger, OverflowInteger, ProperCrtUint, ProperUint}; // Input `a` is `CRTInteger` with `a.truncation` of length `k` with "signed" limbs // Output is `out = a (mod modulus)` as CRTInteger with @@ -19,12 +20,18 @@ use std::{cmp::max, iter}; // `out.native = (a (mod modulus)) % (native_modulus::)` // We constrain `a = out + modulus * quotient` and range check `out` and `quotient` // -// Assumption: the leading two bits (in big endian) are 1, and `abs(a) <= 2^{n * k - 1 + F::NUM_BITS - 2}` (A weaker assumption is also enough, but this is good enough for forseeable use cases) -pub fn crt<'a, F: PrimeField>( +// Assumption: the leading two bits (in big endian) are 1, +/// # Assumptions +/// * abs(a) <= 2n * k - 1 + F::NUM_BITS - 2 (A weaker assumption is also enough, but this is good enough for forseeable use cases) +/// * `native_modulus::` requires *exactly* `k = a.limbs.len()` limbs to represent + +// This is currently optimized for limbs greater than 64 bits, so we need `F` to be a `BigPrimeField` +// In the future we'll need a slightly different implementation for limbs that fit in 32 or 64 bits (e.g., `F` is Goldilocks) +pub fn crt( range: &impl RangeInstructions, // chip: &BigIntConfig, - ctx: &mut Context<'a, F>, - a: &CRTInteger<'a, F>, + ctx: &mut Context, + a: CRTInteger, k_bits: usize, // = a.len().bits() modulus: &BigInt, mod_vec: &[F], @@ -32,22 +39,12 @@ pub fn crt<'a, F: PrimeField>( limb_bits: usize, limb_bases: &[F], limb_base_big: &BigInt, -) -> CRTInteger<'a, F> { +) -> ProperCrtUint { let n = limb_bits; let k = a.truncation.limbs.len(); let trunc_len = n * k; - #[cfg(feature = "display")] - { - let key = format!("carry_mod(crt) length {k}"); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - - // safety check: - a.value - .as_ref() - .map(|a| assert!(a.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2)); - } + debug_assert!(a.value.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2); // in order for CRT method to work, we need `abs(out + modulus * quotient - a) < 2^{trunc_len - 1} * native_modulus::` // this is ensured if `0 <= out < 2^{n*k}` and @@ -55,7 +52,7 @@ pub fn crt<'a, F: PrimeField>( // which is ensured if // `abs(modulus * quotient) < 2^{trunc_len - 1 + F::NUM_BITS - 1} <= 2^{trunc_len - 1} * native_modulus:: - abs(a)` given our assumption `abs(a) <= 2^{n * k - 1 + F::NUM_BITS - 2}` let quot_max_bits = trunc_len - 1 + (F::NUM_BITS as usize) - 1 - (modulus.bits() as usize); - assert!(quot_max_bits < trunc_len); + debug_assert!(quot_max_bits < trunc_len); // Let n' <= quot_max_bits - n(k-1) - 1 // If quot[i] <= 2^n for i < k - 1 and quot[k-1] <= 2^{n'} then // quot < 2^{n(k-1)+1} + 2^{n' + n(k-1)} = (2+2^{n'}) 2^{n(k-1)} < 2^{n'+1} * 2^{n(k-1)} <= 2^{quot_max_bits - n(k-1)} * 2^{n(k-1)} @@ -69,26 +66,17 @@ pub fn crt<'a, F: PrimeField>( // we need to find `out_vec` as a proper BigInt with k limbs // we need to find `quot_vec` as a proper BigInt with k limbs - // we need to constrain that `sum_i out_vec[i] * 2^{n*i} = out_native` in `F` - // we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F` - let (out_val, out_vec, quot_vec) = if let Some(a_big) = value_to_option(a.value.as_ref()) { - let (quot_val, out_val) = a_big.div_mod_floor(modulus); + let (quot_val, out_val) = a.value.div_mod_floor(modulus); - debug_assert!(out_val < (BigInt::one() << (n * k))); - debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits)); + debug_assert!(out_val < (BigInt::one() << (n * k))); + debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits)); - ( - Value::known(out_val.clone()), - // decompose_bigint_option just throws away signed limbs in index >= k - decompose_bigint_option::(Value::known(&out_val), k, n), - decompose_bigint_option::(Value::known("_val), k, n), - ) - } else { - (Value::unknown(), vec![Value::unknown(); k], vec![Value::unknown(); k]) - }; + // decompose_bigint just throws away signed limbs in index >= k + let out_vec = decompose_bigint::(&out_val, k, n); + let quot_vec = decompose_bigint::("_val, k, n); - // let out_native = out_val.as_ref().map(|a| bigint_to_fe::(a)); - // let quot_native = quot_val.map(|a| bigint_to_fe::(&a)); + // we need to constrain that `sum_i out_vec[i] * 2^{n*i} = out_native` in `F` + // we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F` // assert!(modulus < &(BigUint::one() << (n * k))); assert_eq!(mod_vec.len(), k); @@ -107,76 +95,46 @@ pub fn crt<'a, F: PrimeField>( let mut quot_assigned: Vec> = Vec::with_capacity(k); let mut out_assigned: Vec> = Vec::with_capacity(k); let mut check_assigned: Vec> = Vec::with_capacity(k); - let mut tmp_assigned: Vec> = Vec::with_capacity(k); - // match chip.strategy { // strategies where we carry out school-book multiplication in some form: // BigIntStrategy::Simple => { - for (i, (a_limb, (quot_v, out_v))) in - a.truncation.limbs.iter().zip(quot_vec.into_iter().zip(out_vec.into_iter())).enumerate() + for (i, ((a_limb, quot_v), out_v)) in + a.truncation.limbs.into_iter().zip(quot_vec).zip(out_vec).enumerate() { - let (quot_cell, out_cell, check_cell) = { - let prod = range.gate().inner_product_left( - ctx, - quot_assigned.iter().map(|a| Existing(a)).chain(iter::once(Witness(quot_v))), - mod_vec[..=i].iter().rev().map(|c| Constant(*c)), - &mut tmp_assigned, - ); - // let gate_index = prod.column(); - - let quot_cell = tmp_assigned.pop().unwrap(); - let out_cell; - let check_cell; - // perform step 2: compute prod - a + out - let temp1 = prod.value().zip(a_limb.value()).map(|(prod, a)| *prod - a); - let check_val = temp1 + out_v; - - // This is to take care of edge case where we switch columns to handle overlap - let alloc = ctx.advice_alloc.get_mut(range.gate().context_id()).unwrap(); - if alloc.1 + 6 >= ctx.max_rows { - // edge case, we need to copy the last `prod` cell - // dbg!(*alloc); - alloc.1 = 0; - alloc.0 += 1; - range.gate().assign_region_last(ctx, [Existing(&prod)], []); + let (prod, new_quot_cell) = range.gate().inner_product_left_last( + ctx, + quot_assigned.iter().map(|a| Existing(*a)).chain(iter::once(Witness(quot_v))), + mod_vec[..=i].iter().rev().map(|c| Constant(*c)), + ); + // let gate_index = prod.column(); + + let out_cell; + let check_cell; + // perform step 2: compute prod - a + out + let temp1 = *prod.value() - a_limb.value(); + let check_val = temp1 + out_v; + + match range.strategy() { + RangeStrategy::Vertical => { + // transpose of: + // | prod | -1 | a | prod - a | 1 | out | prod - a + out + // where prod is at relative row `offset` + ctx.assign_region( + [ + Constant(-F::one()), + Existing(a_limb), + Witness(temp1), + Constant(F::one()), + Witness(out_v), + Witness(check_val), + ], + [-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call + ); + check_cell = ctx.last().unwrap(); + out_cell = ctx.get(-2); } - match range.strategy() { - RangeStrategy::Vertical => { - // transpose of: - // | prod | -1 | a | prod - a | 1 | out | prod - a + out - // where prod is at relative row `offset` - let mut assignments = range.gate().assign_region( - ctx, - [ - Constant(-F::one()), - Existing(a_limb), - Witness(temp1), - Constant(F::one()), - Witness(out_v), - Witness(check_val), - ], - [(-1, None), (2, None)], - ); - check_cell = assignments.pop().unwrap(); - out_cell = assignments.pop().unwrap(); - } - RangeStrategy::PlonkPlus => { - // | prod | a | out | prod - a + out | - // selector columns: - // | 1 | 0 | 0 | - // | 0 | -1| 1 | - let mut assignments = range.gate().assign_region( - ctx, - [Existing(a_limb), Witness(out_v), Witness(check_val)], - [(-1, Some([F::zero(), -F::one(), F::one()]))], - ); - check_cell = assignments.pop().unwrap(); - out_cell = assignments.pop().unwrap(); - } - } - (quot_cell, out_cell, check_cell) - }; - quot_assigned.push(quot_cell); + } + quot_assigned.push(new_quot_cell); out_assigned.push(out_cell); check_assigned.push(check_cell); } @@ -186,32 +144,21 @@ pub fn crt<'a, F: PrimeField>( // range check limbs of `out` are in [0, 2^n) except last limb should be in [0, 2^out_last_limb_bits) for (out_index, out_cell) in out_assigned.iter().enumerate() { let limb_bits = if out_index == k - 1 { out_last_limb_bits } else { n }; - range.range_check(ctx, out_cell, limb_bits); + range.range_check(ctx, *out_cell, limb_bits); } // range check that quot_cell in quot_assigned is in [-2^n, 2^n) except for last cell check it's in [-2^quot_last_limb_bits, 2^quot_last_limb_bits) for (q_index, quot_cell) in quot_assigned.iter().enumerate() { let limb_bits = if q_index == k - 1 { quot_last_limb_bits } else { n }; - let limb_base = if q_index == k - 1 { - biguint_to_fe(&(BigUint::one() << limb_bits)) - } else { - limb_bases[1] - }; + let limb_base = + if q_index == k - 1 { range.gate().pow_of_two()[limb_bits] } else { limb_bases[1] }; // compute quot_cell + 2^n and range check with n + 1 bits - let quot_shift = { - let out_val = quot_cell.value().map(|a| limb_base + a); - // | quot_cell | 2^n | 1 | quot_cell + 2^n | - range.gate().assign_region_last( - ctx, - [Existing(quot_cell), Constant(limb_base), Constant(F::one()), Witness(out_val)], - [(0, None)], - ) - }; - range.range_check(ctx, "_shift, limb_bits + 1); + let quot_shift = range.gate().add(ctx, *quot_cell, Constant(limb_base)); + range.range_check(ctx, quot_shift, limb_bits + 1); } - let check_overflow_int = &OverflowInteger::construct( + let check_overflow_int = OverflowInteger::new( check_assigned, max(max(limb_bits, a.truncation.max_limb_bits) + 1, 2 * n + k_bits), ); @@ -226,40 +173,25 @@ pub fn crt<'a, F: PrimeField>( limb_base_big, ); - // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` - let out_native_assigned = OverflowInteger::::evaluate( - range.gate(), - /*chip,*/ ctx, - &out_assigned, - limb_bases.iter().cloned(), - ); - // Constrain `quot_native = sum_i quot_assigned[i] * 2^{n*i}` in `F` - let quot_native_assigned = OverflowInteger::::evaluate( - range.gate(), - /*chip,*/ ctx, - "_assigned, - limb_bases.iter().cloned(), - ); + let quot_native = + OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases); - // TODO: we can save 1 cell by connecting `out_native_assigned` computation with the following: + // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` + let out_native = + OverflowInteger::evaluate_native(ctx, range.gate(), out_assigned.clone(), limb_bases); + // We save 1 cell by connecting `out_native` computation with the following: // Check `out + modulus * quotient - a = 0` in native field // | out | modulus | quotient | a | - let _native_computation = range.gate().assign_region_last( - ctx, - [ - Existing(&out_native_assigned), - Constant(mod_native), - Existing("_native_assigned), - Existing(&a.native), - ], - [(0, None)], + ctx.assign_region( + [Constant(mod_native), Existing(quot_native), Existing(a.native)], + [-1], // negative index because -1 relative offset is `out_native` assigned value ); - CRTInteger::construct( - OverflowInteger::construct(out_assigned, limb_bits), - out_native_assigned, + ProperCrtUint(CRTInteger::new( + ProperUint(out_assigned).into_overflow(limb_bits), + out_native, out_val, - ) + )) } diff --git a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs index 38453da0..6232cbdf 100644 --- a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs @@ -1,12 +1,11 @@ use super::{check_carry_to_zero, CRTInteger, OverflowInteger}; -use crate::halo2_proofs::circuit::Value; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{biguint_to_fe, decompose_bigint_option, value_to_option, PrimeField}, + utils::{decompose_bigint, BigPrimeField}, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, }; -use num_bigint::{BigInt, BigUint}; +use num_bigint::BigInt; use num_integer::Integer; use num_traits::{One, Signed, Zero}; use std::{cmp::max, iter}; @@ -14,11 +13,10 @@ use std::{cmp::max, iter}; // same as carry_mod::crt but `out = 0` so no need to range check // // Assumption: the leading two bits (in big endian) are 1, and `a.max_size <= 2^{n * k - 1 + F::NUM_BITS - 2}` (A weaker assumption is also enough) -pub fn crt<'a, F: PrimeField>( +pub fn crt( range: &impl RangeInstructions, - // chip: &BigIntConfig, - ctx: &mut Context<'a, F>, - a: &CRTInteger<'a, F>, + ctx: &mut Context, + a: CRTInteger, k_bits: usize, // = a.len().bits() modulus: &BigInt, mod_vec: &[F], @@ -31,17 +29,7 @@ pub fn crt<'a, F: PrimeField>( let k = a.truncation.limbs.len(); let trunc_len = n * k; - #[cfg(feature = "display")] - { - let key = format!("check_carry_mod(crt) length {k}"); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - - // safety check: - a.value - .as_ref() - .map(|a| assert!(a.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2)); - } + debug_assert!(a.value.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2); // see carry_mod.rs for explanation let quot_max_bits = trunc_len - 1 + (F::NUM_BITS as usize) - 1 - (modulus.bits() as usize); @@ -53,19 +41,15 @@ pub fn crt<'a, F: PrimeField>( // we need to find `quot_native` as a native F element // we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F` - let quot_vec = if let Some(a_big) = value_to_option(a.value.as_ref()) { - let (quot_val, _out_val) = a_big.div_mod_floor(modulus); + let (quot_val, _out_val) = a.value.div_mod_floor(modulus); - // only perform safety checks in display mode so we can turn them off in production - debug_assert_eq!(_out_val, BigInt::zero()); - debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits)); + // only perform safety checks in debug mode + debug_assert_eq!(_out_val, BigInt::zero()); + debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits)); - decompose_bigint_option::(Value::known("_val), k, n) - } else { - vec![Value::unknown(); k] - }; + let quot_vec = decompose_bigint::("_val, k, n); - //assert!(modulus < &(BigUint::one() << (n * k))); + debug_assert!(modulus < &(BigInt::one() << (n * k))); // We need to show `modulus * quotient - a` is: // - congruent to `0 (mod 2^trunc_len)` @@ -81,43 +65,24 @@ pub fn crt<'a, F: PrimeField>( let mut quot_assigned: Vec> = Vec::with_capacity(k); let mut check_assigned: Vec> = Vec::with_capacity(k); - let mut tmp_assigned: Vec> = Vec::with_capacity(k); // match chip.strategy { // BigIntStrategy::Simple => { - for (i, (a_limb, quot_v)) in a.truncation.limbs.iter().zip(quot_vec.into_iter()).enumerate() { - let (quot_cell, check_cell) = { - let prod = range.gate().inner_product_left( - ctx, - quot_assigned.iter().map(Existing).chain(iter::once(Witness(quot_v))), - mod_vec[0..=i].iter().rev().map(|c| Constant(*c)), - &mut tmp_assigned, - ); - - let quot_cell = tmp_assigned.pop().unwrap(); - // perform step 2: compute prod - a + out - // transpose of: - // | prod | -1 | a | prod - a | - - // This is to take care of edge case where we switch columns to handle overlap - let alloc = ctx.advice_alloc.get_mut(range.gate().context_id()).unwrap(); - if alloc.1 + 3 >= ctx.max_rows { - // edge case, we need to copy the last `prod` cell - alloc.1 = 0; - alloc.0 += 1; - range.gate().assign_region_last(ctx, vec![Existing(&prod)], vec![]); - } - - let check_val = prod.value().zip(a_limb.value()).map(|(prod, a)| *prod - a); - let check_cell = range.gate().assign_region_last( - ctx, - vec![Constant(-F::one()), Existing(a_limb), Witness(check_val)], - vec![(-1, None)], - ); - - (quot_cell, check_cell) - }; - quot_assigned.push(quot_cell); + for (i, (a_limb, quot_v)) in a.truncation.limbs.into_iter().zip(quot_vec).enumerate() { + let (prod, new_quot_cell) = range.gate().inner_product_left_last( + ctx, + quot_assigned.iter().map(|x| Existing(*x)).chain(iter::once(Witness(quot_v))), + mod_vec[0..=i].iter().rev().map(|c| Constant(*c)), + ); + + // perform step 2: compute prod - a + out + // transpose of: + // | prod | -1 | a | prod - a | + let check_val = *prod.value() - a_limb.value(); + let check_cell = ctx + .assign_region_last([Constant(-F::one()), Existing(a_limb), Witness(check_val)], [-1]); + + quot_assigned.push(new_quot_cell); check_assigned.push(check_cell); } // } @@ -126,35 +91,16 @@ pub fn crt<'a, F: PrimeField>( // range check that quot_cell in quot_assigned is in [-2^n, 2^n) except for last cell check it's in [-2^quot_last_limb_bits, 2^quot_last_limb_bits) for (q_index, quot_cell) in quot_assigned.iter().enumerate() { let limb_bits = if q_index == k - 1 { quot_last_limb_bits } else { n }; - let limb_base = if q_index == k - 1 { - biguint_to_fe(&(BigUint::one() << limb_bits)) - } else { - limb_bases[1] - }; + let limb_base = + if q_index == k - 1 { range.gate().pow_of_two()[limb_bits] } else { limb_bases[1] }; // compute quot_cell + 2^n and range check with n + 1 bits - let quot_shift = { - // TODO: unnecessary clone - let out_val = quot_cell.value().map(|a| limb_base + a); - // | quot_cell | 2^n | 1 | quot_cell + 2^n | - range.gate().assign_region_last( - ctx, - vec![ - Existing(quot_cell), - Constant(limb_base), - Constant(F::one()), - Witness(out_val), - ], - vec![(0, None)], - ) - }; - range.range_check(ctx, "_shift, limb_bits + 1); + let quot_shift = range.gate().add(ctx, *quot_cell, Constant(limb_base)); + range.range_check(ctx, quot_shift, limb_bits + 1); } - let check_overflow_int = &OverflowInteger::construct( - check_assigned, - max(a.truncation.max_limb_bits, 2 * n + k_bits), - ); + let check_overflow_int = + OverflowInteger::new(check_assigned, max(a.truncation.max_limb_bits, 2 * n + k_bits)); // check that `modulus * quotient - a == 0 mod 2^{trunc_len}` after carry check_carry_to_zero::truncate::( @@ -167,23 +113,13 @@ pub fn crt<'a, F: PrimeField>( ); // Constrain `quot_native = sum_i out_assigned[i] * 2^{n*i}` in `F` - let quot_native_assigned = OverflowInteger::::evaluate( - range.gate(), - /*chip,*/ ctx, - "_assigned, - limb_bases.iter().cloned(), - ); + let quot_native = + OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases); // Check `0 + modulus * quotient - a = 0` in native field // | 0 | modulus | quotient | a | - let _native_computation = range.gate().assign_region( - ctx, - vec![ - Constant(F::zero()), - Constant(mod_native), - Existing("_native_assigned), - Existing(&a.native), - ], - vec![(0, None)], + ctx.assign_region( + [Constant(F::zero()), Constant(mod_native), Existing(quot_native), Existing(a.native)], + [0], ); } diff --git a/halo2-ecc/src/bigint/check_carry_to_zero.rs b/halo2-ecc/src/bigint/check_carry_to_zero.rs index e718b128..fa2f5648 100644 --- a/halo2-ecc/src/bigint/check_carry_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_to_zero.rs @@ -1,13 +1,11 @@ use super::OverflowInteger; -use crate::halo2_proofs::circuit::Value; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{bigint_to_fe, biguint_to_fe, fe_to_bigint, value_to_option, PrimeField}, + utils::{bigint_to_fe, fe_to_bigint, BigPrimeField}, Context, QuantumCell::{Constant, Existing, Witness}, }; -use num_bigint::{BigInt, BigUint}; -use num_traits::One; +use num_bigint::BigInt; // check that `a` carries to `0 mod 2^{a.limb_bits * a.limbs.len()}` // same as `assign` above except we need to provide `c_{k - 1}` witness as well @@ -26,10 +24,10 @@ use num_traits::One; // a_i * 2^{n*w} + a_{i - 1} * 2^{n*(w-1)} + ... + a_{i - w} + c_{i - w - 1} = c_i * 2^{n*(w+1)} // which is valid as long as `(m - n + EPSILON) + n * (w+1) < native_modulus::().bits() - 1` // so we only need to range check `c_i` every `w + 1` steps, starting with `i = w` -pub fn truncate<'a, F: PrimeField>( +pub fn truncate( range: &impl RangeInstructions, - ctx: &mut Context<'a, F>, - a: &OverflowInteger<'a, F>, + ctx: &mut Context, + a: OverflowInteger, limb_bits: usize, limb_base: F, limb_base_big: &BigInt, @@ -37,27 +35,16 @@ pub fn truncate<'a, F: PrimeField>( let k = a.limbs.len(); let max_limb_bits = a.max_limb_bits; - #[cfg(feature = "display")] - { - let key = format!("check_carry_to_zero(trunc) length {k}"); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - } - - let mut carries: Vec> = Vec::with_capacity(k); + let mut carries = Vec::with_capacity(k); for a_limb in a.limbs.iter() { - let a_val = a_limb.value(); - let carry = a_val.map(|a_fe| { - let a_val_big = fe_to_bigint(a_fe); - if carries.is_empty() { - // warning: using >> on negative integer produces undesired effect - a_val_big / limb_base_big - } else { - let carry_val = value_to_option(carries.last().unwrap().as_ref()).unwrap(); - (a_val_big + carry_val) / limb_base_big - } - }); + let a_val_big = fe_to_bigint(a_limb.value()); + let carry = if let Some(carry_val) = carries.last() { + (a_val_big + carry_val) / limb_base_big + } else { + // warning: using >> on negative integer produces undesired effect + a_val_big / limb_base_big + }; carries.push(carry); } @@ -69,44 +56,30 @@ pub fn truncate<'a, F: PrimeField>( // `window = w + 1` valid as long as `range_bits + n * (w+1) < native_modulus::().bits() - 1` // let window = (F::NUM_BITS as usize - 2 - range_bits) / limb_bits; // assert!(window > 0); + // In practice, we are currently always using window = 1 so the above is commented out - // TODO: maybe we can also cache these bigints - let shift_val = biguint_to_fe::(&(BigUint::one() << range_bits)); + let shift_val = range.gate().pow_of_two()[range_bits]; // let num_windows = (k - 1) / window + 1; // = ((k - 1) - (window - 1) + window - 1) / window + 1; let mut previous = None; - for (a_limb, carry) in a.limbs.iter().zip(carries.iter()) { - let neg_carry_val = carry.as_ref().map(|c| bigint_to_fe::(&-c)); - let neg_carry = range - .gate() - .assign_region( - ctx, - vec![ - Existing(a_limb), - Witness(neg_carry_val), - Constant(limb_base), - previous.as_ref().map(Existing).unwrap_or_else(|| Constant(F::zero())), - ], - vec![(0, None)], - ) - .into_iter() - .nth(1) - .unwrap(); + for (a_limb, carry) in a.limbs.into_iter().zip(carries.into_iter()) { + let neg_carry_val = bigint_to_fe(&-carry); + ctx.assign_region( + [ + Existing(a_limb), + Witness(neg_carry_val), + Constant(limb_base), + previous.map(Existing).unwrap_or_else(|| Constant(F::zero())), + ], + [0], + ); + let neg_carry = ctx.get(-3); // i in 0..num_windows { // let idx = std::cmp::min(window * i + window - 1, k - 1); // let carry_cell = &neg_carry_assignments[idx]; - let shifted_carry = { - let shift_carry_val = Value::known(shift_val) + neg_carry.value(); - let cells = vec![ - Existing(&neg_carry), - Constant(F::one()), - Constant(shift_val), - Witness(shift_carry_val), - ]; - range.gate().assign_region_last(ctx, cells, vec![(0, None)]) - }; - range.range_check(ctx, &shifted_carry, range_bits + 1); + let shifted_carry = range.gate().add(ctx, neg_carry, Constant(shift_val)); + range.range_check(ctx, shifted_carry, range_bits + 1); previous = Some(neg_carry); } diff --git a/halo2-ecc/src/bigint/mod.rs b/halo2-ecc/src/bigint/mod.rs index 44b65a0b..ea14b127 100644 --- a/halo2-ecc/src/bigint/mod.rs +++ b/halo2-ecc/src/bigint/mod.rs @@ -1,17 +1,11 @@ -use crate::halo2_proofs::{ - circuit::{Cell, Value}, - plonk::ConstraintSystem, -}; use halo2_base::{ - gates::{flex_gate::FlexGateConfig, GateInstructions}, - utils::{biguint_to_fe, decompose_biguint, fe_to_biguint, PrimeField}, + gates::flex_gate::GateInstructions, + utils::{biguint_to_fe, decompose_biguint, fe_to_biguint, BigPrimeField, ScalarField}, AssignedValue, Context, - QuantumCell::{Constant, Existing, Witness}, + QuantumCell::Constant, }; -use itertools::Itertools; use num_bigint::{BigInt, BigUint}; use num_traits::Zero; -use std::{marker::PhantomData, rc::Rc}; pub mod add_no_carry; pub mod big_is_equal; @@ -29,8 +23,7 @@ pub mod select_by_indicator; pub mod sub; pub mod sub_no_carry; -#[derive(Clone, Debug, PartialEq)] -#[derive(Default)] +#[derive(Clone, Debug, PartialEq, Default)] pub enum BigIntStrategy { // use existing gates #[default] @@ -40,54 +33,91 @@ pub enum BigIntStrategy { // CustomVerticalShort, } - - #[derive(Clone, Debug)] -pub struct OverflowInteger<'v, F: PrimeField> { - pub limbs: Vec>, +pub struct OverflowInteger { + pub limbs: Vec>, // max bits of a limb, ignoring sign pub max_limb_bits: usize, // the standard limb bit that we use for pow of two limb base - to reduce overhead we just assume this is inferred from context (e.g., the chip stores it), so we stop storing it here // pub limb_bits: usize, } -impl<'v, F: PrimeField> OverflowInteger<'v, F> { - pub fn construct(limbs: Vec>, max_limb_bits: usize) -> Self { +impl OverflowInteger { + pub fn new(limbs: Vec>, max_limb_bits: usize) -> Self { Self { limbs, max_limb_bits } } // convenience function for testing #[cfg(test)] - pub fn to_bigint(&self, limb_bits: usize) -> Value { + pub fn to_bigint(&self, limb_bits: usize) -> BigInt + where + F: BigPrimeField, + { use halo2_base::utils::fe_to_bigint; - self.limbs.iter().rev().fold(Value::known(BigInt::zero()), |acc, acell| { - acc.zip(acell.value()).map(|(acc, x)| (acc << limb_bits) + fe_to_bigint(x)) - }) + self.limbs + .iter() + .rev() + .fold(BigInt::zero(), |acc, acell| (acc << limb_bits) + fe_to_bigint(acell.value())) + } + + /// Computes `sum_i limbs[i] * limb_bases[i]` in native field `F`. + /// In practice assumes `limb_bases[i] = 2^{limb_bits * i}`. + pub fn evaluate_native( + ctx: &mut Context, + gate: &impl GateInstructions, + limbs: impl IntoIterator>, + limb_bases: &[F], + ) -> AssignedValue { + // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` + gate.inner_product(ctx, limbs, limb_bases.iter().map(|c| Constant(*c))) + } +} + +/// Safe wrapper around a BigUint represented as a vector of limbs in **little endian**. +/// The underlying BigUint is represented by +/// sumi limbs\[i\] * 2limb_bits * i +/// +/// To save memory we do not store the `limb_bits` and it must be inferred from context. +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct ProperUint(pub(crate) Vec>); + +impl ProperUint { + pub fn limbs(&self) -> &[AssignedValue] { + self.0.as_slice() + } + + pub fn into_overflow(self, limb_bits: usize) -> OverflowInteger { + OverflowInteger::new(self.0, limb_bits) } - pub fn evaluate( + /// Computes `sum_i limbs[i] * limb_bases[i]` in native field `F`. + /// In practice assumes `limb_bases[i] = 2^{limb_bits * i}`. + /// + /// Assumes that `value` is the underlying BigUint value represented by `self`. + pub fn into_crt( + self, + ctx: &mut Context, gate: &impl GateInstructions, - // chip: &BigIntConfig, - ctx: &mut Context<'_, F>, - limbs: &[AssignedValue<'v, F>], - limb_bases: impl IntoIterator, - ) -> AssignedValue<'v, F> { + value: BigUint, + limb_bases: &[F], + limb_bits: usize, + ) -> ProperCrtUint { // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` - gate.inner_product( - ctx, - limbs.iter().map(|a| Existing(a)), - limb_bases.into_iter().map(|c| Constant(c)), - ) + let native = + OverflowInteger::evaluate_native(ctx, gate, self.0.iter().copied(), limb_bases); + ProperCrtUint(CRTInteger::new(self.into_overflow(limb_bits), native, value.into())) } } +#[repr(transparent)] #[derive(Clone, Debug)] -pub struct FixedOverflowInteger { +pub struct FixedOverflowInteger { pub limbs: Vec, } -impl FixedOverflowInteger { +impl FixedOverflowInteger { pub fn construct(limbs: Vec) -> Self { Self { limbs } } @@ -107,42 +137,37 @@ impl FixedOverflowInteger { .fold(BigUint::zero(), |acc, x| (acc << limb_bits) + fe_to_biguint(x)) } - pub fn assign<'v>( - self, - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - limb_bits: usize, - ) -> OverflowInteger<'v, F> { - let assigned_limbs = gate.assign_region(ctx, self.limbs.into_iter().map(Constant), vec![]); - OverflowInteger::construct(assigned_limbs, limb_bits) + pub fn assign(self, ctx: &mut Context) -> ProperUint { + let assigned_limbs = self.limbs.into_iter().map(|limb| ctx.load_constant(limb)).collect(); + ProperUint(assigned_limbs) } /// only use case is when coeffs has only a single 1, rest are 0 - pub fn select_by_indicator<'v>( + pub fn select_by_indicator( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, + ctx: &mut Context, a: &[Self], - coeffs: &[AssignedValue<'v, F>], + coeffs: &[AssignedValue], limb_bits: usize, - ) -> OverflowInteger<'v, F> { + ) -> OverflowInteger { let k = a[0].limbs.len(); let out_limbs = (0..k) .map(|idx| { let int_limbs = a.iter().map(|a| Constant(a.limbs[idx])); - gate.select_by_indicator(ctx, int_limbs, coeffs.iter()) + gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied()) }) .collect(); - OverflowInteger::construct(out_limbs, limb_bits) + OverflowInteger::new(out_limbs, limb_bits) } } #[derive(Clone, Debug)] -pub struct CRTInteger<'v, F: PrimeField> { +pub struct CRTInteger { // keep track of an integer `a` using CRT as `a mod 2^t` and `a mod n` // where `t = truncation.limbs.len() * truncation.limb_bits` - // `n = modulus::` + // `n = modulus::` // `value` is the actual integer value we want to keep track of // we allow `value` to be a signed BigInt @@ -151,31 +176,96 @@ pub struct CRTInteger<'v, F: PrimeField> { // the IMPLICIT ASSUMPTION: `value (mod 2^t) = truncation` && `value (mod n) = native` // this struct should only be used if the implicit assumption above is satisfied - pub truncation: OverflowInteger<'v, F>, - pub native: AssignedValue<'v, F>, - pub value: Value, + pub truncation: OverflowInteger, + pub native: AssignedValue, + pub value: BigInt, +} + +impl AsRef> for CRTInteger { + fn as_ref(&self) -> &CRTInteger { + self + } +} + +// Cloning all the time impacts readability so we'll just implement From<&T> for T +impl<'a, F: ScalarField> From<&'a CRTInteger> for CRTInteger { + fn from(x: &'a CRTInteger) -> Self { + x.clone() + } } -impl<'v, F: PrimeField> CRTInteger<'v, F> { - pub fn construct( - truncation: OverflowInteger<'v, F>, - native: AssignedValue<'v, F>, - value: Value, - ) -> Self { +impl CRTInteger { + pub fn new(truncation: OverflowInteger, native: AssignedValue, value: BigInt) -> Self { Self { truncation, native, value } } - pub fn native(&self) -> &AssignedValue<'v, F> { + pub fn native(&self) -> &AssignedValue { &self.native } - pub fn limbs(&self) -> &[AssignedValue<'v, F>] { + pub fn limbs(&self) -> &[AssignedValue] { self.truncation.limbs.as_slice() } } +/// Safe wrapper for representing a BigUint as a [`CRTInteger`] whose underlying BigUint value is in `[0, 2^t)` +/// where `t = truncation.limbs.len() * limb_bits`. This struct guarantees that +/// * each `truncation.limbs[i]` is ranged checked to be in `[0, 2^limb_bits)`, +/// * `native` is the evaluation of `sum_i truncation.limbs[i] * 2^{limb_bits * i} (mod modulus::)` in the native field `F` +/// * `value` is equal to `sum_i truncation.limbs[i] * 2^{limb_bits * i}` as integers +/// +/// Note this means `native` and `value` are completely determined by `truncation`. However, we still store them explicitly for convenience. +#[repr(transparent)] #[derive(Clone, Debug)] -pub struct FixedCRTInteger { +pub struct ProperCrtUint(pub(crate) CRTInteger); + +impl AsRef> for ProperCrtUint { + fn as_ref(&self) -> &CRTInteger { + &self.0 + } +} + +impl<'a, F: ScalarField> From<&'a ProperCrtUint> for ProperCrtUint { + fn from(x: &'a ProperCrtUint) -> Self { + x.clone() + } +} + +// cannot blanket implement From> for T because of Rust +impl From> for CRTInteger { + fn from(x: ProperCrtUint) -> Self { + x.0 + } +} + +impl<'a, F: ScalarField> From<&'a ProperCrtUint> for CRTInteger { + fn from(x: &'a ProperCrtUint) -> Self { + x.0.clone() + } +} + +impl From> for ProperUint { + fn from(x: ProperCrtUint) -> Self { + ProperUint(x.0.truncation.limbs) + } +} + +impl ProperCrtUint { + pub fn limbs(&self) -> &[AssignedValue] { + self.0.limbs() + } + + pub fn native(&self) -> &AssignedValue { + self.0.native() + } + + pub fn value(&self) -> BigUint { + self.0.value.to_biguint().expect("Value of proper uint should not be negative") + } +} + +#[derive(Clone, Debug)] +pub struct FixedCRTInteger { // keep track of an integer `a` using CRT as `a mod 2^t` and `a mod n` // where `t = truncation.limbs.len() * truncation.limb_bits` // `n = modulus::` @@ -191,15 +281,8 @@ pub struct FixedCRTInteger { pub value: BigUint, } -#[derive(Clone, Debug)] -pub struct FixedAssignedCRTInteger { - pub truncation: FixedOverflowInteger, - pub limb_fixed_cells: Vec, - pub value: BigUint, -} - -impl FixedCRTInteger { - pub fn construct(truncation: FixedOverflowInteger, value: BigUint) -> Self { +impl FixedCRTInteger { + pub fn new(truncation: FixedOverflowInteger, value: BigUint) -> Self { Self { truncation, value } } @@ -210,90 +293,14 @@ impl FixedCRTInteger { Self { truncation, value } } - pub fn assign<'a>( + pub fn assign( self, - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, + ctx: &mut Context, limb_bits: usize, native_modulus: &BigUint, - ) -> CRTInteger<'a, F> { - let assigned_truncation = self.truncation.assign(gate, ctx, limb_bits); - let assigned_native = { - let native_cells = vec![Constant(biguint_to_fe(&(&self.value % native_modulus)))]; - gate.assign_region_last(ctx, native_cells, vec![]) - }; - CRTInteger::construct(assigned_truncation, assigned_native, Value::known(self.value.into())) - } - - pub fn assign_without_caching<'a>( - self, - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - limb_bits: usize, - native_modulus: &BigUint, - ) -> CRTInteger<'a, F> { - let fixed_cells = self - .truncation - .limbs - .iter() - .map(|limb| ctx.assign_fixed_without_caching(*limb)) - .collect_vec(); - let assigned_limbs = gate.assign_region( - ctx, - self.truncation.limbs.into_iter().map(|v| Witness(Value::known(v))), - vec![], - ); - for (cell, acell) in fixed_cells.iter().zip(assigned_limbs.iter()) { - #[cfg(feature = "halo2-axiom")] - ctx.region.constrain_equal(cell, acell.cell()); - #[cfg(feature = "halo2-pse")] - ctx.region.constrain_equal(*cell, acell.cell()).unwrap(); - } - let assigned_native = { - let native_val = biguint_to_fe(&(&self.value % native_modulus)); - let cell = ctx.assign_fixed_without_caching(native_val); - let acell = - gate.assign_region_last(ctx, vec![Witness(Value::known(native_val))], vec![]); - - #[cfg(feature = "halo2-axiom")] - ctx.region.constrain_equal(&cell, acell.cell()); - #[cfg(feature = "halo2-pse")] - ctx.region.constrain_equal(cell, acell.cell()).unwrap(); - - acell - }; - CRTInteger::construct( - OverflowInteger::construct(assigned_limbs, limb_bits), - assigned_native, - Value::known(self.value.into()), - ) - } -} - -#[derive(Clone, Debug, Default)] -#[allow(dead_code)] -pub struct BigIntConfig { - // everything is empty if strategy is `Simple` or `SimplePlus` - strategy: BigIntStrategy, - context_id: Rc, - _marker: PhantomData, -} - -impl BigIntConfig { - pub fn configure( - _meta: &mut ConstraintSystem, - strategy: BigIntStrategy, - _limb_bits: usize, - _num_limbs: usize, - _gate: &FlexGateConfig, - context_id: String, - ) -> Self { - // let mut q_dot_constant = HashMap::new(); - /* - match strategy { - _ => {} - } - */ - Self { strategy, _marker: PhantomData, context_id: Rc::new(context_id) } + ) -> ProperCrtUint { + let assigned_truncation = self.truncation.assign(ctx).into_overflow(limb_bits); + let assigned_native = ctx.load_constant(biguint_to_fe(&(&self.value % native_modulus))); + ProperCrtUint(CRTInteger::new(assigned_truncation, assigned_native, self.value.into())) } } diff --git a/halo2-ecc/src/bigint/mul_no_carry.rs b/halo2-ecc/src/bigint/mul_no_carry.rs index 637c17e6..aa174c3d 100644 --- a/halo2-ecc/src/bigint/mul_no_carry.rs +++ b/halo2-ecc/src/bigint/mul_no_carry.rs @@ -1,53 +1,49 @@ use super::{CRTInteger, OverflowInteger}; -use halo2_base::{gates::GateInstructions, utils::PrimeField, Context, QuantumCell::Existing}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, Context, QuantumCell::Existing}; -pub fn truncate<'v, F: PrimeField>( +/// # Assumptions +/// * `a` and `b` have the same number of limbs `k` +/// * `k` is nonzero +/// * `num_limbs_log2_ceil = log2_ceil(k)` +/// * `log2_ceil(k) + a.max_limb_bits + b.max_limb_bits <= F::NUM_BITS as usize - 2` +pub fn truncate( gate: &impl GateInstructions, - // _chip: &BigIntConfig, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, num_limbs_log2_ceil: usize, -) -> OverflowInteger<'v, F> { +) -> OverflowInteger { let k = a.limbs.len(); - assert!(k > 0); assert_eq!(k, b.limbs.len()); + debug_assert!(k > 0); - #[cfg(feature = "display")] - { - let key = format!("mul_no_carry(truncate) length {k}"); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - - assert!( - num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits <= F::NUM_BITS as usize - 2 - ); - } + debug_assert!( + num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits <= F::NUM_BITS as usize - 2 + ); let out_limbs = (0..k) .map(|i| { gate.inner_product( ctx, - a.limbs[..=i].iter().map(Existing), - b.limbs[..=i].iter().rev().map(Existing), + a.limbs[..=i].iter().copied(), + b.limbs[..=i].iter().rev().map(|x| Existing(*x)), ) }) .collect(); - OverflowInteger::construct(out_limbs, num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits) + OverflowInteger::new(out_limbs, num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits) } -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - // chip: &BigIntConfig, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, num_limbs_log2_ceil: usize, -) -> CRTInteger<'v, F> { - let out_trunc = truncate::(gate, ctx, &a.truncation, &b.truncation, num_limbs_log2_ceil); - let out_native = gate.mul(ctx, Existing(&a.native), Existing(&b.native)); - let out_val = a.value.as_ref() * b.value.as_ref(); +) -> CRTInteger { + let out_trunc = truncate::(gate, ctx, a.truncation, b.truncation, num_limbs_log2_ceil); + let out_native = gate.mul(ctx, a.native, b.native); + let out_val = a.value * b.value; - CRTInteger::construct(out_trunc, out_native, out_val) + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/negative.rs b/halo2-ecc/src/bigint/negative.rs index 60183c3f..74e61da1 100644 --- a/halo2-ecc/src/bigint/negative.rs +++ b/halo2-ecc/src/bigint/negative.rs @@ -1,11 +1,11 @@ use super::OverflowInteger; -use halo2_base::{gates::GateInstructions, utils::PrimeField, Context, QuantumCell::Existing}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, Context}; -pub fn assign<'v, F: PrimeField>( +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, -) -> OverflowInteger<'v, F> { - let out_limbs = a.limbs.iter().map(|limb| gate.neg(ctx, Existing(limb))).collect(); - OverflowInteger::construct(out_limbs, a.max_limb_bits) + ctx: &mut Context, + a: OverflowInteger, +) -> OverflowInteger { + let out_limbs = a.limbs.into_iter().map(|limb| gate.neg(ctx, limb)).collect(); + OverflowInteger::new(out_limbs, a.max_limb_bits) } diff --git a/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs b/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs index 1c64e24f..5c818453 100644 --- a/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs +++ b/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs @@ -1,49 +1,47 @@ use super::{CRTInteger, OverflowInteger}; use halo2_base::{ gates::GateInstructions, - utils::{log2_ceil, PrimeField}, + utils::{log2_ceil, ScalarField}, Context, - QuantumCell::{Constant, Existing, Witness}, + QuantumCell::Constant, }; +use itertools::Itertools; use std::cmp::max; /// compute a * c + b = b + a * c +/// +/// # Assumptions +/// * `a, b` have same number of limbs +/// * Number of limbs is nonzero +/// * `c_log2_ceil = log2_ceil(c)` where `c` is the BigUint value of `c_f` // this is uniquely suited for our simple gate -pub fn assign<'v, F: PrimeField>( +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, c_f: F, c_log2_ceil: usize, -) -> OverflowInteger<'v, F> { - assert_eq!(a.limbs.len(), b.limbs.len()); - +) -> OverflowInteger { let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(a_limb, b_limb)| { - let out_val = a_limb.value().zip(b_limb.value()).map(|(a, b)| c_f * a + b); - gate.assign_region_last( - ctx, - vec![Existing(b_limb), Existing(a_limb), Constant(c_f), Witness(out_val)], - vec![(0, None)], - ) - }) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.mul_add(ctx, a_limb, Constant(c_f), b_limb)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits + c_log2_ceil, b.max_limb_bits) + 1) + OverflowInteger::new(out_limbs, max(a.max_limb_bits + c_log2_ceil, b.max_limb_bits) + 1) } -pub fn crt<'v, F: PrimeField>( +/// compute a * c + b = b + a * c +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, c: i64, -) -> CRTInteger<'v, F> { - assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); +) -> CRTInteger { + debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); let (c_f, c_abs) = if c >= 0 { let c_abs = u64::try_from(c).unwrap(); @@ -53,15 +51,8 @@ pub fn crt<'v, F: PrimeField>( (-F::from(c_abs), c_abs) }; - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation, c_f, log2_ceil(c_abs)); - let out_native = { - let out_val = b.native.value().zip(a.native.value()).map(|(b, a)| c_f * a + b); - gate.assign_region_last( - ctx, - vec![Existing(&b.native), Existing(&a.native), Constant(c_f), Witness(out_val)], - vec![(0, None)], - ) - }; - let out_val = a.value.as_ref().zip(b.value.as_ref()).map(|(a, b)| a * c + b); - CRTInteger::construct(out_trunc, out_native, out_val) + let out_trunc = assign(gate, ctx, a.truncation, b.truncation, c_f, log2_ceil(c_abs)); + let out_native = gate.mul_add(ctx, a.native, Constant(c_f), b.native); + let out_val = a.value * c + b.value; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/scalar_mul_no_carry.rs b/halo2-ecc/src/bigint/scalar_mul_no_carry.rs index 4aff4b0c..fdbc4058 100644 --- a/halo2-ecc/src/bigint/scalar_mul_no_carry.rs +++ b/halo2-ecc/src/bigint/scalar_mul_no_carry.rs @@ -1,29 +1,28 @@ use super::{CRTInteger, OverflowInteger}; use halo2_base::{ gates::GateInstructions, - utils::{log2_ceil, PrimeField}, + utils::{log2_ceil, ScalarField}, Context, - QuantumCell::{Constant, Existing}, + QuantumCell::Constant, }; -pub fn assign<'v, F: PrimeField>( +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, + ctx: &mut Context, + a: OverflowInteger, c_f: F, c_log2_ceil: usize, -) -> OverflowInteger<'v, F> { - let out_limbs = - a.limbs.iter().map(|limb| gate.mul(ctx, Existing(limb), Constant(c_f))).collect(); - OverflowInteger::construct(out_limbs, a.max_limb_bits + c_log2_ceil) +) -> OverflowInteger { + let out_limbs = a.limbs.into_iter().map(|limb| gate.mul(ctx, limb, Constant(c_f))).collect(); + OverflowInteger::new(out_limbs, a.max_limb_bits + c_log2_ceil) } -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, + ctx: &mut Context, + a: CRTInteger, c: i64, -) -> CRTInteger<'v, F> { +) -> CRTInteger { let (c_f, c_abs) = if c >= 0 { let c_abs = u64::try_from(c).unwrap(); (F::from(c_abs), c_abs) @@ -32,19 +31,9 @@ pub fn crt<'v, F: PrimeField>( (-F::from(c_abs), c_abs) }; - let out_limbs = a - .truncation - .limbs - .iter() - .map(|limb| gate.mul(ctx, Existing(limb), Constant(c_f))) - .collect(); + let out_overflow = assign(gate, ctx, a.truncation, c_f, log2_ceil(c_abs)); + let out_native = gate.mul(ctx, a.native, Constant(c_f)); + let out_val = a.value * c; - let out_native = gate.mul(ctx, Existing(&a.native), Constant(c_f)); - let out_val = a.value.as_ref().map(|a| a * c); - - CRTInteger::construct( - OverflowInteger::construct(out_limbs, a.truncation.max_limb_bits + log2_ceil(c_abs)), - out_native, - out_val, - ) + CRTInteger::new(out_overflow, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/select.rs b/halo2-ecc/src/bigint/select.rs index aa296164..65fd7333 100644 --- a/halo2-ecc/src/bigint/select.rs +++ b/halo2-ecc/src/bigint/select.rs @@ -1,55 +1,50 @@ use super::{CRTInteger, OverflowInteger}; -use halo2_base::{ - gates::GateInstructions, utils::PrimeField, AssignedValue, Context, QuantumCell::Existing, -}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use itertools::Itertools; use std::cmp::max; -pub fn assign<'v, F: PrimeField>( +/// # Assumptions +/// * `a, b` have same number of limbs +/// * Number of limbs is nonzero +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, - sel: &AssignedValue<'v, F>, -) -> OverflowInteger<'v, F> { - assert_eq!(a.limbs.len(), b.limbs.len()); + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, + sel: AssignedValue, +) -> OverflowInteger { let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(a_limb, b_limb)| gate.select(ctx, Existing(a_limb), Existing(b_limb), Existing(sel))) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits)) + OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits)) } -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - sel: &AssignedValue<'v, F>, -) -> CRTInteger<'v, F> { - assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, + sel: AssignedValue, +) -> CRTInteger { + debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); let out_limbs = a .truncation .limbs - .iter() - .zip(b.truncation.limbs.iter()) - .map(|(a_limb, b_limb)| gate.select(ctx, Existing(a_limb), Existing(b_limb), Existing(sel))) + .into_iter() + .zip_eq(b.truncation.limbs) + .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel)) .collect(); - let out_trunc = OverflowInteger::construct( + let out_trunc = OverflowInteger::new( out_limbs, max(a.truncation.max_limb_bits, b.truncation.max_limb_bits), ); - let out_native = gate.select(ctx, Existing(&a.native), Existing(&b.native), Existing(sel)); - let out_val = a.value.as_ref().zip(b.value.as_ref()).zip(sel.value()).map(|((a, b), s)| { - if s.is_zero_vartime() { - b.clone() - } else { - a.clone() - } - }); - CRTInteger::construct(out_trunc, out_native, out_val) + let out_native = gate.select(ctx, a.native, b.native, sel); + let out_val = if sel.value().is_zero_vartime() { b.value } else { a.value }; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/select_by_indicator.rs b/halo2-ecc/src/bigint/select_by_indicator.rs index 87597804..d1658d04 100644 --- a/halo2-ecc/src/bigint/select_by_indicator.rs +++ b/halo2-ecc/src/bigint/select_by_indicator.rs @@ -1,69 +1,69 @@ use super::{CRTInteger, OverflowInteger}; -use crate::halo2_proofs::circuit::Value; -use halo2_base::{ - gates::GateInstructions, utils::PrimeField, AssignedValue, Context, QuantumCell::Existing, -}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; use num_bigint::BigInt; use num_traits::Zero; use std::cmp::max; /// only use case is when coeffs has only a single 1, rest are 0 -pub fn assign<'v, F: PrimeField>( +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &[OverflowInteger<'v, F>], - coeffs: &[AssignedValue<'v, F>], -) -> OverflowInteger<'v, F> { + ctx: &mut Context, + a: &[OverflowInteger], + coeffs: &[AssignedValue], +) -> OverflowInteger { let k = a[0].limbs.len(); let out_limbs = (0..k) .map(|idx| { - let int_limbs = a.iter().map(|a| Existing(&a.limbs[idx])); - gate.select_by_indicator(ctx, int_limbs, coeffs.iter()) + let int_limbs = a.iter().map(|a| a.limbs[idx]); + gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied()) }) .collect(); let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.max_limb_bits)); - OverflowInteger::construct(out_limbs, max_limb_bits) + OverflowInteger::new(out_limbs, max_limb_bits) } /// only use case is when coeffs has only a single 1, rest are 0 -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &[CRTInteger<'v, F>], - coeffs: &[AssignedValue<'v, F>], + ctx: &mut Context, + a: &[impl AsRef>], + coeffs: &[AssignedValue], limb_bases: &[F], -) -> CRTInteger<'v, F> { +) -> CRTInteger { assert_eq!(a.len(), coeffs.len()); - let k = a[0].truncation.limbs.len(); + let k = a[0].as_ref().truncation.limbs.len(); let out_limbs = (0..k) .map(|idx| { - let int_limbs = a.iter().map(|a| Existing(&a.truncation.limbs[idx])); - gate.select_by_indicator(ctx, int_limbs, coeffs.iter()) + let int_limbs = a.iter().map(|a| a.as_ref().truncation.limbs[idx]); + gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied()) }) .collect(); - let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.truncation.max_limb_bits)); + let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.as_ref().truncation.max_limb_bits)); - let out_trunc = OverflowInteger::construct(out_limbs, max_limb_bits); + let out_trunc = OverflowInteger::new(out_limbs, max_limb_bits); let out_native = if a.len() > k { - OverflowInteger::::evaluate(gate, ctx, &out_trunc.limbs, limb_bases[..k].iter().cloned()) + OverflowInteger::evaluate_native( + ctx, + gate, + out_trunc.limbs.iter().copied(), + &limb_bases[..k], + ) } else { - let a_native = a.iter().map(|x| Existing(&x.native)); - gate.select_by_indicator(ctx, a_native, coeffs.iter()) + let a_native = a.iter().map(|x| x.as_ref().native); + gate.select_by_indicator(ctx, a_native, coeffs.iter().copied()) }; - let out_val = a.iter().zip(coeffs.iter()).fold(Value::known(BigInt::zero()), |acc, (x, y)| { - acc.zip(x.value.as_ref()).zip(y.value()).map(|((a, x), y)| { - if y.is_zero_vartime() { - a - } else { - x.clone() - } - }) + let out_val = a.iter().zip(coeffs.iter()).fold(BigInt::zero(), |acc, (x, y)| { + if y.value().is_zero_vartime() { + acc + } else { + x.as_ref().value.clone() + } }); - CRTInteger::construct(out_trunc, out_native, out_val) + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/sub.rs b/halo2-ecc/src/bigint/sub.rs index 5e987f0c..8b2263f9 100644 --- a/halo2-ecc/src/bigint/sub.rs +++ b/halo2-ecc/src/bigint/sub.rs @@ -1,81 +1,79 @@ -use super::{CRTInteger, OverflowInteger}; +use super::{CRTInteger, OverflowInteger, ProperCrtUint, ProperUint}; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::PrimeField, + utils::ScalarField, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, }; +use itertools::Itertools; -/// Should only be called on integers a, b in proper representation with all limbs having at most `limb_bits` number of bits -pub fn assign<'a, F: PrimeField>( +/// # Assumptions +/// * Should only be called on integers a, b in proper representation with all limbs having at most `limb_bits` number of bits +/// * `a, b` have same nonzero number of limbs +pub fn assign( range: &impl RangeInstructions, - ctx: &mut Context<'a, F>, - a: &OverflowInteger<'a, F>, - b: &OverflowInteger<'a, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, limb_bits: usize, limb_base: F, -) -> (OverflowInteger<'a, F>, AssignedValue<'a, F>) { - assert!(a.max_limb_bits <= limb_bits); - assert!(b.max_limb_bits <= limb_bits); - assert_eq!(a.limbs.len(), b.limbs.len()); - let k = a.limbs.len(); +) -> (OverflowInteger, AssignedValue) { + let a = a.into(); + let b = b.into(); + let k = a.0.len(); let mut out_limbs = Vec::with_capacity(k); let mut borrow: Option> = None; - for (a_limb, b_limb) in a.limbs.iter().zip(b.limbs.iter()) { + for (a_limb, b_limb) in a.0.into_iter().zip_eq(b.0) { let (bottom, lt) = match borrow { None => { - let lt = range.is_less_than(ctx, Existing(a_limb), Existing(b_limb), limb_bits); - (b_limb.clone(), lt) + let lt = range.is_less_than(ctx, a_limb, b_limb, limb_bits); + (b_limb, lt) } Some(borrow) => { - let b_plus_borrow = range.gate().add(ctx, Existing(b_limb), Existing(&borrow)); - let lt = range.is_less_than( - ctx, - Existing(a_limb), - Existing(&b_plus_borrow), - limb_bits + 1, - ); + let b_plus_borrow = range.gate().add(ctx, b_limb, borrow); + let lt = range.is_less_than(ctx, a_limb, b_plus_borrow, limb_bits + 1); (b_plus_borrow, lt) } }; let out_limb = { // | a | lt | 2^n | a + lt * 2^n | -1 | bottom | a + lt * 2^n - bottom - let a_with_borrow_val = - a_limb.value().zip(lt.value()).map(|(a, lt)| limb_base * lt + a); - let out_val = a_with_borrow_val.zip(bottom.value()).map(|(ac, b)| ac - b); - range.gate().assign_region_last( - ctx, - vec![ + let a_with_borrow_val = limb_base * lt.value() + a_limb.value(); + let out_val = a_with_borrow_val - bottom.value(); + ctx.assign_region_last( + [ Existing(a_limb), - Existing(<), + Existing(lt), Constant(limb_base), Witness(a_with_borrow_val), Constant(-F::one()), - Existing(&bottom), + Existing(bottom), Witness(out_val), ], - vec![(0, None), (3, None)], + [0, 3], ) }; out_limbs.push(out_limb); borrow = Some(lt); } - (OverflowInteger::construct(out_limbs, limb_bits), borrow.unwrap()) + (OverflowInteger::new(out_limbs, limb_bits), borrow.unwrap()) } // returns (a-b, underflow), where underflow is nonzero iff a < b -pub fn crt<'a, F: PrimeField>( +/// # Assumptions +/// * `a, b` are proper CRT representations of integers with the same number of limbs +pub fn crt( range: &impl RangeInstructions, - ctx: &mut Context<'a, F>, - a: &CRTInteger<'a, F>, - b: &CRTInteger<'a, F>, + ctx: &mut Context, + a: ProperCrtUint, + b: ProperCrtUint, limb_bits: usize, limb_base: F, -) -> (CRTInteger<'a, F>, AssignedValue<'a, F>) { - let (out_trunc, underflow) = - assign::(range, ctx, &a.truncation, &b.truncation, limb_bits, limb_base); - let out_native = range.gate().sub(ctx, Existing(&a.native), Existing(&b.native)); - let out_val = a.value.as_ref().zip(b.value.as_ref()).map(|(a, b)| a - b); - (CRTInteger::construct(out_trunc, out_native, out_val), underflow) +) -> (CRTInteger, AssignedValue) { + let out_native = range.gate().sub(ctx, a.0.native, b.0.native); + let a_limbs = ProperUint(a.0.truncation.limbs); + let b_limbs = ProperUint(b.0.truncation.limbs); + let (out_trunc, underflow) = assign(range, ctx, a_limbs, b_limbs, limb_bits, limb_base); + let out_val = a.0.value - b.0.value; + (CRTInteger::new(out_trunc, out_native, out_val), underflow) } diff --git a/halo2-ecc/src/bigint/sub_no_carry.rs b/halo2-ecc/src/bigint/sub_no_carry.rs index 2226027d..4e8867c0 100644 --- a/halo2-ecc/src/bigint/sub_no_carry.rs +++ b/halo2-ecc/src/bigint/sub_no_carry.rs @@ -1,32 +1,34 @@ use super::{CRTInteger, OverflowInteger}; -use halo2_base::{gates::GateInstructions, utils::PrimeField, Context, QuantumCell::Existing}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, Context}; +use itertools::Itertools; use std::cmp::max; -pub fn assign<'v, F: PrimeField>( +/// # Assumptions +/// * `a, b` have same number of limbs +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, -) -> OverflowInteger<'v, F> { - assert_eq!(a.limbs.len(), b.limbs.len()); + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, +) -> OverflowInteger { let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(a_limb, b_limb)| gate.sub(ctx, Existing(a_limb), Existing(b_limb))) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.sub(ctx, a_limb, b_limb)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) + OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) } -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, -) -> CRTInteger<'v, F> { - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation); - let out_native = gate.sub(ctx, Existing(&a.native), Existing(&b.native)); - let out_val = a.value.as_ref().zip(b.value.as_ref()).map(|(a, b)| a - b); - CRTInteger::construct(out_trunc, out_native, out_val) + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, +) -> CRTInteger { + let out_trunc = assign(gate, ctx, a.truncation, b.truncation); + let out_native = gate.sub(ctx, a.native, b.native); + let out_val = a.value - b.value; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bn254/configs/msm_circuit.config b/halo2-ecc/src/bn254/configs/msm_circuit.config deleted file mode 100644 index 9246e19f..00000000 --- a/halo2-ecc/src/bn254/configs/msm_circuit.config +++ /dev/null @@ -1 +0,0 @@ -{"strategy":"Simple","degree":20,"num_advice":10,"num_lookup_advice":2,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/final_exp.rs b/halo2-ecc/src/bn254/final_exp.rs index e131f7d5..7959142e 100644 --- a/halo2-ecc/src/bn254/final_exp.rs +++ b/halo2-ecc/src/bn254/final_exp.rs @@ -1,79 +1,76 @@ -use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint}; +use super::{Fp12Chip, Fp2Chip, FpChip, FqPoint}; use crate::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Fq, Fq2, BN_X, FROBENIUS_COEFF_FQ12_C1}, }; use crate::{ ecc::get_naf, - fields::{fp12::mul_no_carry_w6, FieldChip, FieldExtPoint}, -}; -use halo2_base::{ - gates::GateInstructions, - utils::{fe_to_biguint, modulus, PrimeField}, - Context, - QuantumCell::{Constant, Existing}, + fields::{fp12::mul_no_carry_w6, vector::FieldVector, FieldChip, PrimeField}, }; +use halo2_base::{gates::GateInstructions, utils::modulus, Context, QuantumCell::Constant}; use num_bigint::BigUint; const XI_0: i64 = 9; -impl<'a, F: PrimeField> Fp12Chip<'a, F> { +impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { // computes a ** (p ** power) // only works for p = 3 (mod 4) and p = 1 (mod 6) - pub fn frobenius_map<'v>( + pub fn frobenius_map( &self, - ctx: &mut Context<'v, F>, - a: &>::FieldPoint<'v>, + ctx: &mut Context, + a: &>::FieldPoint, power: usize, - ) -> >::FieldPoint<'v> { + ) -> >::FieldPoint { assert_eq!(modulus::() % 4u64, BigUint::from(3u64)); assert_eq!(modulus::() % 6u64, BigUint::from(1u64)); - assert_eq!(a.coeffs.len(), 12); + assert_eq!(a.0.len(), 12); let pow = power % 12; let mut out_fp2 = Vec::with_capacity(6); - let fp2_chip = Fp2Chip::::construct(self.fp_chip); + let fp_chip = self.fp_chip(); + let fp2_chip = Fp2Chip::::new(fp_chip); for i in 0..6 { let frob_coeff = FROBENIUS_COEFF_FQ12_C1[pow].pow_vartime([i as u64]); // possible optimization (not implemented): load `frob_coeff` as we multiply instead of loading first // frobenius map is used infrequently so this is a small optimization - let mut a_fp2 = - FieldExtPoint::construct(vec![a.coeffs[i].clone(), a.coeffs[i + 6].clone()]); + let mut a_fp2 = FieldVector(vec![a[i].clone(), a[i + 6].clone()]); if pow % 2 != 0 { - a_fp2 = fp2_chip.conjugate(ctx, &a_fp2); + a_fp2 = fp2_chip.conjugate(ctx, a_fp2); } // if `frob_coeff` is in `Fp` and not just `Fp2`, then we can be more efficient in multiplication if frob_coeff == Fq2::one() { out_fp2.push(a_fp2); } else if frob_coeff.c1 == Fq::zero() { - let frob_fixed = fp2_chip.fp_chip.load_constant(ctx, fe_to_biguint(&frob_coeff.c0)); + let frob_fixed = fp_chip.load_constant(ctx, frob_coeff.c0); { - let out_nocarry = fp2_chip.fp_mul_no_carry(ctx, &a_fp2, &frob_fixed); - out_fp2.push(fp2_chip.carry_mod(ctx, &out_nocarry)); + let out_nocarry = fp2_chip.0.fp_mul_no_carry(ctx, a_fp2, frob_fixed); + out_fp2.push(fp2_chip.carry_mod(ctx, out_nocarry)); } } else { let frob_fixed = fp2_chip.load_constant(ctx, frob_coeff); - out_fp2.push(fp2_chip.mul(ctx, &a_fp2, &frob_fixed)); + out_fp2.push(fp2_chip.mul(ctx, a_fp2, frob_fixed)); } } let out_coeffs = out_fp2 .iter() - .map(|x| x.coeffs[0].clone()) - .chain(out_fp2.iter().map(|x| x.coeffs[1].clone())) + .map(|x| x[0].clone()) + .chain(out_fp2.iter().map(|x| x[1].clone())) .collect(); - FieldExtPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // exp is in little-endian - pub fn pow<'v>( + /// # Assumptions + /// * `a` is nonzero field point + pub fn pow( &self, - ctx: &mut Context<'v, F>, - a: &>::FieldPoint<'v>, + ctx: &mut Context, + a: &>::FieldPoint, exp: Vec, - ) -> >::FieldPoint<'v> { + ) -> >::FieldPoint { let mut res = a.clone(); let mut is_started = false; let naf = get_naf(exp); @@ -86,7 +83,11 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { if z != 0 { assert!(z == 1 || z == -1); if is_started { - res = if z == 1 { self.mul(ctx, &res, a) } else { self.divide(ctx, &res, a) }; + res = if z == 1 { + self.mul(ctx, &res, a) + } else { + self.divide_unsafe(ctx, &res, a) + }; } else { assert_eq!(z, 1); is_started = true; @@ -106,14 +107,12 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { /// in = g0 + g2 w + g4 w^2 + g1 w^3 + g3 w^4 + g5 w^5 where g_i = g_i0 + g_i1 * u are elements of Fp2 /// out = Compress(in) = [ g2, g3, g4, g5 ] - pub fn cyclotomic_compress<'v>( - &self, - a: &FieldExtPoint>, - ) -> Vec>> { - let g2 = FieldExtPoint::construct(vec![a.coeffs[1].clone(), a.coeffs[1 + 6].clone()]); - let g3 = FieldExtPoint::construct(vec![a.coeffs[4].clone(), a.coeffs[4 + 6].clone()]); - let g4 = FieldExtPoint::construct(vec![a.coeffs[2].clone(), a.coeffs[2 + 6].clone()]); - let g5 = FieldExtPoint::construct(vec![a.coeffs[5].clone(), a.coeffs[5 + 6].clone()]); + pub fn cyclotomic_compress(&self, a: &FqPoint) -> Vec> { + let a = &a.0; + let g2 = FieldVector(vec![a[1].clone(), a[1 + 6].clone()]); + let g3 = FieldVector(vec![a[4].clone(), a[4 + 6].clone()]); + let g4 = FieldVector(vec![a[2].clone(), a[2 + 6].clone()]); + let g5 = FieldVector(vec![a[5].clone(), a[5 + 6].clone()]); vec![g2, g3, g4, g5] } @@ -129,16 +128,17 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { /// if g2 = 0: /// g1 = (2 g4 * g5)/g3 /// g0 = (2 g1^2 - 3 g3 * g4) * c + 1 - pub fn cyclotomic_decompress<'v>( + pub fn cyclotomic_decompress( &self, - ctx: &mut Context<'v, F>, - compression: Vec>>, - ) -> FieldExtPoint> { - let [g2, g3, g4, g5]: [FieldExtPoint>; 4] = compression.try_into().unwrap(); + ctx: &mut Context, + compression: Vec>, + ) -> FqPoint { + let [g2, g3, g4, g5]: [_; 4] = compression.try_into().unwrap(); - let fp2_chip = Fp2Chip::::construct(self.fp_chip); + let fp_chip = self.fp_chip(); + let fp2_chip = Fp2Chip::::new(fp_chip); let g5_sq = fp2_chip.mul_no_carry(ctx, &g5, &g5); - let g5_sq_c = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &g5_sq); + let g5_sq_c = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, g5_sq); let g4_sq = fp2_chip.mul_no_carry(ctx, &g4, &g4); let g4_sq_3 = fp2_chip.scalar_mul_no_carry(ctx, &g4_sq, 3); @@ -148,15 +148,15 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { g1_num = fp2_chip.sub_no_carry(ctx, &g1_num, &g3_2); // can divide without carrying g1_num or g1_denom (I think) let g2_4 = fp2_chip.scalar_mul_no_carry(ctx, &g2, 4); - let g1_1 = fp2_chip.divide(ctx, &g1_num, &g2_4); + let g1_1 = fp2_chip.divide_unsafe(ctx, &g1_num, &g2_4); let g4_g5 = fp2_chip.mul_no_carry(ctx, &g4, &g5); let g1_num = fp2_chip.scalar_mul_no_carry(ctx, &g4_g5, 2); - let g1_0 = fp2_chip.divide(ctx, &g1_num, &g3); + let g1_0 = fp2_chip.divide_unsafe(ctx, &g1_num, &g3); let g2_is_zero = fp2_chip.is_zero(ctx, &g2); // resulting `g1` is already in "carried" format (witness is in `[0, p)`) - let g1 = fp2_chip.select(ctx, &g1_0, &g1_1, &g2_is_zero); + let g1 = fp2_chip.0.select(ctx, g1_0, g1_1, g2_is_zero); // share the computation of 2 g1^2 between the two cases let g1_sq = fp2_chip.mul_no_carry(ctx, &g1, &g1); @@ -166,30 +166,26 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { let g3_g4 = fp2_chip.mul_no_carry(ctx, &g3, &g4); let g3_g4_3 = fp2_chip.scalar_mul_no_carry(ctx, &g3_g4, 3); let temp = fp2_chip.add_no_carry(ctx, &g1_sq_2, &g2_g5); - let temp = fp2_chip.select(ctx, &g1_sq_2, &temp, &g2_is_zero); + let temp = fp2_chip.0.select(ctx, g1_sq_2, temp, g2_is_zero); let temp = fp2_chip.sub_no_carry(ctx, &temp, &g3_g4_3); - let mut g0 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &temp); + let mut g0 = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, temp); // compute `g0 + 1` - g0.coeffs[0].truncation.limbs[0] = fp2_chip.range().gate.add( - ctx, - Existing(&g0.coeffs[0].truncation.limbs[0]), - Constant(F::one()), - ); - g0.coeffs[0].native = - fp2_chip.range().gate.add(ctx, Existing(&g0.coeffs[0].native), Constant(F::one())); - g0.coeffs[0].truncation.max_limb_bits += 1; - g0.coeffs[0].value = g0.coeffs[0].value.as_ref().map(|v| v + 1usize); + g0[0].truncation.limbs[0] = + fp2_chip.gate().add(ctx, g0[0].truncation.limbs[0], Constant(F::one())); + g0[0].native = fp2_chip.gate().add(ctx, g0[0].native, Constant(F::one())); + g0[0].truncation.max_limb_bits += 1; + g0[0].value += 1usize; // finally, carry g0 - g0 = fp2_chip.carry_mod(ctx, &g0); + let g0 = fp2_chip.carry_mod(ctx, g0); - let mut g0 = g0.coeffs.into_iter(); - let mut g1 = g1.coeffs.into_iter(); - let mut g2 = g2.coeffs.into_iter(); - let mut g3 = g3.coeffs.into_iter(); - let mut g4 = g4.coeffs.into_iter(); - let mut g5 = g5.coeffs.into_iter(); + let mut g0 = g0.into_iter(); + let mut g1 = g1.into_iter(); + let mut g2 = g2.into_iter(); + let mut g3 = g3.into_iter(); + let mut g4 = g4.into_iter(); + let mut g5 = g5.into_iter(); let mut out_coeffs = Vec::with_capacity(12); for _ in 0..2 { @@ -202,7 +198,7 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { g5.next().unwrap(), ]); } - FieldExtPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // input is [g2, g3, g4, g5] = C(g) in compressed format of `cyclotomic_compress` @@ -217,61 +213,59 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { // A_ij = (g_i + g_j)(g_i + c g_j) // B_ij = g_i g_j - pub fn cyclotomic_square<'v>( + pub fn cyclotomic_square( &self, - ctx: &mut Context<'v, F>, - compression: &[FieldExtPoint>], - ) -> Vec>> { + ctx: &mut Context, + compression: &[FqPoint], + ) -> Vec> { assert_eq!(compression.len(), 4); let g2 = &compression[0]; let g3 = &compression[1]; let g4 = &compression[2]; let g5 = &compression[3]; - let fp2_chip = Fp2Chip::::construct(self.fp_chip); + let fp_chip = self.fp_chip(); + let fp2_chip = Fp2Chip::::new(fp_chip); let g2_plus_g3 = fp2_chip.add_no_carry(ctx, g2, g3); - let cg3 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, g3); + let cg3 = mul_no_carry_w6::, XI_0>(fp_chip, ctx, g3.into()); let g2_plus_cg3 = fp2_chip.add_no_carry(ctx, g2, &cg3); let a23 = fp2_chip.mul_no_carry(ctx, &g2_plus_g3, &g2_plus_cg3); let g4_plus_g5 = fp2_chip.add_no_carry(ctx, g4, g5); - let cg5 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, g5); + let cg5 = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, g5.into()); let g4_plus_cg5 = fp2_chip.add_no_carry(ctx, g4, &cg5); let a45 = fp2_chip.mul_no_carry(ctx, &g4_plus_g5, &g4_plus_cg5); let b23 = fp2_chip.mul_no_carry(ctx, g2, g3); let b45 = fp2_chip.mul_no_carry(ctx, g4, g5); - let b45_c = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &b45); + let b45_c = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, b45.clone()); let mut temp = fp2_chip.scalar_mul_and_add_no_carry(ctx, &b45_c, g2, 3); let h2 = fp2_chip.scalar_mul_no_carry(ctx, &temp, 2); - temp = fp2_chip.add_no_carry(ctx, &b45_c, &b45); - temp = fp2_chip.sub_no_carry(ctx, &a45, &temp); - temp = fp2_chip.scalar_mul_no_carry(ctx, &temp, 3); - let h3 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g3, &temp, -2); + temp = fp2_chip.add_no_carry(ctx, b45_c, b45); + temp = fp2_chip.sub_no_carry(ctx, &a45, temp); + temp = fp2_chip.scalar_mul_no_carry(ctx, temp, 3); + let h3 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g3, temp, -2); const XI0_PLUS_1: i64 = XI_0 + 1; // (c + 1) = (XI_0 + 1) + u - temp = mul_no_carry_w6::, XI0_PLUS_1>(fp2_chip.fp_chip, ctx, &b23); - temp = fp2_chip.sub_no_carry(ctx, &a23, &temp); - temp = fp2_chip.scalar_mul_no_carry(ctx, &temp, 3); - let h4 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g4, &temp, -2); + temp = mul_no_carry_w6::, XI0_PLUS_1>(fp_chip, ctx, b23.clone()); + temp = fp2_chip.sub_no_carry(ctx, &a23, temp); + temp = fp2_chip.scalar_mul_no_carry(ctx, temp, 3); + let h4 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g4, temp, -2); - temp = fp2_chip.scalar_mul_and_add_no_carry(ctx, &b23, g5, 3); - let h5 = fp2_chip.scalar_mul_no_carry(ctx, &temp, 2); + temp = fp2_chip.scalar_mul_and_add_no_carry(ctx, b23, g5, 3); + let h5 = fp2_chip.scalar_mul_no_carry(ctx, temp, 2); - [h2, h3, h4, h5].iter().map(|h| fp2_chip.carry_mod(ctx, h)).collect() + [h2, h3, h4, h5].into_iter().map(|h| fp2_chip.carry_mod(ctx, h)).collect() } // exp is in little-endian - pub fn cyclotomic_pow<'v>( - &self, - ctx: &mut Context<'v, F>, - a: FieldExtPoint>, - exp: Vec, - ) -> FieldExtPoint> { + /// # Assumptions + /// * `a` is a nonzero element in the cyclotomic subgroup + pub fn cyclotomic_pow(&self, ctx: &mut Context, a: FqPoint, exp: Vec) -> FqPoint { let mut compression = self.cyclotomic_compress(&a); let mut out = None; let mut is_started = false; @@ -285,7 +279,11 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { assert!(z == 1 || z == -1); if is_started { let mut res = self.cyclotomic_decompress(ctx, compression); - res = if z == 1 { self.mul(ctx, &res, &a) } else { self.divide(ctx, &res, &a) }; + res = if z == 1 { + self.mul(ctx, &res, &a) + } else { + self.divide_unsafe(ctx, &res, &a) + }; // compression is free, so it doesn't hurt (except possibly witness generation runtime) to do it // TODO: alternatively we go from small bits to large to avoid this compression compression = self.cyclotomic_compress(&res); @@ -304,11 +302,11 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { #[allow(non_snake_case)] // use equation for (p^4 - p^2 + 1)/r in Section 5 of https://eprint.iacr.org/2008/490.pdf for BN curves - pub fn hard_part_BN<'v>( + pub fn hard_part_BN( &self, - ctx: &mut Context<'v, F>, - m: >::FieldPoint<'v>, - ) -> >::FieldPoint<'v> { + ctx: &mut Context, + m: >::FieldPoint, + ) -> >::FieldPoint { // x = BN_X // m^p @@ -322,7 +320,7 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { let mp2_mp3 = self.mul(ctx, &mp2, &mp3); let y0 = self.mul(ctx, &mp, &mp2_mp3); // y1 = 1/m, inverse = frob(6) = conjugation in cyclotomic subgroup - let y1 = self.conjugate(ctx, &m); + let y1 = self.conjugate(ctx, m.clone()); // m^x let mx = self.cyclotomic_pow(ctx, m, vec![BN_X]); @@ -337,20 +335,20 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { let y2 = self.frobenius_map(ctx, &mx2, 2); // m^{x^3} // y5 = 1/mx2 - let y5 = self.conjugate(ctx, &mx2); + let y5 = self.conjugate(ctx, mx2.clone()); let mx3 = self.cyclotomic_pow(ctx, mx2, vec![BN_X]); // (m^{x^3})^p let mx3p = self.frobenius_map(ctx, &mx3, 1); // y3 = 1/mxp - let y3 = self.conjugate(ctx, &mxp); + let y3 = self.conjugate(ctx, mxp); // y4 = 1/(mx * mx2p) let mx_mx2p = self.mul(ctx, &mx, &mx2p); - let y4 = self.conjugate(ctx, &mx_mx2p); + let y4 = self.conjugate(ctx, mx_mx2p); // y6 = 1/(mx3 * mx3p) let mx3_mx3p = self.mul(ctx, &mx3, &mx3p); - let y6 = self.conjugate(ctx, &mx3_mx3p); + let y6 = self.conjugate(ctx, mx3_mx3p); // out = y0 * y1^2 * y2^6 * y3^12 * y4^18 * y5^30 * y6^36 // we compute this using the vectorial addition chain from p. 6 of https://eprint.iacr.org/2008/490.pdf @@ -372,25 +370,26 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { } // out = in^{ (q^6 - 1)*(q^2 + 1) } - pub fn easy_part<'v>( + /// # Assumptions + /// * `a` is nonzero field point + pub fn easy_part( &self, - ctx: &mut Context<'v, F>, - a: &>::FieldPoint<'v>, - ) -> >::FieldPoint<'v> { + ctx: &mut Context, + a: >::FieldPoint, + ) -> >::FieldPoint { // a^{q^6} = conjugate of a - let f1 = self.conjugate(ctx, a); - let f2 = self.divide(ctx, &f1, a); + let f1 = self.conjugate(ctx, a.clone()); + let f2 = self.divide_unsafe(ctx, &f1, a); let f3 = self.frobenius_map(ctx, &f2, 2); - let f = self.mul(ctx, &f3, &f2); - f + self.mul(ctx, &f3, &f2) } // out = in^{(q^12 - 1)/r} - pub fn final_exp<'v>( + pub fn final_exp( &self, - ctx: &mut Context<'v, F>, - a: &>::FieldPoint<'v>, - ) -> >::FieldPoint<'v> { + ctx: &mut Context, + a: >::FieldPoint, + ) -> >::FieldPoint { let f0 = self.easy_part(ctx, a); let f = self.hard_part_BN(ctx, f0); f diff --git a/halo2-ecc/src/bn254/mod.rs b/halo2-ecc/src/bn254/mod.rs index 5f5db57b..deed3c4d 100644 --- a/halo2-ecc/src/bn254/mod.rs +++ b/halo2-ecc/src/bn254/mod.rs @@ -1,17 +1,16 @@ +use crate::bigint::ProperCrtUint; +use crate::fields::vector::FieldVector; +use crate::fields::{fp, fp12, fp2}; use crate::halo2_proofs::halo2curves::bn256::{Fq, Fq12, Fq2}; -use crate::{ - bigint::CRTInteger, - fields::{fp, fp12, fp2, FieldExtPoint}, -}; pub mod final_exp; pub mod pairing; -type FpChip = fp::FpConfig; -type FpPoint<'v, F> = CRTInteger<'v, F>; -type FqPoint<'v, F> = FieldExtPoint>; -type Fp2Chip<'a, F> = fp2::Fp2Chip<'a, F, FpChip, Fq2>; -type Fp12Chip<'a, F> = fp12::Fp12Chip<'a, F, FpChip, Fq12, 9>; +pub type FpChip<'range, F> = fp::FpChip<'range, F, Fq>; +pub type FpPoint = ProperCrtUint; +pub type FqPoint = FieldVector>; +pub type Fp2Chip<'chip, F> = fp2::Fp2Chip<'chip, F, FpChip<'chip, F>, Fq2>; +pub type Fp12Chip<'chip, F> = fp12::Fp12Chip<'chip, F, FpChip<'chip, F>, Fq12, 9>; #[cfg(test)] pub(crate) mod tests; diff --git a/halo2-ecc/src/bn254/pairing.rs b/halo2-ecc/src/bn254/pairing.rs index 2502ea48..e25f066a 100644 --- a/halo2-ecc/src/bn254/pairing.rs +++ b/halo2-ecc/src/bn254/pairing.rs @@ -1,21 +1,15 @@ #![allow(non_snake_case)] -use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint, FqPoint}; -use crate::halo2_proofs::{ - circuit::Value, - halo2curves::bn256::{self, G1Affine, G2Affine, SIX_U_PLUS_2_NAF}, - halo2curves::bn256::{Fq, Fq2, FROBENIUS_COEFF_FQ12_C1}, - plonk::ConstraintSystem, +use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint, Fq, FqPoint}; +use crate::fields::vector::FieldVector; +use crate::halo2_proofs::halo2curves::bn256::{ + G1Affine, G2Affine, FROBENIUS_COEFF_FQ12_C1, SIX_U_PLUS_2_NAF, }; use crate::{ ecc::{EcPoint, EccChip}, - fields::{fp::FpStrategy, fp12::mul_no_carry_w6}, - fields::{FieldChip, FieldExtPoint}, + fields::fp12::mul_no_carry_w6, + fields::{FieldChip, PrimeField}, }; -use halo2_base::{ - utils::{biguint_to_fe, fe_to_biguint, PrimeField}, - Context, -}; -use num_bigint::BigUint; +use halo2_base::Context; const XI_0: i64 = 9; @@ -27,34 +21,34 @@ const XI_0: i64 = 9; // line_{Psi(Q0), Psi(Q1)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals w^3 (y_1 - y_2) X + w^2 (x_2 - x_1) Y + w^5 (x_1 y_2 - x_2 y_1) =: out3 * w^3 + out2 * w^2 + out5 * w^5 where out2, out3, out5 are Fp2 points // Output is [None, None, out2, out3, None, out5] as vector of `Option`s -pub fn sparse_line_function_unequal<'a, F: PrimeField>( +pub fn sparse_line_function_unequal( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - Q: (&EcPoint>, &EcPoint>), - P: &EcPoint>, -) -> Vec>> { + ctx: &mut Context, + Q: (&EcPoint>, &EcPoint>), + P: &EcPoint>, +) -> Vec>> { let (x_1, y_1) = (&Q.0.x, &Q.0.y); let (x_2, y_2) = (&Q.1.x, &Q.1.y); let (X, Y) = (&P.x, &P.y); - assert_eq!(x_1.coeffs.len(), 2); - assert_eq!(y_1.coeffs.len(), 2); - assert_eq!(x_2.coeffs.len(), 2); - assert_eq!(y_2.coeffs.len(), 2); + assert_eq!(x_1.0.len(), 2); + assert_eq!(y_1.0.len(), 2); + assert_eq!(x_2.0.len(), 2); + assert_eq!(y_2.0.len(), 2); let y1_minus_y2 = fp2_chip.sub_no_carry(ctx, y_1, y_2); let x2_minus_x1 = fp2_chip.sub_no_carry(ctx, x_2, x_1); let x1y2 = fp2_chip.mul_no_carry(ctx, x_1, y_2); let x2y1 = fp2_chip.mul_no_carry(ctx, x_2, y_1); - let out3 = fp2_chip.fp_mul_no_carry(ctx, &y1_minus_y2, X); - let out2 = fp2_chip.fp_mul_no_carry(ctx, &x2_minus_x1, Y); + let out3 = fp2_chip.0.fp_mul_no_carry(ctx, y1_minus_y2, X); + let out2 = fp2_chip.0.fp_mul_no_carry(ctx, x2_minus_x1, Y); let out5 = fp2_chip.sub_no_carry(ctx, &x1y2, &x2y1); // so far we have not "carried mod p" for any of the outputs // we do this below - vec![None, None, Some(out2), Some(out3), None, Some(out5)] - .iter() - .map(|option_nc| option_nc.as_ref().map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) + [None, None, Some(out2), Some(out3), None, Some(out5)] + .into_iter() + .map(|option_nc| option_nc.map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) .collect() } @@ -66,15 +60,15 @@ pub fn sparse_line_function_unequal<'a, F: PrimeField>( // line_{Psi(Q), Psi(Q)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals (3x^3 - 2y^2)(XI_0 + u) + w^4 (-3 x^2 * Q.x) + w^3 (2 y * Q.y) =: out0 + out4 * w^4 + out3 * w^3 where out0, out3, out4 are Fp2 points // Output is [out0, None, None, out3, out4, None] as vector of `Option`s -pub fn sparse_line_function_equal<'a, F: PrimeField>( +pub fn sparse_line_function_equal( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - Q: &EcPoint>, - P: &EcPoint>, -) -> Vec>> { + ctx: &mut Context, + Q: &EcPoint>, + P: &EcPoint>, +) -> Vec>> { let (x, y) = (&Q.x, &Q.y); - assert_eq!(x.coeffs.len(), 2); - assert_eq!(y.coeffs.len(), 2); + assert_eq!(x.0.len(), 2); + assert_eq!(y.0.len(), 2); let x_sq = fp2_chip.mul(ctx, x, x); @@ -83,38 +77,38 @@ pub fn sparse_line_function_equal<'a, F: PrimeField>( let y_sq = fp2_chip.mul_no_carry(ctx, y, y); let two_y_sq = fp2_chip.scalar_mul_no_carry(ctx, &y_sq, 2); let out0_left = fp2_chip.sub_no_carry(ctx, &three_x_cu, &two_y_sq); - let out0 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &out0_left); + let out0 = mul_no_carry_w6::<_, _, XI_0>(fp2_chip.fp_chip(), ctx, out0_left); - let x_sq_Px = fp2_chip.fp_mul_no_carry(ctx, &x_sq, &P.x); - let out4 = fp2_chip.scalar_mul_no_carry(ctx, &x_sq_Px, -3); + let x_sq_Px = fp2_chip.0.fp_mul_no_carry(ctx, x_sq, &P.x); + let out4 = fp2_chip.scalar_mul_no_carry(ctx, x_sq_Px, -3); - let y_Py = fp2_chip.fp_mul_no_carry(ctx, y, &P.y); + let y_Py = fp2_chip.0.fp_mul_no_carry(ctx, y.clone(), &P.y); let out3 = fp2_chip.scalar_mul_no_carry(ctx, &y_Py, 2); // so far we have not "carried mod p" for any of the outputs // we do this below - vec![Some(out0), None, None, Some(out3), Some(out4), None] - .iter() - .map(|option_nc| option_nc.as_ref().map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) + [Some(out0), None, None, Some(out3), Some(out4), None] + .into_iter() + .map(|option_nc| option_nc.map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) .collect() } // multiply Fp12 point `a` with Fp12 point `b` where `b` is len 6 vector of Fp2 points, where some are `None` to represent zero. // Assumes `b` is not vector of all `None`s -pub fn sparse_fp12_multiply<'a, F: PrimeField>( +pub fn sparse_fp12_multiply( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - a: &FqPoint<'a, F>, - b_fp2_coeffs: &Vec>>, -) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 12); + ctx: &mut Context, + a: &FqPoint, + b_fp2_coeffs: &[Option>], +) -> FqPoint { + assert_eq!(a.0.len(), 12); assert_eq!(b_fp2_coeffs.len(), 6); let mut a_fp2_coeffs = Vec::with_capacity(6); for i in 0..6 { - a_fp2_coeffs.push(FqPoint::construct(vec![a.coeffs[i].clone(), a.coeffs[i + 6].clone()])); + a_fp2_coeffs.push(FieldVector(vec![a[i].clone(), a[i + 6].clone()])); } // a * b as element of Fp2[w] without evaluating w^6 = (XI_0 + u) - let mut prod_2d: Vec>>> = vec![None; 11]; + let mut prod_2d = vec![None; 11]; for i in 0..6 { for j in 0..6 { prod_2d[i + j] = @@ -139,7 +133,7 @@ pub fn sparse_fp12_multiply<'a, F: PrimeField>( let prod_nocarry = if i != 5 { let eval_w6 = prod_2d[i + 6] .as_ref() - .map(|a| mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, a)); + .map(|a| mul_no_carry_w6::<_, _, XI_0>(fp2_chip.fp_chip(), ctx, a.clone())); match (prod_2d[i].as_ref(), eval_w6) { (None, b) => b.unwrap(), // Our current use cases of 235 and 034 sparse multiplication always result in non-None value (Some(a), None) => a.clone(), @@ -148,18 +142,18 @@ pub fn sparse_fp12_multiply<'a, F: PrimeField>( } else { prod_2d[i].clone().unwrap() }; - let prod = fp2_chip.carry_mod(ctx, &prod_nocarry); + let prod = fp2_chip.carry_mod(ctx, prod_nocarry); out_fp2.push(prod); } let mut out_coeffs = Vec::with_capacity(12); for fp2_coeff in &out_fp2 { - out_coeffs.push(fp2_coeff.coeffs[0].clone()); + out_coeffs.push(fp2_coeff[0].clone()); } for fp2_coeff in &out_fp2 { - out_coeffs.push(fp2_coeff.coeffs[1].clone()); + out_coeffs.push(fp2_coeff[1].clone()); } - FqPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // Input: @@ -168,13 +162,13 @@ pub fn sparse_fp12_multiply<'a, F: PrimeField>( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q0), Psi(Q1)}(P) as Fp12 point -pub fn fp12_multiply_with_line_unequal<'a, F: PrimeField>( +pub fn fp12_multiply_with_line_unequal( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - g: &FqPoint<'a, F>, - Q: (&EcPoint>, &EcPoint>), - P: &EcPoint>, -) -> FqPoint<'a, F> { + ctx: &mut Context, + g: &FqPoint, + Q: (&EcPoint>, &EcPoint>), + P: &EcPoint>, +) -> FqPoint { let line = sparse_line_function_unequal::(fp2_chip, ctx, Q, P); sparse_fp12_multiply::(fp2_chip, ctx, g, &line) } @@ -185,13 +179,13 @@ pub fn fp12_multiply_with_line_unequal<'a, F: PrimeField>( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q), Psi(Q)}(P) as Fp12 point -pub fn fp12_multiply_with_line_equal<'a, F: PrimeField>( +pub fn fp12_multiply_with_line_equal( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - g: &FqPoint<'a, F>, - Q: &EcPoint>, - P: &EcPoint>, -) -> FqPoint<'a, F> { + ctx: &mut Context, + g: &FqPoint, + Q: &EcPoint>, + P: &EcPoint>, +) -> FqPoint { let line = sparse_line_function_equal::(fp2_chip, ctx, Q, P); sparse_fp12_multiply::(fp2_chip, ctx, g, &line) } @@ -214,20 +208,20 @@ pub fn fp12_multiply_with_line_equal<'a, F: PrimeField>( // - `0 <= loop_count < r` and `loop_count < p` (to avoid [loop_count]Q' = Frob_p(Q')) // - x^3 + b = 0 has no solution in Fp2, i.e., the y-coordinate of Q cannot be 0. -pub fn miller_loop_BN<'a, 'b, F: PrimeField>( - ecc_chip: &EccChip>, - ctx: &mut Context<'b, F>, - Q: &EcPoint>, - P: &EcPoint>, +pub fn miller_loop_BN( + ecc_chip: &EccChip>, + ctx: &mut Context, + Q: &EcPoint>, + P: &EcPoint>, pseudo_binary_encoding: &[i8], -) -> FqPoint<'b, F> { +) -> FqPoint { let mut i = pseudo_binary_encoding.len() - 1; while pseudo_binary_encoding[i] == 0 { i -= 1; } let last_index = i; - let neg_Q = ecc_chip.negate(ctx, Q); + let neg_Q = ecc_chip.negate(ctx, Q.clone()); assert!(pseudo_binary_encoding[i] == 1 || pseudo_binary_encoding[i] == -1); let mut R = if pseudo_binary_encoding[i] == 1 { Q.clone() } else { neg_Q.clone() }; i -= 1; @@ -236,28 +230,29 @@ pub fn miller_loop_BN<'a, 'b, F: PrimeField>( let sparse_f = sparse_line_function_equal::(ecc_chip.field_chip(), ctx, &R, P); assert_eq!(sparse_f.len(), 6); - let zero_fp = ecc_chip.field_chip.fp_chip.load_constant(ctx, BigUint::from(0u64)); + let fp_chip = ecc_chip.field_chip.fp_chip(); + let zero_fp = fp_chip.load_constant(ctx, Fq::zero()); let mut f_coeffs = Vec::with_capacity(12); for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[0].clone()); + f_coeffs.push(fp2_point[0].clone()); } else { f_coeffs.push(zero_fp.clone()); } } for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[1].clone()); + f_coeffs.push(fp2_point[1].clone()); } else { f_coeffs.push(zero_fp.clone()); } } - let mut f = FqPoint::construct(f_coeffs); + let mut f = FieldVector(f_coeffs); + let fp12_chip = Fp12Chip::::new(fp_chip); loop { if i != last_index - 1 { - let fp12_chip = Fp12Chip::::construct(ecc_chip.field_chip.fp_chip); let f_sq = fp12_chip.mul(ctx, &f, &f); f = fp12_multiply_with_line_equal::(ecc_chip.field_chip(), ctx, &f_sq, &R, P); } @@ -299,12 +294,12 @@ pub fn miller_loop_BN<'a, 'b, F: PrimeField>( // let pairs = [(a_i, b_i)], a_i in G_1, b_i in G_2 // output is Prod_i e'(a_i, b_i), where e'(a_i, b_i) is the output of `miller_loop_BN(b_i, a_i)` -pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( - ecc_chip: &EccChip>, - ctx: &mut Context<'b, F>, - pairs: Vec<(&EcPoint>, &EcPoint>)>, +pub fn multi_miller_loop_BN( + ecc_chip: &EccChip>, + ctx: &mut Context, + pairs: Vec<(&EcPoint>, &EcPoint>)>, pseudo_binary_encoding: &[i8], -) -> FqPoint<'b, F> { +) -> FqPoint { let mut i = pseudo_binary_encoding.len() - 1; while pseudo_binary_encoding[i] == 0 { i -= 1; @@ -314,29 +309,30 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( let neg_b = pairs.iter().map(|pair| ecc_chip.negate(ctx, pair.1)).collect::>(); + let fp_chip = ecc_chip.field_chip.fp_chip(); // initialize the first line function into Fq12 point let mut f = { let sparse_f = sparse_line_function_equal::(ecc_chip.field_chip(), ctx, pairs[0].1, pairs[0].0); assert_eq!(sparse_f.len(), 6); - let zero_fp = ecc_chip.field_chip.fp_chip.load_constant(ctx, BigUint::from(0u64)); + let zero_fp = fp_chip.load_constant(ctx, Fq::zero()); let mut f_coeffs = Vec::with_capacity(12); for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[0].clone()); + f_coeffs.push(fp2_point[0].clone()); } else { f_coeffs.push(zero_fp.clone()); } } for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[1].clone()); + f_coeffs.push(fp2_point[1].clone()); } else { f_coeffs.push(zero_fp.clone()); } } - FqPoint::construct(f_coeffs) + FieldVector(f_coeffs) }; for &(a, b) in pairs.iter().skip(1) { f = fp12_multiply_with_line_equal::(ecc_chip.field_chip(), ctx, &f, b, a); @@ -344,7 +340,7 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( i -= 1; let mut r = pairs.iter().map(|pair| pair.1.clone()).collect::>(); - let fp12_chip = Fp12Chip::::construct(ecc_chip.field_chip.fp_chip); + let fp12_chip = Fp12Chip::::new(fp_chip); loop { if i != last_index - 1 { f = fp12_chip.mul(ctx, &f, &f); @@ -353,7 +349,7 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( } } for r in r.iter_mut() { - *r = ecc_chip.double(ctx, r); + *r = ecc_chip.double(ctx, r.clone()); } assert!(pseudo_binary_encoding[i] <= 1 && pseudo_binary_encoding[i] >= -1); @@ -367,7 +363,7 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( (r, sign_b), a, ); - *r = ecc_chip.add_unequal(ctx, r, sign_b, false); + *r = ecc_chip.add_unequal(ctx, r.clone(), sign_b, false); } } if i == 0 { @@ -384,11 +380,11 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( let c3 = ecc_chip.field_chip.load_constant(ctx, c3); // finish multiplying remaining line functions outside the loop - for (r, &(a, b)) in r.iter_mut().zip(pairs.iter()) { - let b_1 = twisted_frobenius::(ecc_chip, ctx, b, &c2, &c3); - let neg_b_2 = neg_twisted_frobenius::(ecc_chip, ctx, &b_1, &c2, &c3); - f = fp12_multiply_with_line_unequal::(ecc_chip.field_chip(), ctx, &f, (r, &b_1), a); - *r = ecc_chip.add_unequal(ctx, r, &b_1, false); + for (r, (a, b)) in r.iter_mut().zip(pairs) { + let b_1 = twisted_frobenius(ecc_chip, ctx, b, &c2, &c3); + let neg_b_2 = neg_twisted_frobenius(ecc_chip, ctx, &b_1, &c2, &c3); + f = fp12_multiply_with_line_unequal(ecc_chip.field_chip(), ctx, &f, (r, &b_1), a); + *r = ecc_chip.add_unequal(ctx, r.clone(), b_1, false); f = fp12_multiply_with_line_unequal::(ecc_chip.field_chip(), ctx, &f, (r, &neg_b_2), a); } f @@ -401,21 +397,24 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( // - coeff[1][2], coeff[1][3] as assigned cells: this is an optimization to avoid loading new constants // Output: // - (coeff[1][2] * x^p, coeff[1][3] * y^p) point in E(Fp2) -pub fn twisted_frobenius<'a, 'b, F: PrimeField>( - ecc_chip: &EccChip>, - ctx: &mut Context<'b, F>, - Q: &EcPoint>, - c2: &FqPoint<'b, F>, - c3: &FqPoint<'b, F>, -) -> EcPoint> { - assert_eq!(c2.coeffs.len(), 2); - assert_eq!(c3.coeffs.len(), 2); - - let frob_x = ecc_chip.field_chip.conjugate(ctx, &Q.x); - let frob_y = ecc_chip.field_chip.conjugate(ctx, &Q.y); - let out_x = ecc_chip.field_chip.mul(ctx, c2, &frob_x); - let out_y = ecc_chip.field_chip.mul(ctx, c3, &frob_y); - EcPoint::construct(out_x, out_y) +pub fn twisted_frobenius( + ecc_chip: &EccChip>, + ctx: &mut Context, + Q: impl Into>>, + c2: impl Into>, + c3: impl Into>, +) -> EcPoint> { + let Q = Q.into(); + let c2 = c2.into(); + let c3 = c3.into(); + assert_eq!(c2.0.len(), 2); + assert_eq!(c3.0.len(), 2); + + let frob_x = ecc_chip.field_chip.conjugate(ctx, Q.x); + let frob_y = ecc_chip.field_chip.conjugate(ctx, Q.y); + let out_x = ecc_chip.field_chip.mul(ctx, c2, frob_x); + let out_y = ecc_chip.field_chip.mul(ctx, c3, frob_y); + EcPoint::new(out_x, out_y) } // Frobenius coefficient coeff[1][j] = ((9+u)^{(p-1)/6})^j @@ -424,98 +423,63 @@ pub fn twisted_frobenius<'a, 'b, F: PrimeField>( // - Q = (x, y) point in E(Fp2) // Output: // - (coeff[1][2] * x^p, coeff[1][3] * -y^p) point in E(Fp2) -pub fn neg_twisted_frobenius<'a, 'b, F: PrimeField>( - ecc_chip: &EccChip>, - ctx: &mut Context<'b, F>, - Q: &EcPoint>, - c2: &FqPoint<'b, F>, - c3: &FqPoint<'b, F>, -) -> EcPoint> { - assert_eq!(c2.coeffs.len(), 2); - assert_eq!(c3.coeffs.len(), 2); - - let frob_x = ecc_chip.field_chip.conjugate(ctx, &Q.x); - let neg_frob_y = ecc_chip.field_chip.neg_conjugate(ctx, &Q.y); - let out_x = ecc_chip.field_chip.mul(ctx, c2, &frob_x); - let out_y = ecc_chip.field_chip.mul(ctx, c3, &neg_frob_y); - EcPoint::construct(out_x, out_y) +pub fn neg_twisted_frobenius( + ecc_chip: &EccChip>, + ctx: &mut Context, + Q: impl Into>>, + c2: impl Into>, + c3: impl Into>, +) -> EcPoint> { + let Q = Q.into(); + let c2 = c2.into(); + let c3 = c3.into(); + assert_eq!(c2.0.len(), 2); + assert_eq!(c3.0.len(), 2); + + let frob_x = ecc_chip.field_chip.conjugate(ctx, Q.x); + let neg_frob_y = ecc_chip.field_chip.neg_conjugate(ctx, Q.y); + let out_x = ecc_chip.field_chip.mul(ctx, c2, frob_x); + let out_y = ecc_chip.field_chip.mul(ctx, c3, neg_frob_y); + EcPoint::new(out_x, out_y) } // To avoid issues with mutably borrowing twice (not allowed in Rust), we only store fp_chip and construct g2_chip and fp12_chip in scope when needed for temporary mutable borrows -pub struct PairingChip<'a, F: PrimeField> { - pub fp_chip: &'a FpChip, +pub struct PairingChip<'chip, F: PrimeField> { + pub fp_chip: &'chip FpChip<'chip, F>, } -impl<'a, F: PrimeField> PairingChip<'a, F> { - pub fn construct(fp_chip: &'a FpChip) -> Self { +impl<'chip, F: PrimeField> PairingChip<'chip, F> { + pub fn new(fp_chip: &'chip FpChip) -> Self { Self { fp_chip } } - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - context_id: usize, - k: usize, - ) -> FpChip { - FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - halo2_base::utils::modulus::(), - context_id, - k, - ) - } - - pub fn load_private_g1<'v>( + pub fn load_private_g1_unchecked( &self, - ctx: &mut Context<'_, F>, - point: Value, - ) -> EcPoint> { - // go from pse/pairing::bn256::Fq to forked Fq - let convert_fp = |x: bn256::Fq| biguint_to_fe(&fe_to_biguint(&x)); - let g1_chip = EccChip::construct(self.fp_chip.clone()); - g1_chip - .load_private(ctx, (point.map(|pt| convert_fp(pt.x)), point.map(|pt| convert_fp(pt.y)))) + ctx: &mut Context, + point: G1Affine, + ) -> EcPoint> { + let g1_chip = EccChip::new(self.fp_chip); + g1_chip.load_private_unchecked(ctx, (point.x, point.y)) } - pub fn load_private_g2<'v>( + pub fn load_private_g2_unchecked( &self, - ctx: &mut Context<'_, F>, - point: Value, - ) -> EcPoint>> { - let fp2_chip = Fp2Chip::::construct(self.fp_chip); - let g2_chip = EccChip::construct(fp2_chip); - // go from pse/pairing::bn256::Fq2 to forked public Fq2 - let convert_fp2 = |c0: bn256::Fq, c1: bn256::Fq| Fq2 { - c0: biguint_to_fe(&fe_to_biguint(&c0)), - c1: biguint_to_fe(&fe_to_biguint(&c1)), - }; - let x = point.map(|pt| convert_fp2(pt.x.c0, pt.x.c1)); - let y = point.map(|pt| convert_fp2(pt.y.c0, pt.y.c1)); - - g2_chip.load_private(ctx, (x, y)) + ctx: &mut Context, + point: G2Affine, + ) -> EcPoint> { + let fp2_chip = Fp2Chip::new(self.fp_chip); + let g2_chip = EccChip::new(&fp2_chip); + g2_chip.load_private_unchecked(ctx, (point.x, point.y)) } - pub fn miller_loop<'v>( + pub fn miller_loop( &self, - ctx: &mut Context<'v, F>, - Q: &EcPoint>, - P: &EcPoint>, - ) -> FqPoint<'v, F> { - let fp2_chip = Fp2Chip::::construct(self.fp_chip); - let g2_chip = EccChip::construct(fp2_chip); + ctx: &mut Context, + Q: &EcPoint>, + P: &EcPoint>, + ) -> FqPoint { + let fp2_chip = Fp2Chip::::new(self.fp_chip); + let g2_chip = EccChip::new(&fp2_chip); miller_loop_BN::( &g2_chip, ctx, @@ -525,13 +489,13 @@ impl<'a, F: PrimeField> PairingChip<'a, F> { ) } - pub fn multi_miller_loop<'v>( + pub fn multi_miller_loop( &self, - ctx: &mut Context<'v, F>, - pairs: Vec<(&EcPoint>, &EcPoint>)>, - ) -> FqPoint<'v, F> { - let fp2_chip = Fp2Chip::::construct(self.fp_chip); - let g2_chip = EccChip::construct(fp2_chip); + ctx: &mut Context, + pairs: Vec<(&EcPoint>, &EcPoint>)>, + ) -> FqPoint { + let fp2_chip = Fp2Chip::::new(self.fp_chip); + let g2_chip = EccChip::new(&fp2_chip); multi_miller_loop_BN::( &g2_chip, ctx, @@ -540,21 +504,21 @@ impl<'a, F: PrimeField> PairingChip<'a, F> { ) } - pub fn final_exp<'v>(&self, ctx: &mut Context<'v, F>, f: &FqPoint<'v, F>) -> FqPoint<'v, F> { - let fp12_chip = Fp12Chip::::construct(self.fp_chip); + pub fn final_exp(&self, ctx: &mut Context, f: FqPoint) -> FqPoint { + let fp12_chip = Fp12Chip::::new(self.fp_chip); fp12_chip.final_exp(ctx, f) } // optimal Ate pairing - pub fn pairing<'v>( + pub fn pairing( &self, - ctx: &mut Context<'v, F>, - Q: &EcPoint>, - P: &EcPoint>, - ) -> FqPoint<'v, F> { + ctx: &mut Context, + Q: &EcPoint>, + P: &EcPoint>, + ) -> FqPoint { let f0 = self.miller_loop(ctx, Q, P); - let fp12_chip = Fp12Chip::::construct(self.fp_chip); + let fp12_chip = Fp12Chip::::new(self.fp_chip); // final_exp implemented in final_exp module - fp12_chip.final_exp(ctx, &f0) + fp12_chip.final_exp(ctx, f0) } } diff --git a/halo2-ecc/src/bn254/tests/ec_add.rs b/halo2-ecc/src/bn254/tests/ec_add.rs index 08dc9fb1..a902ce3c 100644 --- a/halo2-ecc/src/bn254/tests/ec_add.rs +++ b/halo2-ecc/src/bn254/tests/ec_add.rs @@ -1,15 +1,19 @@ -use std::env::set_var; use std::fs; -use std::{env::var, fs::File}; +use std::fs::File; +use std::io::{BufRead, BufReader}; use super::*; -use crate::fields::FieldChip; -use crate::halo2_proofs::halo2curves::{bn256::G2Affine, FieldExt}; +use crate::fields::{FieldChip, FpStrategy}; +use crate::halo2_proofs::halo2curves::bn256::G2Affine; use group::cofactor::CofactorCurveAffine; -use halo2_base::SKIP_FIRST_PASS; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::RangeChip; +use halo2_base::utils::fs::gen_srs; +use halo2_base::Context; +use itertools::Itertools; use rand_core::OsRng; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct CircuitParams { strategy: FpStrategy, degree: u32, @@ -22,270 +26,96 @@ struct CircuitParams { batch_size: usize, } -#[derive(Clone, Debug)] -struct Config { - fp_chip: FpChip, - batch_size: usize, -} +fn g2_add_test(ctx: &mut Context, params: CircuitParams, _points: Vec) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp2_chip = Fp2Chip::::new(&fp_chip); + let g2_chip = EccChip::new(&fp2_chip); -impl Config { - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - batch_size: usize, - context_id: usize, - k: usize, - ) -> Self { - let fp_chip = FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - p, - context_id, - k, - ); - Self { fp_chip, batch_size } - } -} + let points = + _points.iter().map(|pt| g2_chip.assign_point_unchecked(ctx, *pt)).collect::>(); -struct EcAddCircuit { - points: Vec>, - batch_size: usize, - _marker: PhantomData, -} + let acc = g2_chip.sum::(ctx, points); -impl Default for EcAddCircuit { - fn default() -> Self { - Self { points: vec![None; 100], batch_size: 100, _marker: PhantomData } - } -} - -impl Circuit for EcAddCircuit { - type Config = Config; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - points: vec![None; self.batch_size], - batch_size: self.batch_size, - _marker: PhantomData, - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("EC_ADD_CONFIG") - .unwrap_or_else(|_| "./src/bn254/configs/ec_add_circuit.config".to_string()); - let params: CircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - Config::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - BigUint::from_str_radix(&Fq::MODULUS[2..], 16).unwrap(), - params.batch_size, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - assert_eq!(config.batch_size, self.points.len()); - - config.fp_chip.load_lookup_table(&mut layouter)?; - let fp2_chip = Fp2Chip::::construct(&config.fp_chip); - let g2_chip = EccChip::construct(fp2_chip.clone()); - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "G2 add", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let display = self.points[0].is_some(); - let points = self - .points - .iter() - .cloned() - .map(|pt| { - g2_chip.assign_point(ctx, pt.map(Value::known).unwrap_or(Value::unknown())) - }) - .collect::>(); - - let acc = g2_chip.sum::(ctx, points.iter()); - - #[cfg(feature = "display")] - if display { - let answer = self - .points - .iter() - .fold(G2Affine::identity(), |a, b| (a + b.unwrap()).to_affine()); - let x = fp2_chip.get_assigned_value(&acc.x); - let y = fp2_chip.get_assigned_value(&acc.y); - x.map(|x| assert_eq!(answer.x, x)); - y.map(|y| assert_eq!(answer.y, y)); - } - - config.fp_chip.finalize(ctx); - - #[cfg(feature = "display")] - if display { - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } + let answer = _points.iter().fold(G2Affine::identity(), |a, b| (a + b).to_affine()); + let x = fp2_chip.get_assigned_value(&acc.x.into()); + let y = fp2_chip.get_assigned_value(&acc.y.into()); + assert_eq!(answer.x, x); + assert_eq!(answer.y, y); } #[test] fn test_ec_add() { - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - folder.push("configs/ec_add_circuit.config"); - set_var("EC_ADD_CONFIG", &folder); - let params_str = std::fs::read_to_string(folder.as_path()) - .unwrap_or_else(|_| panic!("{folder:?} file should exist")); - let params: CircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let k = params.degree; + let path = "configs/bn254/ec_add_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); - let mut rng = OsRng; - - let mut points = Vec::new(); - for _ in 0..params.batch_size { - let new_pt = Some(G2Affine::random(&mut rng)); - points.push(new_pt); - } + let k = params.degree; + let points = (0..params.batch_size).map(|_| G2Affine::random(OsRng)).collect_vec(); - let circuit = - EcAddCircuit:: { points, batch_size: params.batch_size, _marker: PhantomData }; + let mut builder = GateThreadBuilder::::mock(); + g2_add_test(builder.main(0), params, points); - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); + builder.config(k as usize, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); } #[test] fn bench_ec_add() -> Result<(), Box> { - use std::io::BufRead; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - - folder.push("configs/bench_ec_add.config"); - let bench_params_file = std::fs::File::open(folder.as_path())?; - folder.pop(); - folder.pop(); + let config_path = "configs/bn254/bench_ec_add.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/bn254").unwrap(); - folder.push("results/ec_add_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); + let results_path = "results/bn254/ec_add_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,proof_time,proof_size,verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - - let mut params_folder = std::path::PathBuf::new(); - params_folder.push("./params"); - if !params_folder.is_dir() { - std::fs::create_dir(params_folder.as_path())?; - } + fs::create_dir_all("data").unwrap(); - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: CircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); let mut rng = OsRng; - { - folder.pop(); - folder.push("configs/ec_add_circuit.tmp.config"); - set_var("EC_ADD_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } let params_time = start_timer!(|| "Params construction"); - let params = { - params_folder.push(format!("kzg_bn254_{}.srs", bench_params.degree)); - let fd = std::fs::File::open(params_folder.as_path()); - let params = if let Ok(mut f) = fd { - println!("Found existing params file. Reading params..."); - ParamsKZG::::read(&mut f).unwrap() - } else { - println!("Creating new params file..."); - let mut f = std::fs::File::create(params_folder.as_path())?; - let params = ParamsKZG::::setup(bench_params.degree, &mut rng); - params.write(&mut f).unwrap(); - params - }; - params_folder.pop(); - params - }; + let params = gen_srs(k); end_timer!(params_time); - let circuit = EcAddCircuit:: { - points: vec![None; bench_params.batch_size], - batch_size: bench_params.batch_size, - _marker: PhantomData, + let start0 = start_timer!(|| "Witness generation for empty circuit"); + let circuit = { + let points = vec![G2Affine::generator(); bench_params.batch_size]; + let mut builder = GateThreadBuilder::::keygen(); + g2_add_test(builder.main(0), bench_params, points); + builder.config(k as usize, Some(20)); + RangeCircuitBuilder::keygen(builder) }; + end_timer!(start0); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; end_timer!(vk_time); - let pk_time = start_timer!(|| "Generating pkey"); let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - let mut points = Vec::new(); - for _ in 0..bench_params.batch_size { - let new_pt = Some(G2Affine::random(&mut rng)); - points.push(new_pt); - } - - let proof_circuit = EcAddCircuit:: { - points, - batch_size: bench_params.batch_size, - _marker: PhantomData, - }; + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof + let points = (0..bench_params.batch_size).map(|_| G2Affine::random(&mut rng)).collect_vec(); let proof_time = start_timer!(|| "Proving time"); + let proof_circuit = { + let mut builder = GateThreadBuilder::::prover(); + g2_add_test(builder.main(0), bench_params, points); + builder.config(k as usize, Some(20)); + RangeCircuitBuilder::prover(builder, break_points) + }; let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -299,8 +129,8 @@ fn bench_ec_add() -> Result<(), Box> { end_timer!(proof_time); let proof_size = { - folder.push(format!( - "ec_add_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/ec_add_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -309,27 +139,27 @@ fn bench_ec_add() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierSHPLONK<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("EC_ADD_CONFIG").unwrap())?; writeln!( fs_results, diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index c7239d9d..0283f672 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -1,13 +1,29 @@ -use std::{env::var, fs::File}; +use std::{ + fs::{self, File}, + io::{BufRead, BufReader}, +}; -#[allow(unused_imports)] -use crate::ecc::fixed_base::FixedEcPoint; +use crate::fields::{FpStrategy, PrimeField}; use super::*; -use halo2_base::{halo2_proofs::halo2curves::bn256::G1, SKIP_FIRST_PASS}; - -#[derive(Serialize, Deserialize, Debug)] -struct MSMCircuitParams { +#[allow(unused_imports)] +use ff::PrimeField as _; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + halo2_proofs::halo2curves::bn256::G1, + utils::fs::gen_srs, +}; +use itertools::Itertools; +use rand_core::OsRng; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct FixedMSMCircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -21,274 +37,128 @@ struct MSMCircuitParams { clump_factor: usize, } -#[derive(Clone, Debug)] -struct MSMConfig { - fp_chip: FpChip, - batch_size: usize, - _radix: usize, - _clump_factor: usize, -} - -impl MSMConfig { - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - batch_size: usize, - _radix: usize, - _clump_factor: usize, - context_id: usize, - k: usize, - ) -> Self { - let fp_chip = FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - p, - context_id, - k, - ); - MSMConfig { fp_chip, batch_size, _radix, _clump_factor } - } -} - -struct MSMCircuit { +fn fixed_base_msm_test( + builder: &mut GateThreadBuilder, + params: FixedMSMCircuitParams, bases: Vec, - scalars: Vec>, - _marker: PhantomData, -} - -impl Circuit for MSMCircuit { - type Config = MSMConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - bases: self.bases.clone(), - scalars: vec![None; self.scalars.len()], - _marker: PhantomData, - } + scalars: Vec, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let scalars_assigned = scalars + .iter() + .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) + .collect::>(); + + let msm = ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); + + let mut elts: Vec = Vec::new(); + for (base, scalar) in bases.iter().zip(scalars.iter()) { + elts.push(base * scalar); } + let msm_answer = elts.into_iter().reduce(|a, b| a + b).unwrap().to_affine(); - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("FIXED_MSM_CONFIG") - .unwrap_or_else(|_| "./src/bn254/configs/fixed_msm_circuit.config".to_string()); - let params: MSMCircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - MSMConfig::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - BigUint::from_str_radix(&Fq::MODULUS[2..], 16).unwrap(), - params.batch_size, - params.radix, - params.clump_factor, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - assert_eq!(config.batch_size, self.scalars.len()); - assert_eq!(config.batch_size, self.bases.len()); - - config.fp_chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "fixed base msm", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let witness_time = start_timer!(|| "Witness generation"); - - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let mut scalars_assigned = Vec::new(); - for scalar in &self.scalars { - let assignment = config - .fp_chip - .range - .gate - .assign_witnesses(ctx, vec![scalar.map_or(Value::unknown(), Value::known)]); - scalars_assigned.push(assignment); - } - - let ecc_chip = EccChip::construct(config.fp_chip.clone()); - - // baseline - /* - let msm = { - let sm = self.bases.iter().zip(scalars_assigned.iter()).map(|(base, scalar)| - ecc_chip.fixed_base_scalar_mult(ctx, &FixedEcPoint::::from_g1(base, config.fp_chip.num_limbs, config.fp_chip.limb_bits), scalar, Fr::NUM_BITS as usize, 4)).collect::>(); - ecc_chip.sum::(ctx, sm.iter()) - }; - */ - - let msm = ecc_chip.fixed_base_msm::( - ctx, - &self.bases, - &scalars_assigned, - Fr::NUM_BITS as usize, - config._radix, - config._clump_factor, - ); - - config.fp_chip.finalize(ctx); - end_timer!(witness_time); + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} - #[cfg(feature = "display")] - if self.scalars[0].is_some() { - let mut elts: Vec = Vec::new(); - for (base, scalar) in self.bases.iter().zip(&self.scalars) { - elts.push(base * biguint_to_fe::(&fe_to_biguint(&scalar.unwrap()))); - } - let msm_answer = elts.into_iter().reduce(|a, b| a + b).unwrap().to_affine(); +fn random_fixed_base_msm_circuit( + params: FixedMSMCircuitParams, + bases: Vec, // bases are fixed in vkey so don't randomly generate + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; - let msm_x = value_to_option(msm.x.value).unwrap(); - let msm_y = value_to_option(msm.y.value).unwrap(); - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x).into()); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y).into()); - } + let scalars = (0..params.batch_size).map(|_| Fr::random(OsRng)).collect_vec(); + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + fixed_base_msm_test(&mut builder, params, bases, scalars); - #[cfg(feature = "display")] - if self.scalars[0].is_some() { - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } -#[cfg(test)] #[test] fn test_fixed_base_msm() { - use std::env::set_var; - - use crate::halo2_proofs::arithmetic::Field; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - folder.push("configs/fixed_msm_circuit.config"); - set_var("FIXED_MSM_CONFIG", &folder); - let params_str = std::fs::read_to_string(folder.as_path()) - .expect("src/bn254/configs/fixed_msm_circuit.config file should exist"); - let params: MSMCircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let k = params.degree; - - let mut rng = rand::thread_rng(); - - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..params.batch_size { - bases.push(G1Affine::random(&mut rng)); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - let circuit = MSMCircuit:: { bases, scalars, _marker: PhantomData }; + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let bases = (0..params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); + let circuit = random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); +#[test] +fn test_fixed_msm_minus_1() { + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + let base = G1Affine::random(OsRng); + let k = params.degree as usize; + let mut builder = GateThreadBuilder::mock(); + fixed_base_msm_test(&mut builder, params, vec![base], vec![-Fr::one()]); + + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } -#[cfg(test)] #[test] fn bench_fixed_base_msm() -> Result<(), Box> { - use std::{ - env::{set_var, var}, - fs, - io::BufRead, - }; - - use halo2_base::utils::fs::gen_srs; - use rand_core::OsRng; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - - folder.push("configs/bench_fixed_msm.config"); - let bench_params_file = std::fs::File::open(folder.as_path())?; - folder.pop(); - folder.pop(); - - folder.push("results/fixed_msm_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); + let config_path = "configs/bn254/bench_fixed_msm.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/bn254").unwrap(); + fs::create_dir_all("data").unwrap(); + + let results_path = "results/bn254/fixed_msm_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,proof_time,proof_size,verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - let mut params_folder = std::path::PathBuf::new(); - params_folder.push("./params"); - if !params_folder.is_dir() { - std::fs::create_dir(params_folder.as_path())?; - } - - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { - let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); - let mut rng = OsRng; - - { - folder.pop(); - folder.push("configs/fixed_msm_circuit.tmp.config"); - set_var("FIXED_MSM_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } - let params = gen_srs(bench_params.degree); + let bench_params: FixedMSMCircuitParams = + serde_json::from_str(line.unwrap().as_str()).unwrap(); + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); + let rng = OsRng; + let params = gen_srs(k); println!("{bench_params:?}"); - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _idx in 0..bench_params.batch_size { - bases.push(G1Affine::random(&mut rng)); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - let circuit = - MSMCircuit:: { bases, scalars: vec![None; scalars.len()], _marker: PhantomData }; + let bases = (0..bench_params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); + let circuit = random_fixed_base_msm_circuit( + bench_params, + bases.clone(), + CircuitBuilderStage::Keygen, + None, + ); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; @@ -298,9 +168,16 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - let circuit = MSMCircuit:: { scalars, ..circuit }; + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); + let circuit = random_fixed_base_msm_circuit( + bench_params, + bases, + CircuitBuilderStage::Prover, + Some(break_points), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -308,14 +185,15 @@ fn bench_fixed_base_msm() -> Result<(), Box> { Challenge255, _, Blake2bWrite, G1Affine, Challenge255>, - MSMCircuit, + _, >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; let proof = transcript.finalize(); end_timer!(proof_time); let proof_size = { - folder.push(format!( - "msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/ + msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -324,27 +202,27 @@ fn bench_fixed_base_msm() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierSHPLONK<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("FIXED_MSM_CONFIG").unwrap())?; writeln!( fs_results, diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index 763bd127..172300a1 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -1,36 +1,46 @@ #![allow(non_snake_case)] -use ark_std::{end_timer, start_timer}; -use group::Curve; -use serde::{Deserialize, Serialize}; -use std::io::Write; -use std::marker::PhantomData; - use super::pairing::PairingChip; use super::*; -use crate::halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - dev::MockProver, - halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, - plonk::*, - poly::commitment::{Params, ParamsProver}, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, +use crate::{ecc::EccChip, fields::PrimeField}; +use crate::{ + fields::FpStrategy, + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, + plonk::*, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + transcript::{Blake2bRead, Blake2bWrite, Challenge255}, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, -}; -use crate::{ecc::EccChip, fields::fp::FpStrategy}; -use halo2_base::{ - gates::GateInstructions, - utils::{biguint_to_fe, fe_to_biguint, value_to_option, PrimeField}, - QuantumCell::Witness, }; -use num_bigint::BigUint; -use num_traits::Num; +use ark_std::{end_timer, start_timer}; +use group::Curve; +use halo2_base::utils::fe_to_biguint; +use serde::{Deserialize, Serialize}; +use std::io::Write; pub mod ec_add; pub mod fixed_base_msm; pub mod msm; +pub mod msm_sum_infinity; +pub mod msm_sum_infinity_fixed_base; pub mod pairing; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct MSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + window_bits: usize, +} diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index 4195c0f8..cfc7d40f 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -1,11 +1,24 @@ -use std::{env::var, fs::File}; - -use crate::halo2_proofs::arithmetic::FieldExt; -use halo2_base::SKIP_FIRST_PASS; +use crate::fields::FpStrategy; +use ff::{Field, PrimeField}; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + utils::fs::gen_srs, +}; +use rand_core::OsRng; +use std::{ + fs::{self, File}, + io::{BufRead, BufReader}, +}; use super::*; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct MSMCircuitParams { strategy: FpStrategy, degree: u32, @@ -19,346 +32,131 @@ struct MSMCircuitParams { window_bits: usize, } -#[derive(Clone, Debug)] -struct MSMConfig { - fp_chip: FpChip, - batch_size: usize, +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); + + let msm = ecc_chip.variable_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } -impl MSMConfig { - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - batch_size: usize, - window_bits: usize, - context_id: usize, - k: usize, - ) -> Self { - let fp_chip = FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - p, - context_id, - k, - ); - MSMConfig { fp_chip, batch_size, window_bits } - } -} - -struct MSMCircuit { - bases: Vec>, - scalars: Vec>, - batch_size: usize, - _marker: PhantomData, -} - -impl Default for MSMCircuit { - fn default() -> Self { - Self { - bases: vec![None; 10], - scalars: vec![None; 10], - batch_size: 10, - _marker: PhantomData, +fn random_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let (bases, scalars): (Vec<_>, Vec<_>) = + (0..params.batch_size).map(|_| (G1Affine::random(OsRng), Fr::random(OsRng))).unzip(); + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) } - } -} - -impl Circuit for MSMCircuit { - type Config = MSMConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - bases: vec![None; self.batch_size], - scalars: vec![None; self.batch_size], - batch_size: self.batch_size, - _marker: PhantomData, + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("MSM_CONFIG") - .unwrap_or_else(|_| "./src/bn254/configs/msm_circuit.config".to_string()); - let params: MSMCircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - MSMConfig::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - BigUint::from_str_radix(&Fq::MODULUS[2..], 16).unwrap(), - params.batch_size, - params.window_bits, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - assert_eq!(config.batch_size, self.scalars.len()); - assert_eq!(config.batch_size, self.bases.len()); - - config.fp_chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "MSM", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let witness_time = start_timer!(|| "Witness generation"); - let mut scalars_assigned = Vec::new(); - for scalar in &self.scalars { - let assignment = config.fp_chip.range.gate.assign_region_smart( - ctx, - vec![Witness(scalar.map_or(Value::unknown(), Value::known))], - vec![], - vec![], - vec![], - ); - scalars_assigned.push(vec![assignment.last().unwrap().clone()]); - } - - let ecc_chip = EccChip::construct(config.fp_chip.clone()); - let mut bases_assigned = Vec::new(); - for base in &self.bases { - let base_assigned = ecc_chip.load_private( - ctx, - ( - base.map(|pt| Value::known(biguint_to_fe(&fe_to_biguint(&pt.x)))) - .unwrap_or(Value::unknown()), - base.map(|pt| Value::known(biguint_to_fe(&fe_to_biguint(&pt.y)))) - .unwrap_or(Value::unknown()), - ), - ); - bases_assigned.push(base_assigned); - } - - let msm = ecc_chip.variable_base_msm::( - ctx, - &bases_assigned, - &scalars_assigned, - 254, - config.window_bits, - ); - - ecc_chip.field_chip.finalize(ctx); - end_timer!(witness_time); - - if self.scalars[0].is_some() { - let mut elts = Vec::new(); - for (base, scalar) in self.bases.iter().zip(&self.scalars) { - elts.push(base.unwrap() * scalar.unwrap()); - } - let msm_answer = elts.into_iter().reduce(|a, b| a + b).unwrap().to_affine(); - - let msm_x = value_to_option(msm.x.value).unwrap(); - let msm_y = value_to_option(msm.y.value).unwrap(); - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x).into()); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y).into()); - } - - #[cfg(feature = "display")] - if self.bases[0].is_some() { - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } -#[cfg(test)] #[test] fn test_msm() { - use std::env::set_var; - - use crate::halo2_proofs::arithmetic::Field; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - folder.push("configs/msm_circuit.config"); - set_var("MSM_CONFIG", &folder); - let params_str = std::fs::read_to_string(folder.as_path()) - .expect("src/bn254/configs/msm_circuit.config file should exist"); - let params: MSMCircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let k = params.degree; - - let mut rng = rand::thread_rng(); - - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..params.batch_size { - let new_pt = Some(G1Affine::random(&mut rng)); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - let circuit = - MSMCircuit:: { bases, scalars, batch_size: params.batch_size, _marker: PhantomData }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); + let path = "configs/bn254/msm_circuit.config"; + let params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = random_msm_circuit(params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } -#[cfg(test)] #[test] fn bench_msm() -> Result<(), Box> { - use std::{env::set_var, fs, io::BufRead}; - - use rand_core::OsRng; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - - folder.push("configs/bench_msm.config"); - let bench_params_file = std::fs::File::open(folder.as_path())?; - folder.pop(); - folder.pop(); - - folder.push("results/msm_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); + let config_path = "configs/bn254/bench_msm.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/bn254").unwrap(); + fs::create_dir_all("data").unwrap(); + + let results_path = "results/bn254/msm_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,window_bits,proof_time,proof_size,verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - let mut params_folder = std::path::PathBuf::new(); - params_folder.push("./params"); - if !params_folder.is_dir() { - std::fs::create_dir(params_folder.as_path())?; - } - - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); - let mut rng = OsRng; + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); + let rng = OsRng; - { - folder.pop(); - folder.push("configs/msm_circuit.tmp.config"); - set_var("MSM_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } - let params_time = start_timer!(|| "Params construction"); - let params = { - params_folder.push(format!("kzg_bn254_{}.srs", bench_params.degree)); - let fd = std::fs::File::open(params_folder.as_path()); - let params = if let Ok(mut f) = fd { - println!("Found existing params file. Reading params..."); - ParamsKZG::::read(&mut f).unwrap() - } else { - println!("Creating new params file..."); - let mut f = std::fs::File::create(params_folder.as_path())?; - let params = ParamsKZG::::setup(bench_params.degree, &mut rng); - params.write(&mut f).unwrap(); - params - }; - params_folder.pop(); - params - }; - end_timer!(params_time); + let params = gen_srs(k); + println!("{bench_params:?}"); - let circuit = MSMCircuit:: { - bases: vec![None; bench_params.batch_size], - scalars: vec![None; bench_params.batch_size], - batch_size: bench_params.batch_size, - _marker: PhantomData, - }; + let circuit = random_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; end_timer!(vk_time); - /* - let vk_size = { - folder.push(format!( - "msm_circuit_{}_{}_{}_{}_{}_{}_{}_{}_{}.vkey", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - bench_params.window_bits, - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - vk.write(&mut fd).unwrap(); - fd.metadata().unwrap().len() - }; - */ - let pk_time = start_timer!(|| "Generating pkey"); let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _idx in 0..bench_params.batch_size { - let new_pt = Some(G1Affine::random(&mut rng)); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - println!("{bench_params:?}"); - let proof_circuit = MSMCircuit:: { - bases, - scalars, - batch_size: bench_params.batch_size, - _marker: PhantomData, - }; - + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); + let circuit = + random_msm_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -366,14 +164,14 @@ fn bench_msm() -> Result<(), Box> { Challenge255, _, Blake2bWrite, G1Affine, Challenge255>, - MSMCircuit, - >(¶ms, &pk, &[proof_circuit], &[&[]], rng, &mut transcript)?; + _, + >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; let proof = transcript.finalize(); end_timer!(proof_time); let proof_size = { - folder.push(format!( - "msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -383,29 +181,28 @@ fn bench_msm() -> Result<(), Box> { bench_params.num_limbs, bench_params.batch_size, bench_params.window_bits - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierSHPLONK<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("MSM_CONFIG").unwrap())?; - writeln!( fs_results, "{},{},{},{},{},{},{},{},{},{:?},{},{:?}", diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs new file mode 100644 index 00000000..600a4931 --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -0,0 +1,183 @@ +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; +use rand_core::OsRng; +use std::fs::File; + +use super::*; + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); + + let msm = ecc_chip.variable_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs new file mode 100644 index 00000000..6cf96c7f --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -0,0 +1,183 @@ +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; +use rand_core::OsRng; +use std::fs::File; + +use super::*; + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases; + //.iter() + //.map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + //.collect::>(); + + let msm = ecc_chip.fixed_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases_assigned + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_fb_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/bn254/tests/pairing.rs b/halo2-ecc/src/bn254/tests/pairing.rs index f71f6cdd..37f82684 100644 --- a/halo2-ecc/src/bn254/tests/pairing.rs +++ b/halo2-ecc/src/bn254/tests/pairing.rs @@ -1,14 +1,26 @@ use std::{ - env::{set_var, var}, fs::{self, File}, + io::{BufRead, BufReader}, }; use super::*; -use crate::halo2_proofs::halo2curves::bn256::G2Affine; -use halo2_base::SKIP_FIRST_PASS; +use crate::fields::FieldChip; +use crate::{fields::FpStrategy, halo2_proofs::halo2curves::bn256::G2Affine}; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + halo2_proofs::poly::kzg::multiopen::{ProverGWC, VerifierGWC}, + utils::fs::gen_srs, + Context, +}; use rand_core::OsRng; -#[derive(Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct PairingCircuitParams { strategy: FpStrategy, degree: u32, @@ -20,257 +32,114 @@ struct PairingCircuitParams { num_limbs: usize, } -#[derive(Default)] -struct PairingCircuit { - P: Option, - Q: Option, - _marker: PhantomData, +fn pairing_test( + ctx: &mut Context, + params: PairingCircuitParams, + P: G1Affine, + Q: G2Affine, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let chip = PairingChip::new(&fp_chip); + + let P_assigned = chip.load_private_g1_unchecked(ctx, P); + let Q_assigned = chip.load_private_g2_unchecked(ctx, Q); + + // test optimal ate pairing + let f = chip.pairing(ctx, &Q_assigned, &P_assigned); + + let actual_f = pairing(&P, &Q); + let fp12_chip = Fp12Chip::new(&fp_chip); + // cannot directly compare f and actual_f because `Gt` has private field `Fq12` + assert_eq!( + format!("Gt({:?})", fp12_chip.get_assigned_value(&f.into())), + format!("{actual_f:?}") + ); } -impl Circuit for PairingCircuit { - type Config = FpChip; - type FloorPlanner = SimpleFloorPlanner; // V1; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("PAIRING_CONFIG") - .unwrap_or_else(|_| "./src/bn254/configs/pairing_circuit.config".to_string()); - let params: PairingCircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - PairingChip::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.range.load_lookup_table(&mut layouter)?; - let chip = PairingChip::::construct(&config); - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "pairing", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = config.new_context(region); - let ctx = &mut aux; - - let P_assigned = - chip.load_private_g1(ctx, self.P.map(Value::known).unwrap_or(Value::unknown())); - let Q_assigned = - chip.load_private_g2(ctx, self.Q.map(Value::known).unwrap_or(Value::unknown())); - - /* - // test miller loop without final exp - { - let f = chip.miller_loop(ctx, &Q_assigned, &P_assigned)?; - for fc in &f.coeffs { - assert_eq!(fc.value, fc.truncation.to_bigint()); - } - if self.P != None { - let actual_f = multi_miller_loop(&[( - &self.P.unwrap(), - &G2Prepared::from_affine(self.Q.unwrap()), - )]); - let f_val: Vec = - f.coeffs.iter().map(|x| x.value.clone().unwrap().to_str_radix(16)).collect(); - println!("single miller loop:"); - println!("actual f: {:#?}", actual_f); - println!("circuit f: {:#?}", f_val); - } - } - */ - - // test optimal ate pairing - { - let f = chip.pairing(ctx, &Q_assigned, &P_assigned); - #[cfg(feature = "display")] - for fc in &f.coeffs { - assert_eq!( - value_to_option(fc.value.clone()), - value_to_option(fc.truncation.to_bigint(chip.fp_chip.limb_bits)) - ); - } - #[cfg(feature = "display")] - if self.P.is_some() { - let actual_f = pairing(&self.P.unwrap(), &self.Q.unwrap()); - let f_val: Vec = f - .coeffs - .iter() - .map(|x| value_to_option(x.value.clone()).unwrap().to_str_radix(16)) - //.map(|x| x.to_bigint().clone().unwrap().to_str_radix(16)) - .collect(); - println!("optimal ate pairing:"); - println!("actual f: {actual_f:#?}"); - println!("circuit f: {f_val:#?}"); - } - } - - // IMPORTANT: this copies cells to the lookup advice column to perform range check lookups - // This is not optional. - config.finalize(ctx); - - #[cfg(feature = "display")] - if self.P.is_some() { - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } +fn random_pairing_circuit( + params: PairingCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let P = G1Affine::random(OsRng); + let Q = G2Affine::random(OsRng); + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + pairing_test::(builder.main(0), params, P, Q); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } #[test] fn test_pairing() { - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - folder.push("configs/pairing_circuit.config"); - set_var("PAIRING_CONFIG", &folder); - let params_str = std::fs::read_to_string(folder.as_path()) - .expect("src/bn254/configs/pairing_circuit.config file should exist"); - let params: PairingCircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let k = params.degree; - - let mut rng = OsRng; - - let P = Some(G1Affine::random(&mut rng)); - let Q = Some(G2Affine::random(&mut rng)); - - let circuit = PairingCircuit:: { P, Q, _marker: PhantomData }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); + let path = "configs/bn254/pairing_circuit.config"; + let params: PairingCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = random_pairing_circuit(params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } #[test] fn bench_pairing() -> Result<(), Box> { - use std::io::BufRead; - - use crate::halo2_proofs::poly::kzg::multiopen::{ProverGWC, VerifierGWC}; - - let mut rng = OsRng; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - - folder.push("configs/bench_pairing.config"); - let bench_params_file = std::fs::File::open(folder.as_path())?; - folder.pop(); - folder.pop(); - - folder.push("results/pairing_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); - writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size(bytes),verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - - let mut params_folder = std::path::PathBuf::new(); - params_folder.push("./params"); - if !params_folder.is_dir() { - std::fs::create_dir(params_folder.as_path())?; - } - - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let rng = OsRng; + let config_path = "configs/bn254/bench_pairing.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/bn254").unwrap(); + fs::create_dir_all("data").unwrap(); + + let results_path = "results/bn254/pairing_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); + writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; + + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: PairingCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); - - { - folder.pop(); - folder.push("configs/pairing_circuit.tmp.config"); - set_var("PAIRING_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } - let params_time = start_timer!(|| "Params construction"); - let params = { - params_folder.push(format!("kzg_bn254_{}.srs", bench_params.degree)); - let fd = std::fs::File::open(params_folder.as_path()); - let params = if let Ok(mut f) = fd { - println!("Found existing params file. Reading params..."); - ParamsKZG::::read(&mut f).unwrap() - } else { - println!("Creating new params file..."); - let mut f = std::fs::File::create(params_folder.as_path())?; - let params = ParamsKZG::::setup(bench_params.degree, &mut rng); - params.write(&mut f).unwrap(); - params - }; - params_folder.pop(); - params - }; + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); - let circuit = PairingCircuit::::default(); - end_timer!(params_time); + let params = gen_srs(k); + let circuit = random_pairing_circuit(bench_params, CircuitBuilderStage::Keygen, None); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; end_timer!(vk_time); - /* - let vk_size = { - folder.push(format!( - "pairing_circuit_{}_{}_{}_{}_{}_{}_{}.vkey", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - vk.write(&mut fd).unwrap(); - fd.metadata().unwrap().len() - }; - */ - let pk_time = start_timer!(|| "Generating pkey"); let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - let mut rng = OsRng; - let P = Some(G1Affine::random(&mut rng)); - let Q = Some(G2Affine::random(&mut rng)); - let proof_circuit = PairingCircuit:: { P, Q, _marker: PhantomData }; - + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); + let circuit = + random_pairing_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -278,14 +147,14 @@ fn bench_pairing() -> Result<(), Box> { Challenge255, _, Blake2bWrite, G1Affine, Challenge255>, - PairingCircuit, - >(¶ms, &pk, &[proof_circuit], &[&[]], rng, &mut transcript)?; + _, + >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; let proof = transcript.finalize(); end_timer!(proof_time); let proof_size = { - folder.push(format!( - "pairing_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/pairing_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -293,27 +162,27 @@ fn bench_pairing() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierGWC<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("PAIRING_CONFIG").unwrap())?; writeln!( fs_results, diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index 005f5c39..ca0b111b 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -1,107 +1,104 @@ -use crate::bigint::{big_less_than, CRTInteger}; -use crate::fields::{fp::FpConfig, FieldChip}; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::{modulus, CurveAffineExt, PrimeField}, - AssignedValue, Context, - QuantumCell::Existing, -}; +use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; -use super::fixed_base; -use super::{ec_add_unequal, scalar_multiply, EcPoint}; +use crate::bigint::{big_is_equal, big_less_than, FixedOverflowInteger, ProperCrtUint}; +use crate::fields::{fp::FpChip, FieldChip, PrimeField}; + +use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // CF is the coordinate field of GA // SF is the scalar field of GA // p = coordinate field modulus // n = scalar field modulus // Only valid when p is very close to n in size (e.g. for Secp256k1) -pub fn ecdsa_verify_no_pubkey_check<'v, F: PrimeField, CF: PrimeField, SF: PrimeField, GA>( - base_chip: &FpConfig, - ctx: &mut Context<'v, F>, - pubkey: &EcPoint as FieldChip>::FieldPoint<'v>>, - r: &CRTInteger<'v, F>, - s: &CRTInteger<'v, F>, - msghash: &CRTInteger<'v, F>, +// Assumes `r, s` are proper CRT integers +/// **WARNING**: Only use this function if `1 / (p - n)` is very small (e.g., < 2-100) +/// `pubkey` should not be the identity point +pub fn ecdsa_verify_no_pubkey_check( + chip: &EccChip>, + ctx: &mut Context, + pubkey: EcPoint as FieldChip>::FieldPoint>, + r: ProperCrtUint, + s: ProperCrtUint, + msghash: ProperCrtUint, var_window_bits: usize, fixed_window_bits: usize, -) -> AssignedValue<'v, F> +) -> AssignedValue where GA: CurveAffineExt, { - let scalar_chip = FpConfig::::construct( - base_chip.range.clone(), - base_chip.limb_bits, - base_chip.num_limbs, - modulus::(), - ); - let n = scalar_chip.load_constant(ctx, scalar_chip.p.to_biguint().unwrap()); + // Following https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm + let base_chip = chip.field_chip; + let scalar_chip = + FpChip::::new(base_chip.range, base_chip.limb_bits, base_chip.num_limbs); + let n = scalar_chip.p.to_biguint().unwrap(); + let n = FixedOverflowInteger::from_native(&n, scalar_chip.num_limbs, scalar_chip.limb_bits); + let n = n.assign(ctx); // check r,s are in [1, n - 1] - let r_valid = scalar_chip.is_soft_nonzero(ctx, r); - let s_valid = scalar_chip.is_soft_nonzero(ctx, s); + let r_valid = scalar_chip.is_soft_nonzero(ctx, &r); + let s_valid = scalar_chip.is_soft_nonzero(ctx, &s); // compute u1 = m s^{-1} mod n and u2 = r s^{-1} mod n - let u1 = scalar_chip.divide(ctx, msghash, s); - let u2 = scalar_chip.divide(ctx, r, s); - - //let r_crt = scalar_chip.to_crt(ctx, r)?; + let u1 = scalar_chip.divide_unsafe(ctx, msghash, &s); + let u2 = scalar_chip.divide_unsafe(ctx, &r, s); // compute u1 * G and u2 * pubkey - let u1_mul = fixed_base::scalar_multiply::( + let u1_mul = fixed_base::scalar_multiply( base_chip, ctx, &GA::generator(), - &u1.truncation.limbs, + u1.limbs().to_vec(), base_chip.limb_bits, fixed_window_bits, ); - let u2_mul = scalar_multiply::( + let u2_mul = scalar_multiply::<_, _, GA>( base_chip, ctx, pubkey, - &u2.truncation.limbs, + u2.limbs().to_vec(), base_chip.limb_bits, var_window_bits, ); - // check u1 * G and u2 * pubkey are not negatives and not equal - // TODO: Technically they could be equal for a valid signature, but this happens with vanishing probability - // for an ECDSA signature constructed in a standard way + // check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey + // check (u1 * G).x != (u2 * pubkey).x or (u1 * G).y == (u2 * pubkey).y // coordinates of u1_mul and u2_mul are in proper bigint form, and lie in but are not constrained to [0, n) // we therefore need hard inequality here - let u1_u2_x_eq = base_chip.is_equal(ctx, &u1_mul.x, &u2_mul.x); - let u1_u2_not_neg = base_chip.range.gate().not(ctx, Existing(&u1_u2_x_eq)); + let x_eq = base_chip.is_equal(ctx, &u1_mul.x, &u2_mul.x); + let x_neq = base_chip.gate().not(ctx, x_eq); + let y_eq = base_chip.is_equal(ctx, &u1_mul.y, &u2_mul.y); + let u1g_u2pk_not_neg = base_chip.gate().or(ctx, x_neq, y_eq); // compute (x1, y1) = u1 * G + u2 * pubkey and check (r mod n) == x1 as integers + // because it is possible for u1 * G == u2 * pubkey, we must use `EccChip::sum` + let sum = chip.sum::(ctx, [u1_mul, u2_mul]); // WARNING: For optimization reasons, does not reduce x1 mod n, which is // invalid unless p is very close to n in size. - base_chip.enforce_less_than_p(ctx, u1_mul.x()); - base_chip.enforce_less_than_p(ctx, u2_mul.x()); - let sum = ec_add_unequal(base_chip, ctx, &u1_mul, &u2_mul, false); - let equal_check = base_chip.is_equal(ctx, &sum.x, r); + // enforce x1 < n + let x1 = scalar_chip.enforce_less_than(ctx, sum.x); + let equal_check = big_is_equal::assign(base_chip.gate(), ctx, x1.0, r); - // TODO: maybe the big_less_than is optional? - let u1_small = big_less_than::assign::( + let u1_small = big_less_than::assign( base_chip.range(), ctx, - &u1.truncation, - &n.truncation, + u1, + n.clone(), base_chip.limb_bits, base_chip.limb_bases[1], ); - let u2_small = big_less_than::assign::( + let u2_small = big_less_than::assign( base_chip.range(), ctx, - &u2.truncation, - &n.truncation, + u2, + n, base_chip.limb_bits, base_chip.limb_bases[1], ); - // check (r in [1, n - 1]) and (s in [1, n - 1]) and (u1_mul != - u2_mul) and (r == x1 mod n) - let res1 = base_chip.range.gate().and(ctx, Existing(&r_valid), Existing(&s_valid)); - let res2 = base_chip.range.gate().and(ctx, Existing(&res1), Existing(&u1_small)); - let res3 = base_chip.range.gate().and(ctx, Existing(&res2), Existing(&u2_small)); - let res4 = base_chip.range.gate().and(ctx, Existing(&res3), Existing(&u1_u2_not_neg)); - let res5 = base_chip.range.gate().and(ctx, Existing(&res4), Existing(&equal_check)); + // check (r in [1, n - 1]) and (s in [1, n - 1]) and (u1 * G != - u2 * pubkey) and (r == x1 mod n) + let res1 = base_chip.gate().and(ctx, r_valid, s_valid); + let res2 = base_chip.gate().and(ctx, res1, u1_small); + let res3 = base_chip.gate().and(ctx, res2, u2_small); + let res4 = base_chip.gate().and(ctx, res3, u1g_u2pk_not_neg); + let res5 = base_chip.gate().and(ctx, res4, equal_check); res5 } diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index 64168c96..5dfba754 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,109 +1,39 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; -use crate::halo2_proofs::arithmetic::CurveAffine; -use crate::{ - bigint::{CRTInteger, FixedCRTInteger}, - fields::{PrimeFieldChip, Selectable}, -}; +use crate::ecc::{ec_sub_strict, load_random_point}; +use crate::fields::{FieldChip, PrimeField, Selectable}; use group::Curve; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::{fe_to_biguint, CurveAffineExt, PrimeField}, - AssignedValue, Context, - QuantumCell::Existing, -}; +use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; +use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use itertools::Itertools; -use num_bigint::BigUint; -use std::{cmp::min, marker::PhantomData}; - -// this only works for curves GA with base field of prime order -#[derive(Clone, Debug)] -pub struct FixedEcPoint { - pub x: FixedCRTInteger, // limbs in `F` and value in `BigUint` - pub y: FixedCRTInteger, - _marker: PhantomData, -} - -impl FixedEcPoint -where - C::Base: PrimeField, -{ - pub fn construct(x: FixedCRTInteger, y: FixedCRTInteger) -> Self { - Self { x, y, _marker: PhantomData } - } - - pub fn from_curve(point: C, num_limbs: usize, limb_bits: usize) -> Self { - let (x, y) = point.into_coordinates(); - let x = FixedCRTInteger::from_native(fe_to_biguint(&x), num_limbs, limb_bits); - let y = FixedCRTInteger::from_native(fe_to_biguint(&y), num_limbs, limb_bits); - Self::construct(x, y) - } - - pub fn assign<'v, FC>( - self, - chip: &FC, - ctx: &mut Context<'_, F>, - native_modulus: &BigUint, - ) -> EcPoint> - where - FC: PrimeFieldChip = CRTInteger<'v, F>>, - { - let assigned_x = self.x.assign(chip.range().gate(), ctx, chip.limb_bits(), native_modulus); - let assigned_y = self.y.assign(chip.range().gate(), ctx, chip.limb_bits(), native_modulus); - EcPoint::construct(assigned_x, assigned_y) - } - - pub fn assign_without_caching<'v, FC>( - self, - chip: &FC, - ctx: &mut Context<'_, F>, - native_modulus: &BigUint, - ) -> EcPoint> - where - FC: PrimeFieldChip = CRTInteger<'v, F>>, - { - let assigned_x = self.x.assign_without_caching( - chip.range().gate(), - ctx, - chip.limb_bits(), - native_modulus, - ); - let assigned_y = self.y.assign_without_caching( - chip.range().gate(), - ctx, - chip.limb_bits(), - native_modulus, - ); - EcPoint::construct(assigned_x, assigned_y) - } -} - -// computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant) -// - `scalar` is represented as a reference array of `AssignedCell`s -// - `scalar = sum_i scalar_i * 2^{max_bits * i}` -// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` -// assumes: -// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) -// - `max_bits <= modulus::.bits()` - -pub fn scalar_multiply<'v, F, FC, C>( +use rayon::prelude::*; +use std::cmp::min; + +/// Computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant) +/// - `scalar` is represented as a non-empty reference array of `AssignedValue`s +/// - `scalar = sum_i scalar_i * 2^{max_bits * i}` +/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` +/// +/// # Assumptions +/// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) +/// - `scalar > 0` +/// - `max_bits <= modulus::.bits()` +pub fn scalar_multiply( chip: &FC, - ctx: &mut Context<'v, F>, + ctx: &mut Context, point: &C, - scalar: &[AssignedValue<'v, F>], + scalar: Vec>, max_bits: usize, window_bits: usize, -) -> EcPoint> +) -> EcPoint where F: PrimeField, C: CurveAffineExt, - C::Base: PrimeField, - FC: PrimeFieldChip = CRTInteger<'v, F>> - + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { if point.is_identity().into() { - let point = FixedEcPoint::from_curve(*point, chip.num_limbs(), chip.limb_bits()); - return FixedEcPoint::assign(point, chip, ctx, chip.native_modulus()); + let zero = chip.load_constant(ctx, C::Base::zero()); + return EcPoint::new(zero.clone(), zero); } assert!(!scalar.is_empty()); assert!((max_bits as u32) <= F::NUM_BITS); @@ -141,66 +71,64 @@ where let cached_points = cached_points_affine .into_iter() .map(|point| { - let point = FixedEcPoint::from_curve(point, chip.num_limbs(), chip.limb_bits()); - FixedEcPoint::assign(point, chip, ctx, chip.native_modulus()) + let (x, y) = point.into_coordinates(); + let [x, y] = [x, y].map(|x| chip.load_constant(ctx, x)); + EcPoint::new(x, y) }) .collect_vec(); let bits = scalar - .iter() + .into_iter() .flat_map(|scalar_chunk| chip.gate().num_to_bits(ctx, scalar_chunk, max_bits)) .collect::>(); let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = chip.gate().load_zero(ctx); + let any_point = load_random_point::(chip, ctx); + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { - let bit_sum = chip.gate().sum(ctx, bit_window.iter().map(Existing)); + let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied()); // are we just adding a window of all 0s? if so, skip - let is_zero_window = chip.gate().is_zero(ctx, &bit_sum); - let add_point = ec_select_from_bits::(chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(chip, ctx, &curr_point, &sum, &is_zero_window); - Some(ec_select(chip, ctx, &zero_sum, &add_point, &is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = chip.gate().not(ctx, Existing(&is_zero_window)); - chip.gate().mul_add( - ctx, - Existing(&is_started), - Existing(&is_zero_window), - Existing(¬_zero_window), - ) + let is_zero_window = chip.gate().is_zero(ctx, bit_sum); + curr_point = { + let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, true); + ec_select(chip, ctx, curr_point, sum, is_zero_window) }; } - curr_point.unwrap() + ec_sub_strict(chip, ctx, curr_point, any_point) } // basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation // we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO) -pub fn msm<'v, F, FC, C>( + +/// # Assumptions +/// * `points.len() = scalars.len()` +/// * `scalars[i].len() = scalars[j].len()` for all `i,j` +/// * `points` are all on the curve +/// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand +/// * The integer value of `scalars[i]` is less than the order of `points[i]` +/// * Output may be point at infinity, in which case (0, 0) is returned +pub fn msm_par( chip: &EccChip, - ctx: &mut Context<'v, F>, + builder: &mut GateThreadBuilder, points: &[C], - scalars: &[Vec>], + scalars: Vec>>, max_scalar_bits_per_cell: usize, window_bits: usize, -) -> EcPoint> + phase: usize, +) -> EcPoint where F: PrimeField, C: CurveAffineExt, - C::Base: PrimeField, - FC: PrimeFieldChip = CRTInteger<'v, F>> - + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { + if points.is_empty() { + return chip.assign_constant_point(builder.main(phase), C::identity()); + } assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS); + assert_eq!(points.len(), scalars.len()); + assert!(!points.is_empty(), "fixed_base::msm_par requires at least one point"); let scalar_len = scalars[0].len(); let total_bits = max_scalar_bits_per_cell * scalar_len; let num_windows = (total_bits + window_bits - 1) / window_bits; @@ -208,10 +136,11 @@ where // `cached_points` is a flattened 2d vector // first we compute all cached points in Jacobian coordinates since it's fastest let cached_points_jacobian = points - .iter() - .flat_map(|point| { + .par_iter() + .flat_map(|point| -> Vec<_> { let base_pt = point.to_curve(); // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1} + // EXCEPT cached_points[idx][0] = points[idx] let mut increment = base_pt; (0..num_windows) .flat_map(|i| { @@ -224,80 +153,67 @@ where prev }, )) - .collect_vec(); + .collect::>(); increment = curr; cache_vec }) - .collect_vec() + .collect() }) - .collect_vec(); + .collect::>(); // for use in circuits we need affine coordinates, so we do a batch normalize: this is much more efficient than calling `to_affine` one by one since field inversion is very expensive // initialize to all 0s let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()]; C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); - let cached_points = cached_points_affine - .into_iter() - .map(|point| { - let point = - FixedEcPoint::from_curve(point, field_chip.num_limbs(), field_chip.limb_bits()); - point.assign_without_caching(field_chip, ctx, field_chip.native_modulus()) - }) - .collect_vec(); + let ctx = builder.main(phase); + let any_point = chip.load_random_point::(ctx); + + let scalar_mults = parallelize_in( + phase, + builder, + cached_points_affine + .chunks(cached_points_affine.len() / points.len()) + .zip_eq(scalars) + .collect(), + |ctx, (cached_points, scalar)| { + let cached_points = cached_points + .iter() + .map(|point| chip.assign_constant_point(ctx, *point)) + .collect_vec(); + let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); - let bits = scalars - .iter() - .flat_map(|scalar| { assert_eq!(scalar.len(), scalar_len); - scalar - .iter() + let bits = scalar + .into_iter() .flat_map(|scalar_chunk| { field_chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell) }) - .collect_vec() - }) - .collect_vec(); - - let sm = cached_points - .chunks(cached_points.len() / points.len()) - .zip(bits.chunks(total_bits)) - .map(|(cached_points, bits)| { - let cached_point_window_rev = - cached_points.chunks(1usize << window_bits).rev(); + .collect::>(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = field_chip.gate().load_zero(ctx); + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let is_zero_window = { - let sum = field_chip.gate().sum(ctx, bit_window.iter().map(Existing)); - field_chip.gate().is_zero(ctx, &sum) + let sum = field_chip.gate().sum(ctx, bit_window.iter().copied()); + field_chip.gate().is_zero(ctx, sum) }; - let add_point = - ec_select_from_bits::(field_chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(field_chip, ctx, &curr_point, &sum, &is_zero_window); - Some(ec_select(field_chip, ctx, &zero_sum, &add_point, &is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = - field_chip.range().gate().not(ctx, Existing(&is_zero_window)); - field_chip.range().gate().mul_add( - ctx, - Existing(&is_started), - Existing(&is_zero_window), - Existing(¬_zero_window), - ) + curr_point = { + let add_point = + ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, true); + ec_select(field_chip, ctx, curr_point, sum, is_zero_window) }; } - curr_point.unwrap() - }) - .collect_vec(); - chip.sum::(ctx, sm.iter()) + curr_point + }, + ); + let ctx = builder.main(phase); + // sum `scalar_mults` but take into account possiblity of identity points + let any_point2 = chip.load_random_point::(ctx); + let mut acc = any_point2.clone(); + for point in scalar_mults { + let new_acc = chip.add_unequal(ctx, &acc, point, true); + acc = chip.sub_unequal(ctx, new_acc, &any_point, true); + } + ec_sub_strict(field_chip, ctx, acc, any_point2) } diff --git a/halo2-ecc/src/ecc/fixed_base_pippenger.rs b/halo2-ecc/src/ecc/fixed_base_pippenger.rs index 1e36bfd1..05d7cf3e 100644 --- a/halo2-ecc/src/ecc/fixed_base_pippenger.rs +++ b/halo2-ecc/src/ecc/fixed_base_pippenger.rs @@ -20,14 +20,14 @@ use rand_chacha::ChaCha20Rng; // Output: // * new_points: length `points.len() * radix` // * new_bool_scalars: 2d array `ceil(scalar_bits / radix)` by `points.len() * radix` -pub fn decompose<'v, F, C>( +pub fn decompose( gate: &impl GateInstructions, - ctx: &mut Context<'v, F>, + ctx: &mut Context, points: &[C], - scalars: &Vec>>, + scalars: &Vec>>, max_scalar_bits_per_cell: usize, radix: usize, -) -> (Vec, Vec>>) +) -> (Vec, Vec>>) where F: PrimeField, C: CurveAffine, @@ -66,15 +66,15 @@ where // Given points[i] and bool_scalars[j][i], // compute G'[j] = sum_{i=0..points.len()} points[i] * bool_scalars[j][i] // output is [ G'[j] + rand_point ]_{j=0..bool_scalars.len()}, rand_point -pub fn multi_product<'v, F: PrimeField, FC, C>( +pub fn multi_product( chip: &FC, - ctx: &mut Context<'v, F>, + ctx: &mut Context, points: Vec, - bool_scalars: Vec>>, + bool_scalars: Vec>>, clumping_factor: usize, -) -> (Vec>>, EcPoint>) +) -> (Vec>, EcPoint) where - FC: PrimeFieldChip = CRTInteger<'v, F>>, + FC: PrimeFieldChip>, FC::FieldType: PrimeField, C: CurveAffine, { @@ -187,17 +187,17 @@ where (acc, rand_point) } -pub fn multi_exp<'v, F: PrimeField, FC, C>( +pub fn multi_exp( chip: &FC, - ctx: &mut Context<'v, F>, + ctx: &mut Context, points: &[C], - scalars: &Vec>>, + scalars: &Vec>>, max_scalar_bits_per_cell: usize, radix: usize, clump_factor: usize, -) -> EcPoint> +) -> EcPoint where - FC: PrimeFieldChip = CRTInteger<'v, F>>, + FC: PrimeFieldChip>, FC::FieldType: PrimeField, C: CurveAffine, { diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index 2b9cedf6..4da01281 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -1,13 +1,13 @@ #![allow(non_snake_case)] -use crate::bigint::CRTInteger; -use crate::fields::{fp::FpConfig, FieldChip, PrimeFieldChip, Selectable}; -use crate::halo2_proofs::{arithmetic::CurveAffine, circuit::Value}; +use crate::fields::{fp::FpChip, FieldChip, PrimeField, Selectable}; +use crate::halo2_proofs::arithmetic::CurveAffine; use group::{Curve, Group}; +use halo2_base::gates::builder::GateThreadBuilder; +use halo2_base::utils::modulus; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{modulus, CurveAffineExt, PrimeField}, + utils::CurveAffineExt, AssignedValue, Context, - QuantumCell::Existing, }; use itertools::Itertools; use rand::SeedableRng; @@ -21,7 +21,7 @@ pub mod pippenger; // EcPoint and EccChip take in a generic `FieldChip` to implement generic elliptic curve operations on arbitrary field extensions (provided chip exists) for short Weierstrass curves (currently further assuming a4 = 0 for optimization purposes) #[derive(Debug)] -pub struct EcPoint { +pub struct EcPoint { pub x: FieldPoint, pub y: FieldPoint, _marker: PhantomData, @@ -33,8 +33,17 @@ impl Clone for EcPoint { } } -impl EcPoint { - pub fn construct(x: FieldPoint, y: FieldPoint) -> Self { +// Improve readability by allowing `&EcPoint` to be converted to `EcPoint` via cloning +impl<'a, F: PrimeField, FieldPoint: Clone> From<&'a EcPoint> + for EcPoint +{ + fn from(value: &'a EcPoint) -> Self { + value.clone() + } +} + +impl EcPoint { + pub fn new(x: FieldPoint, y: FieldPoint) -> Self { Self { x, y, _marker: PhantomData } } @@ -47,6 +56,83 @@ impl EcPoint { } } +/// An elliptic curve point where it is easy to compare the x-coordinate of two points +#[derive(Clone, Debug)] +pub struct StrictEcPoint> { + pub x: FC::ReducedFieldPoint, + pub y: FC::FieldPoint, + _marker: PhantomData, +} + +impl> StrictEcPoint { + pub fn new(x: FC::ReducedFieldPoint, y: FC::FieldPoint) -> Self { + Self { x, y, _marker: PhantomData } + } +} + +impl> From> for EcPoint { + fn from(value: StrictEcPoint) -> Self { + Self::new(value.x.into(), value.y) + } +} + +impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> + for EcPoint +{ + fn from(value: &'a StrictEcPoint) -> Self { + value.clone().into() + } +} + +/// An elliptic curve point where the x-coordinate has already been constrained to be reduced or not. +/// In the reduced case one can more optimally compare equality of x-coordinates. +#[derive(Clone, Debug)] +pub enum ComparableEcPoint> { + Strict(StrictEcPoint), + NonStrict(EcPoint), +} + +impl> From> for ComparableEcPoint { + fn from(pt: StrictEcPoint) -> Self { + Self::Strict(pt) + } +} + +impl> From> + for ComparableEcPoint +{ + fn from(pt: EcPoint) -> Self { + Self::NonStrict(pt) + } +} + +impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> + for ComparableEcPoint +{ + fn from(pt: &'a StrictEcPoint) -> Self { + Self::Strict(pt.clone()) + } +} + +impl<'a, F: PrimeField, FC: FieldChip> From<&'a EcPoint> + for ComparableEcPoint +{ + fn from(pt: &'a EcPoint) -> Self { + Self::NonStrict(pt.clone()) + } +} + +impl> From> + for EcPoint +{ + fn from(pt: ComparableEcPoint) -> Self { + match pt { + ComparableEcPoint::Strict(pt) => Self::new(pt.x.into(), pt.y), + ComparableEcPoint::NonStrict(pt) => pt, + } + } +} + // Implements: // Given P = (x_1, y_1) and Q = (x_2, y_2), ecc points over the field F_p // assume x_1 != x_2 @@ -57,37 +143,61 @@ impl EcPoint { // x_3 = lambda^2 - x_1 - x_2 (mod p) // y_3 = lambda (x_1 - x_3) - y_1 mod p // -/// For optimization reasons, we assume that if you are using this with `is_strict = true`, then you have already called `chip.enforce_less_than_p` on both `P.x` and `P.y` -pub fn ec_add_unequal<'v, F: PrimeField, FC: FieldChip>( +/// If `is_strict = true`, then this function constrains that `P.x != Q.x`. +/// If you are calling this with `is_strict = false`, you must ensure that `P.x != Q.x` by some external logic (such +/// as a mathematical theorem). +/// +/// # Assumptions +/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) +pub fn ec_add_unequal>( chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, is_strict: bool, -) -> EcPoint> { - if is_strict { - // constrains that P.x != Q.x - let x_is_equal = chip.is_equal_unenforced(ctx, &P.x, &Q.x); - chip.range().gate().assert_is_const(ctx, &x_is_equal, F::zero()); - } +) -> EcPoint { + let (P, Q) = check_points_are_unequal(chip, ctx, P, Q, is_strict); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.sub_no_carry(ctx, &Q.y, &P.y); - let lambda = chip.divide(ctx, &dy, &dx); + let dy = chip.sub_no_carry(ctx, Q.y, &P.y); + let lambda = chip.divide_unsafe(ctx, dy, dx); // x_3 = lambda^2 - x_1 - x_2 (mod p) let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); - let lambda_sq_minus_px = chip.sub_no_carry(ctx, &lambda_sq, &P.x); - let x_3_no_carry = chip.sub_no_carry(ctx, &lambda_sq_minus_px, &Q.x); - let x_3 = chip.carry_mod(ctx, &x_3_no_carry); + let lambda_sq_minus_px = chip.sub_no_carry(ctx, lambda_sq, &P.x); + let x_3_no_carry = chip.sub_no_carry(ctx, lambda_sq_minus_px, Q.x); + let x_3 = chip.carry_mod(ctx, x_3_no_carry); // y_3 = lambda (x_1 - x_3) - y_1 mod p - let dx_13 = chip.sub_no_carry(ctx, &P.x, &x_3); - let lambda_dx_13 = chip.mul_no_carry(ctx, &lambda, &dx_13); - let y_3_no_carry = chip.sub_no_carry(ctx, &lambda_dx_13, &P.y); - let y_3 = chip.carry_mod(ctx, &y_3_no_carry); + let dx_13 = chip.sub_no_carry(ctx, P.x, &x_3); + let lambda_dx_13 = chip.mul_no_carry(ctx, lambda, dx_13); + let y_3_no_carry = chip.sub_no_carry(ctx, lambda_dx_13, P.y); + let y_3 = chip.carry_mod(ctx, y_3_no_carry); - EcPoint::construct(x_3, y_3) + EcPoint::new(x_3, y_3) +} + +/// If `do_check = true`, then this function constrains that `P.x != Q.x`. +/// Otherwise does nothing. +fn check_points_are_unequal>( + chip: &FC, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, + do_check: bool, +) -> (EcPoint /*P */, EcPoint /*Q */) { + let P = P.into(); + let Q = Q.into(); + if do_check { + // constrains that P.x != Q.x + let [x1, x2] = [&P, &Q].map(|pt| match pt { + ComparableEcPoint::Strict(pt) => pt.x.clone(), + ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), + }); + let x_is_equal = chip.is_equal_unenforced(ctx, x1, x2); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + } + (EcPoint::from(P), EcPoint::from(Q)) } // Implements: @@ -99,43 +209,83 @@ pub fn ec_add_unequal<'v, F: PrimeField, FC: FieldChip>( // y_3 = lambda (x_1 - x_3) - y_1 mod p // Assumes that P !=Q and Q != (P - Q) // -/// For optimization reasons, we assume that if you are using this with `is_strict = true`, then you have already called `chip.enforce_less_than_p` on both `P.x` and `P.y` -pub fn ec_sub_unequal<'v, F: PrimeField, FC: FieldChip>( +/// If `is_strict = true`, then this function constrains that `P.x != Q.x`. +/// If you are calling this with `is_strict = false`, you must ensure that `P.x != Q.x` by some external logic (such +/// as a mathematical theorem). +/// +/// # Assumptions +/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) +pub fn ec_sub_unequal>( chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, is_strict: bool, -) -> EcPoint> { - if is_strict { - // constrains that P.x != Q.x - let x_is_equal = chip.is_equal_unenforced(ctx, &P.x, &Q.x); - chip.range().gate().assert_is_const(ctx, &x_is_equal, F::zero()); - } +) -> EcPoint { + let (P, Q) = check_points_are_unequal(chip, ctx, P, Q, is_strict); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.add_no_carry(ctx, &Q.y, &P.y); + let dy = chip.add_no_carry(ctx, Q.y, &P.y); - let lambda = chip.neg_divide(ctx, &dy, &dx); + let lambda = chip.neg_divide_unsafe(ctx, &dy, &dx); // (x_2 - x_1) * lambda + y_2 + y_1 = 0 (mod p) - let lambda_dx = chip.mul_no_carry(ctx, &lambda, &dx); - let lambda_dx_plus_dy = chip.add_no_carry(ctx, &lambda_dx, &dy); - chip.check_carry_mod_to_zero(ctx, &lambda_dx_plus_dy); + let lambda_dx = chip.mul_no_carry(ctx, &lambda, dx); + let lambda_dx_plus_dy = chip.add_no_carry(ctx, lambda_dx, dy); + chip.check_carry_mod_to_zero(ctx, lambda_dx_plus_dy); // x_3 = lambda^2 - x_1 - x_2 (mod p) let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); - let lambda_sq_minus_px = chip.sub_no_carry(ctx, &lambda_sq, &P.x); - let x_3_no_carry = chip.sub_no_carry(ctx, &lambda_sq_minus_px, &Q.x); - let x_3 = chip.carry_mod(ctx, &x_3_no_carry); + let lambda_sq_minus_px = chip.sub_no_carry(ctx, lambda_sq, &P.x); + let x_3_no_carry = chip.sub_no_carry(ctx, lambda_sq_minus_px, Q.x); + let x_3 = chip.carry_mod(ctx, x_3_no_carry); // y_3 = lambda (x_1 - x_3) - y_1 mod p - let dx_13 = chip.sub_no_carry(ctx, &P.x, &x_3); - let lambda_dx_13 = chip.mul_no_carry(ctx, &lambda, &dx_13); - let y_3_no_carry = chip.sub_no_carry(ctx, &lambda_dx_13, &P.y); - let y_3 = chip.carry_mod(ctx, &y_3_no_carry); + let dx_13 = chip.sub_no_carry(ctx, P.x, &x_3); + let lambda_dx_13 = chip.mul_no_carry(ctx, lambda, dx_13); + let y_3_no_carry = chip.sub_no_carry(ctx, lambda_dx_13, P.y); + let y_3 = chip.carry_mod(ctx, y_3_no_carry); + + EcPoint::new(x_3, y_3) +} - EcPoint::construct(x_3, y_3) +/// Constrains `P != -Q` but allows `P == Q`, in which case output is (0,0). +/// For Weierstrass curves only. +/// +/// Assumptions +/// # Neither P or Q is the point at infinity +pub fn ec_sub_strict>( + chip: &FC, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, +) -> EcPoint +where + FC: Selectable, +{ + let mut P = P.into(); + let Q = Q.into(); + // Compute curr_point - start_point, allowing for output to be identity point + let x_is_eq = chip.is_equal(ctx, P.x(), Q.x()); + let y_is_eq = chip.is_equal(ctx, P.y(), Q.y()); + let is_identity = chip.gate().and(ctx, x_is_eq, y_is_eq); + // we ONLY allow x_is_eq = true if y_is_eq is also true; this constrains P != -Q + ctx.constrain_equal(&x_is_eq, &is_identity); + + // P.x = Q.x and P.y = Q.y + // in ec_sub_unequal it will try to do -(P.y + Q.y) / (P.x - Q.x) = -2P.y / 0 + // this will cause divide_unsafe to panic when P.y != 0 + // to avoid this, we load a random pair of points and replace P with it *only if* `is_identity == true` + // we don't even check (rand_x, rand_y) is on the curve, since we don't care about the output + let mut rng = ChaCha20Rng::from_entropy(); + let [rand_x, rand_y] = [(); 2].map(|_| FC::FieldType::random(&mut rng)); + let [rand_x, rand_y] = [rand_x, rand_y].map(|x| chip.load_private(ctx, x)); + let rand_pt = EcPoint::new(rand_x, rand_y); + P = ec_select(chip, ctx, rand_pt, P, is_identity); + + let out = ec_sub_unequal(chip, ctx, P, Q, false); + let zero = chip.load_constant(ctx, FC::FieldType::zero()); + ec_select(chip, ctx, EcPoint::new(zero.clone(), zero), out, is_identity) } // Implements: @@ -150,104 +300,212 @@ pub fn ec_sub_unequal<'v, F: PrimeField, FC: FieldChip>( // we precompute lambda and constrain (2y) * lambda = 3 x^2 (mod p) // then we compute x_3 = lambda^2 - 2 x (mod p) // y_3 = lambda (x - x_3) - y (mod p) -pub fn ec_double<'v, F: PrimeField, FC: FieldChip>( +/// # Assumptions +/// * `P.y != 0` +/// * `P` is not the point at infinity (undefined behavior otherwise) +pub fn ec_double>( chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, -) -> EcPoint> { + ctx: &mut Context, + P: impl Into>, +) -> EcPoint { + let P = P.into(); // removed optimization that computes `2 * lambda` while assigning witness to `lambda` simultaneously, in favor of readability. The difference is just copying `lambda` once let two_y = chip.scalar_mul_no_carry(ctx, &P.y, 2); let three_x = chip.scalar_mul_no_carry(ctx, &P.x, 3); - let three_x_sq = chip.mul_no_carry(ctx, &three_x, &P.x); - let lambda = chip.divide(ctx, &three_x_sq, &two_y); + let three_x_sq = chip.mul_no_carry(ctx, three_x, &P.x); + let lambda = chip.divide_unsafe(ctx, three_x_sq, two_y); // x_3 = lambda^2 - 2 x % p let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); let two_x = chip.scalar_mul_no_carry(ctx, &P.x, 2); - let x_3_no_carry = chip.sub_no_carry(ctx, &lambda_sq, &two_x); - let x_3 = chip.carry_mod(ctx, &x_3_no_carry); + let x_3_no_carry = chip.sub_no_carry(ctx, lambda_sq, two_x); + let x_3 = chip.carry_mod(ctx, x_3_no_carry); // y_3 = lambda (x - x_3) - y % p - let dx = chip.sub_no_carry(ctx, &P.x, &x_3); - let lambda_dx = chip.mul_no_carry(ctx, &lambda, &dx); - let y_3_no_carry = chip.sub_no_carry(ctx, &lambda_dx, &P.y); - let y_3 = chip.carry_mod(ctx, &y_3_no_carry); + let dx = chip.sub_no_carry(ctx, P.x, &x_3); + let lambda_dx = chip.mul_no_carry(ctx, lambda, dx); + let y_3_no_carry = chip.sub_no_carry(ctx, lambda_dx, P.y); + let y_3 = chip.carry_mod(ctx, y_3_no_carry); + + EcPoint::new(x_3, y_3) +} - EcPoint::construct(x_3, y_3) +/// Implements: +/// computing 2P + Q = P + Q + P for P = (x0, y0), Q = (x1, y1) +// using Montgomery ladder(?) to skip intermediate y computation +// from halo2wrong: https://hackmd.io/ncuKqRXzR-Cw-Au2fGzsMg?view +// lambda_0 = (y_1 - y_0) / (x_1 - x_0) +// x_2 = lambda_0^2 - x_0 - x_1 +// lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0) +// x_res = lambda_1^2 - x_0 - x_2 +// y_res = lambda_1 * (x_res - x_0) - y_0 +/// +/// # Assumptions +/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) +pub fn ec_double_and_add_unequal>( + chip: &FC, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, + is_strict: bool, +) -> EcPoint { + let P = P.into(); + let Q = Q.into(); + let mut x_0 = None; + if is_strict { + // constrains that P.x != Q.x + let [x0, x1] = [&P, &Q].map(|pt| match pt { + ComparableEcPoint::Strict(pt) => pt.x.clone(), + ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), + }); + let x_is_equal = chip.is_equal_unenforced(ctx, x0.clone(), x1); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + x_0 = Some(x0); + } + let P = EcPoint::from(P); + let Q = EcPoint::from(Q); + + let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); + let dy = chip.sub_no_carry(ctx, Q.y, &P.y); + let lambda_0 = chip.divide_unsafe(ctx, dy, dx); + + // x_2 = lambda_0^2 - x_0 - x_1 (mod p) + let lambda_0_sq = chip.mul_no_carry(ctx, &lambda_0, &lambda_0); + let lambda_0_sq_minus_x_0 = chip.sub_no_carry(ctx, lambda_0_sq, &P.x); + let x_2_no_carry = chip.sub_no_carry(ctx, lambda_0_sq_minus_x_0, Q.x); + let x_2 = chip.carry_mod(ctx, x_2_no_carry); + + if is_strict { + let x_2 = chip.enforce_less_than(ctx, x_2.clone()); + // TODO: when can we remove this check? + // constrains that x_2 != x_0 + let x_is_equal = chip.is_equal_unenforced(ctx, x_0.unwrap(), x_2); + chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + } + // lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0) + let two_y_0 = chip.scalar_mul_no_carry(ctx, &P.y, 2); + let x_2_minus_x_0 = chip.sub_no_carry(ctx, &x_2, &P.x); + let lambda_1_minus_lambda_0 = chip.divide_unsafe(ctx, two_y_0, x_2_minus_x_0); + let lambda_1_no_carry = chip.add_no_carry(ctx, lambda_0, lambda_1_minus_lambda_0); + + // x_res = lambda_1^2 - x_0 - x_2 + let lambda_1_sq_nc = chip.mul_no_carry(ctx, &lambda_1_no_carry, &lambda_1_no_carry); + let lambda_1_sq_minus_x_0 = chip.sub_no_carry(ctx, lambda_1_sq_nc, &P.x); + let x_res_no_carry = chip.sub_no_carry(ctx, lambda_1_sq_minus_x_0, x_2); + let x_res = chip.carry_mod(ctx, x_res_no_carry); + + // y_res = lambda_1 * (x_res - x_0) - y_0 + let x_res_minus_x_0 = chip.sub_no_carry(ctx, &x_res, P.x); + let lambda_1_x_res_minus_x_0 = chip.mul_no_carry(ctx, lambda_1_no_carry, x_res_minus_x_0); + let y_res_no_carry = chip.sub_no_carry(ctx, lambda_1_x_res_minus_x_0, P.y); + let y_res = chip.carry_mod(ctx, y_res_no_carry); + + EcPoint::new(x_res, y_res) } -pub fn ec_select<'v, F: PrimeField, FC>( +pub fn ec_select( chip: &FC, - ctx: &mut Context<'_, F>, - P: &EcPoint>, - Q: &EcPoint>, - sel: &AssignedValue<'v, F>, -) -> EcPoint> + ctx: &mut Context, + P: EcPoint, + Q: EcPoint, + sel: AssignedValue, +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { - let Rx = chip.select(ctx, &P.x, &Q.x, sel); - let Ry = chip.select(ctx, &P.y, &Q.y, sel); - EcPoint::construct(Rx, Ry) + let Rx = chip.select(ctx, P.x, Q.x, sel); + let Ry = chip.select(ctx, P.y, Q.y, sel); + EcPoint::new(Rx, Ry) } // takes the dot product of points with sel, where each is intepreted as // a _vector_ -pub fn ec_select_by_indicator<'v, F: PrimeField, FC>( +pub fn ec_select_by_indicator( chip: &FC, - ctx: &mut Context<'_, F>, - points: &[EcPoint>], - coeffs: &[AssignedValue<'v, F>], -) -> EcPoint> + ctx: &mut Context, + points: &[Pt], + coeffs: &[AssignedValue], +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, + Pt: Into> + Clone, { - let x_coords = points.iter().map(|P| P.x.clone()).collect::>(); - let y_coords = points.iter().map(|P| P.y.clone()).collect::>(); - let Rx = chip.select_by_indicator(ctx, &x_coords, coeffs); - let Ry = chip.select_by_indicator(ctx, &y_coords, coeffs); - EcPoint::construct(Rx, Ry) + let (x, y): (Vec<_>, Vec<_>) = points + .iter() + .map(|P| { + let P: EcPoint<_, _> = P.clone().into(); + (P.x, P.y) + }) + .unzip(); + let Rx = chip.select_by_indicator(ctx, &x, coeffs); + let Ry = chip.select_by_indicator(ctx, &y, coeffs); + EcPoint::new(Rx, Ry) } // `sel` is little-endian binary -pub fn ec_select_from_bits<'v, F: PrimeField, FC>( +pub fn ec_select_from_bits( chip: &FC, - ctx: &mut Context<'_, F>, - points: &[EcPoint>], - sel: &[AssignedValue<'v, F>], -) -> EcPoint> + ctx: &mut Context, + points: &[Pt], + sel: &[AssignedValue], +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, + Pt: Into> + Clone, { let w = sel.len(); - let num_points = points.len(); - assert_eq!(1 << w, num_points); + assert_eq!(1 << w, points.len()); let coeffs = chip.range().gate().bits_to_indicator(ctx, sel); ec_select_by_indicator(chip, ctx, points, &coeffs) } -// computes [scalar] * P on y^2 = x^3 + b -// - `scalar` is represented as a reference array of `AssignedCell`s -// - `scalar = sum_i scalar_i * 2^{max_bits * i}` -// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` -// assumes: -// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) -// - `max_bits <= modulus::.bits()` -// * P has order given by the scalar field modulus -pub fn scalar_multiply<'v, F: PrimeField, FC>( +// `sel` is little-endian binary +pub fn strict_ec_select_from_bits( chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - scalar: &Vec>, + ctx: &mut Context, + points: &[StrictEcPoint], + sel: &[AssignedValue], +) -> StrictEcPoint +where + FC: FieldChip + Selectable + Selectable, +{ + let w = sel.len(); + assert_eq!(1 << w, points.len()); + let coeffs = chip.range().gate().bits_to_indicator(ctx, sel); + let (x, y): (Vec<_>, Vec<_>) = points.iter().map(|pt| (pt.x.clone(), pt.y.clone())).unzip(); + let x = chip.select_by_indicator(ctx, &x, &coeffs); + let y = chip.select_by_indicator(ctx, &y, &coeffs); + StrictEcPoint::new(x, y) +} + +/// Computes `[scalar] * P` on short Weierstrass curve `y^2 = x^3 + b` +/// - `scalar` is represented as a reference array of `AssignedValue`s +/// - `scalar = sum_i scalar_i * 2^{max_bits * i}` +/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` +/// +/// # Assumptions +/// - `window_bits != 0` +/// - The order of `P` is at least `2^{window_bits}` (in particular, `P` is not the point at infinity) +/// - The curve has no points of order 2. +/// - `scalar_i < 2^{max_bits} for all i` +/// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` +pub fn scalar_multiply( + chip: &FC, + ctx: &mut Context, + P: EcPoint, + scalar: Vec>, max_bits: usize, window_bits: usize, -) -> EcPoint> +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, + C: CurveAffineExt, { assert!(!scalar.is_empty()); assert!((max_bits as u64) <= modulus::().bits()); - + assert!(window_bits != 0); + multi_scalar_multiply::(chip, ctx, &[P], vec![scalar], max_bits, window_bits) + /* let total_bits = max_bits * scalar.len(); let num_windows = (total_bits + window_bits - 1) / window_bits; let rounded_bitlen = num_windows * window_bits; @@ -258,24 +516,15 @@ where bits.append(&mut new_bits); } let mut rounded_bits = bits; - let zero_cell = chip.gate().load_zero(ctx); - for _ in 0..(rounded_bitlen - total_bits) { - rounded_bits.push(zero_cell.clone()); - } + let zero_cell = ctx.load_zero(); + rounded_bits.resize(rounded_bitlen, zero_cell); // is_started[idx] holds whether there is a 1 in bits with index at least (rounded_bitlen - idx) let mut is_started = Vec::with_capacity(rounded_bitlen); - for _ in 0..(rounded_bitlen - total_bits) { - is_started.push(zero_cell.clone()); - } - is_started.push(zero_cell.clone()); - for idx in 1..total_bits { - let or = chip.gate().or( - ctx, - Existing(&is_started[rounded_bitlen - total_bits + idx - 1]), - Existing(&rounded_bits[total_bits - idx]), - ); - is_started.push(or.clone()); + is_started.resize(rounded_bitlen - total_bits + 1, zero_cell); + for idx in 1..=total_bits { + let or = chip.gate().or(ctx, *is_started.last().unwrap(), rounded_bits[total_bits - idx]); + is_started.push(or); } // is_zero_window[idx] is 0/1 depending on whether bits [rounded_bitlen - window_bits * (idx + 1), rounded_bitlen - window_bits * idx) are all 0 @@ -284,29 +533,30 @@ where let temp_bits = rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx] .iter() - .map(|x| Existing(x)); + .copied(); let bit_sum = chip.gate().sum(ctx, temp_bits); - let is_zero = chip.gate().is_zero(ctx, &bit_sum); - is_zero_window.push(is_zero.clone()); + let is_zero = chip.gate().is_zero(ctx, bit_sum); + is_zero_window.push(is_zero); } - // cached_points[idx] stores idx * P, with cached_points[0] = P + let any_point = load_random_point::(chip, ctx); + // cached_points[idx] stores idx * P, with cached_points[0] = any_point let cache_size = 1usize << window_bits; let mut cached_points = Vec::with_capacity(cache_size); - cached_points.push(P.clone()); + cached_points.push(any_point); cached_points.push(P.clone()); for idx in 2..cache_size { if idx == 2 { - let double = ec_double(chip, ctx, P /*, b*/); - cached_points.push(double.clone()); + let double = ec_double(chip, ctx, &P); + cached_points.push(double); } else { - let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], P, false); - cached_points.push(new_point.clone()); + let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, false); + cached_points.push(new_point); } } - // if all the starting window bits are 0, get start_point = P - let mut curr_point = ec_select_from_bits::( + // if all the starting window bits are 0, get start_point = any_point + let mut curr_point = ec_select_from_bits( chip, ctx, &cached_points, @@ -316,48 +566,46 @@ where for idx in 1..num_windows { let mut mult_point = curr_point.clone(); for _ in 0..window_bits { - mult_point = ec_double(chip, ctx, &mult_point); + mult_point = ec_double(chip, ctx, mult_point); } - let add_point = ec_select_from_bits::( + let add_point = ec_select_from_bits( chip, ctx, &cached_points, &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, false); - let is_started_point = - ec_select(chip, ctx, &mult_point, &mult_and_add, &is_zero_window[idx]); + // if is_zero_window[idx] = true, add_point = any_point. We only need any_point to avoid divide by zero in add_unequal + // if is_zero_window = true and is_started = false, then mult_point = 2^window_bits * any_point. Since window_bits != 0, we have mult_point != +- any_point + let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, true); + let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]); curr_point = - ec_select(chip, ctx, &is_started_point, &add_point, &is_started[window_bits * idx]); + ec_select(chip, ctx, is_started_point, add_point, is_started[window_bits * idx]); } - curr_point + // if at the end, return identity point (0,0) if still not started + let zero = chip.load_constant(ctx, FC::FieldType::zero()); + ec_select(chip, ctx, curr_point, EcPoint::new(zero.clone(), zero), *is_started.last().unwrap()) + */ } -pub fn is_on_curve<'v, F, FC, C>( - chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, -) where +/// Checks that `P` is indeed a point on the elliptic curve `C`. +pub fn check_is_on_curve(chip: &FC, ctx: &mut Context, P: &EcPoint) +where F: PrimeField, FC: FieldChip, C: CurveAffine, { let lhs = chip.mul_no_carry(ctx, &P.y, &P.y); - let mut rhs = chip.mul(ctx, &P.x, &P.x); - rhs = chip.mul_no_carry(ctx, &rhs, &P.x); + let mut rhs = chip.mul(ctx, &P.x, &P.x).into(); + rhs = chip.mul_no_carry(ctx, rhs, &P.x); - let b = FC::fe_to_constant(C::b()); - rhs = chip.add_constant_no_carry(ctx, &rhs, b); - let diff = chip.sub_no_carry(ctx, &lhs, &rhs); - chip.check_carry_mod_to_zero(ctx, &diff) + rhs = chip.add_constant_no_carry(ctx, rhs, C::b()); + let diff = chip.sub_no_carry(ctx, lhs, rhs); + chip.check_carry_mod_to_zero(ctx, diff) } -pub fn load_random_point<'v, F, FC, C>( - chip: &FC, - ctx: &mut Context<'v, F>, -) -> EcPoint> +pub fn load_random_point(chip: &FC, ctx: &mut Context) -> EcPoint where F: PrimeField, FC: FieldChip, @@ -365,34 +613,55 @@ where { let base_point: C = C::CurveExt::random(ChaCha20Rng::from_entropy()).to_affine(); let (x, y) = base_point.into_coordinates(); - let pt_x = FC::fe_to_witness(&Value::known(x)); - let pt_y = FC::fe_to_witness(&Value::known(y)); let base = { - let x_overflow = chip.load_private(ctx, pt_x); - let y_overflow = chip.load_private(ctx, pt_y); - EcPoint::construct(x_overflow, y_overflow) + let x_overflow = chip.load_private(ctx, x); + let y_overflow = chip.load_private(ctx, y); + EcPoint::new(x_overflow, y_overflow) }; // for above reason we still need to constrain that the witness is on the curve - is_on_curve::(chip, ctx, &base); + check_is_on_curve::(chip, ctx, &base); base } +pub fn into_strict_point( + chip: &FC, + ctx: &mut Context, + pt: EcPoint, +) -> StrictEcPoint +where + F: PrimeField, + FC: FieldChip, +{ + let x = chip.enforce_less_than(ctx, pt.x); + StrictEcPoint::new(x, pt.y) +} + // need to supply an extra generic `C` implementing `CurveAffine` trait in order to generate random witness points on the curve in question // Using Simultaneous 2^w-Ary Method, see https://www.bmoeller.de/pdf/multiexp-sac2001.pdf // Random Accumlation point trick learned from halo2wrong: https://hackmd.io/ncuKqRXzR-Cw-Au2fGzsMg?view // Input: // - `scalars` is vector of same length as `P` // - each `scalar` in `scalars` satisfies same assumptions as in `scalar_multiply` above -pub fn multi_scalar_multiply<'v, F: PrimeField, FC, C>( + +/// # Assumptions +/// * `points.len() == scalars.len()` +/// * `scalars[i].len() == scalars[j].len()` for all `i, j` +/// * `scalars[i]` is less than the order of `P` +/// * `scalars[i][j] < 2^{max_bits} for all j` +/// * `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` +/// * `points` are all on the curve or the point at infinity +/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) +/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point +pub fn multi_scalar_multiply( chip: &FC, - ctx: &mut Context<'v, F>, - P: &[EcPoint>], - scalars: &[Vec>], + ctx: &mut Context, + P: &[EcPoint], + scalars: Vec>>, max_bits: usize, window_bits: usize, -) -> EcPoint> +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, C: CurveAffineExt, { let k = P.len(); @@ -406,22 +675,20 @@ where let num_windows = (total_bits + window_bits - 1) / window_bits; let rounded_bitlen = num_windows * window_bits; - let zero_cell = chip.gate().load_zero(ctx); + let zero_cell = ctx.load_zero(); let rounded_bits = scalars - .iter() + .into_iter() .flat_map(|scalar| { - assert_eq!(scalar.len(), scalar_len); + debug_assert_eq!(scalar.len(), scalar_len); scalar - .iter() + .into_iter() .flat_map(|scalar_chunk| chip.gate().num_to_bits(ctx, scalar_chunk, max_bits)) - .chain( - std::iter::repeat_with(|| zero_cell.clone()).take(rounded_bitlen - total_bits), - ) + .chain(std::iter::repeat(zero_cell).take(rounded_bitlen - total_bits)) .collect_vec() }) .collect_vec(); - // load random C point as witness + // load any sufficiently generic C point as witness // note that while we load a random point, an adversary would load a specifically chosen point, so we must carefully handle edge cases with constraints let base = load_random_point::(chip, ctx); // contains random base points [A, ..., 2^{w + k - 1} * A] @@ -446,19 +713,19 @@ where ctx, &rand_start_vec[idx], &rand_start_vec[idx + window_bits], - false, + true, // not necessary if we assume (2^w - 1) * A != +- A, but put in for safety ); - chip.enforce_less_than(ctx, point.x()); - chip.enforce_less_than(ctx, neg_mult_rand_start.x()); + let point = into_strict_point(chip, ctx, point.clone()); + let neg_mult_rand_start = into_strict_point(chip, ctx, neg_mult_rand_start); // cached_points[i][0..cache_size] stores (1 - 2^w) * 2^i * A + [0..cache_size] * P_i cached_points.push(neg_mult_rand_start); for _ in 0..(cache_size - 1) { - let prev = cached_points.last().unwrap(); + let prev = cached_points.last().unwrap().clone(); // adversary could pick `A` so add equal case occurs, so we must use strict add_unequal - let mut new_point = ec_add_unequal(chip, ctx, prev, point, true); + let mut new_point = ec_add_unequal(chip, ctx, &prev, &point, true); // special case for when P[idx] = O - new_point = ec_select(chip, ctx, prev, &new_point, &is_infinity); - chip.enforce_less_than(ctx, new_point.x()); + new_point = ec_select(chip, ctx, prev.into(), new_point, is_infinity); + let new_point = into_strict_point(chip, ctx, new_point); cached_points.push(new_point); } } @@ -467,39 +734,35 @@ where // note k can be large (e.g., 800) so 2^{k+1} may be larger than the order of A // random fact: 2^{k + 1} - 1 can be prime: see Mersenne primes // TODO: I don't see a way to rule out 2^{k+1} A = +-A case in general, so will use strict sub_unequal - let start_point = if k < F::CAPACITY as usize { - ec_sub_unequal(chip, ctx, &rand_start_vec[k], &rand_start_vec[0], false) - } else { - chip.enforce_less_than(ctx, rand_start_vec[k].x()); - chip.enforce_less_than(ctx, rand_start_vec[0].x()); - ec_sub_unequal(chip, ctx, &rand_start_vec[k], &rand_start_vec[0], true) - }; + let start_point = ec_sub_unequal( + chip, + ctx, + &rand_start_vec[k], + &rand_start_vec[0], + true, // k >= F::CAPACITY as usize, // this assumed random points on `C` were of prime order equal to modulus of `F`. Since this is easily missed, we turn on strict mode always + ); let mut curr_point = start_point.clone(); // compute \sum_i x_i P_i + (2^{k + 1} - 1) * A for idx in 0..num_windows { for _ in 0..window_bits { - curr_point = ec_double(chip, ctx, &curr_point); + curr_point = ec_double(chip, ctx, curr_point); } - for (cached_points, rounded_bits) in cached_points - .chunks(cache_size) - .zip(rounded_bits.chunks(rounded_bitlen)) + for (cached_points, rounded_bits) in + cached_points.chunks(cache_size).zip(rounded_bits.chunks(rounded_bitlen)) { - let add_point = ec_select_from_bits::( + let add_point = ec_select_from_bits( chip, ctx, cached_points, &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - chip.enforce_less_than(ctx, curr_point.x()); // this all needs strict add_unequal since A can be non-randomly chosen by adversary - curr_point = ec_add_unequal(chip, ctx, &curr_point, &add_point, true); + curr_point = ec_add_unequal(chip, ctx, curr_point, add_point, true); } } - chip.enforce_less_than(ctx, start_point.x()); - chip.enforce_less_than(ctx, curr_point.x()); - ec_sub_unequal(chip, ctx, &curr_point, &start_point, true) + ec_sub_strict(chip, ctx, curr_point, start_point) } pub fn get_naf(mut exp: Vec) -> Vec { @@ -546,247 +809,278 @@ pub fn get_naf(mut exp: Vec) -> Vec { naf } -pub type BaseFieldEccChip = EccChip< +pub type BaseFieldEccChip<'chip, C> = EccChip< + 'chip, ::ScalarExt, - FpConfig<::ScalarExt, ::Base>, + FpChip<'chip, ::ScalarExt, ::Base>, >; #[derive(Clone, Debug)] -pub struct EccChip> { - pub field_chip: FC, +pub struct EccChip<'chip, F: PrimeField, FC: FieldChip> { + pub field_chip: &'chip FC, _marker: PhantomData, } -impl> EccChip { - pub fn construct(field_chip: FC) -> Self { +impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { + pub fn new(field_chip: &'chip FC) -> Self { Self { field_chip, _marker: PhantomData } } pub fn field_chip(&self) -> &FC { - &self.field_chip + self.field_chip } - pub fn load_private<'v>( + /// Load affine point as private witness. Constrains witness to lie on curve. Does not allow (0, 0) point, + pub fn load_private( &self, - ctx: &mut Context<'_, F>, - point: (Value, Value), - ) -> EcPoint> { - let (x, y) = (FC::fe_to_witness(&point.0), FC::fe_to_witness(&point.1)); + ctx: &mut Context, + (x, y): (FC::FieldType, FC::FieldType), + ) -> EcPoint + where + C: CurveAffineExt, + { + let pt = self.load_private_unchecked(ctx, (x, y)); + self.assert_is_on_curve::(ctx, &pt); + pt + } + /// Does not constrain witness to lie on curve + pub fn load_private_unchecked( + &self, + ctx: &mut Context, + (x, y): (FC::FieldType, FC::FieldType), + ) -> EcPoint { let x_assigned = self.field_chip.load_private(ctx, x); let y_assigned = self.field_chip.load_private(ctx, y); - EcPoint::construct(x_assigned, y_assigned) + EcPoint::new(x_assigned, y_assigned) } - /// Does not constrain witness to lie on curve - pub fn assign_point<'v, C>( - &self, - ctx: &mut Context<'_, F>, - g: Value, - ) -> EcPoint> + /// Load affine point as private witness. Constrains witness to either lie on curve or be the point at infinity, + /// represented in affine coordinates as (0, 0). + pub fn assign_point(&self, ctx: &mut Context, g: C) -> EcPoint where C: CurveAffineExt, + C::Base: ff::PrimeField, { - let (x, y) = g.map(|g| g.into_coordinates()).unzip(); - self.load_private(ctx, (x, y)) + let pt = self.assign_point_unchecked(ctx, g); + let is_on_curve = self.is_on_curve_or_infinity::(ctx, &pt); + self.field_chip.gate().assert_is_const(ctx, &is_on_curve, &F::one()); + pt } - pub fn assign_constant_point<'v, C>( + /// Does not constrain witness to lie on curve + pub fn assign_point_unchecked( &self, - ctx: &mut Context<'_, F>, + ctx: &mut Context, g: C, - ) -> EcPoint> + ) -> EcPoint + where + C: CurveAffineExt, + { + let (x, y) = g.into_coordinates(); + self.load_private_unchecked(ctx, (x, y)) + } + + pub fn assign_constant_point(&self, ctx: &mut Context, g: C) -> EcPoint where C: CurveAffineExt, { let (x, y) = g.into_coordinates(); - let [x, y] = [x, y].map(FC::fe_to_constant); let x = self.field_chip.load_constant(ctx, x); let y = self.field_chip.load_constant(ctx, y); - EcPoint::construct(x, y) + EcPoint::new(x, y) } - pub fn load_random_point<'v, C>( - &self, - ctx: &mut Context<'v, F>, - ) -> EcPoint> + pub fn load_random_point(&self, ctx: &mut Context) -> EcPoint where C: CurveAffineExt, { load_random_point::(self.field_chip(), ctx) } - pub fn assert_is_on_curve<'v, C>( - &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - ) where + pub fn assert_is_on_curve(&self, ctx: &mut Context, P: &EcPoint) + where C: CurveAffine, { - is_on_curve::(&self.field_chip, ctx, P) + check_is_on_curve::(self.field_chip, ctx, P) } - pub fn is_on_curve_or_infinity<'v, C>( + pub fn is_on_curve_or_infinity( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - ) -> AssignedValue<'v, F> + ctx: &mut Context, + P: &EcPoint, + ) -> AssignedValue where C: CurveAffine, - C::Base: ff::PrimeField, { let lhs = self.field_chip.mul_no_carry(ctx, &P.y, &P.y); - let mut rhs = self.field_chip.mul(ctx, &P.x, &P.x); - rhs = self.field_chip.mul_no_carry(ctx, &rhs, &P.x); + let mut rhs = self.field_chip.mul(ctx, &P.x, &P.x).into(); + rhs = self.field_chip.mul_no_carry(ctx, rhs, &P.x); - let b = FC::fe_to_constant(C::b()); - rhs = self.field_chip.add_constant_no_carry(ctx, &rhs, b); - let mut diff = self.field_chip.sub_no_carry(ctx, &lhs, &rhs); - diff = self.field_chip.carry_mod(ctx, &diff); + rhs = self.field_chip.add_constant_no_carry(ctx, rhs, C::b()); + let diff = self.field_chip.sub_no_carry(ctx, lhs, rhs); + let diff = self.field_chip.carry_mod(ctx, diff); - let is_on_curve = self.field_chip.is_zero(ctx, &diff); + let is_on_curve = self.field_chip.is_zero(ctx, diff); let x_is_zero = self.field_chip.is_zero(ctx, &P.x); let y_is_zero = self.field_chip.is_zero(ctx, &P.y); - self.field_chip.range().gate().or_and( - ctx, - Existing(&is_on_curve), - Existing(&x_is_zero), - Existing(&y_is_zero), - ) + self.field_chip.range().gate().or_and(ctx, is_on_curve, x_is_zero, y_is_zero) } - pub fn negate<'v>( + pub fn negate( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - ) -> EcPoint> { - EcPoint::construct(P.x.clone(), self.field_chip.negate(ctx, &P.y)) + ctx: &mut Context, + P: impl Into>, + ) -> EcPoint { + let P = P.into(); + EcPoint::new(P.x, self.field_chip.negate(ctx, P.y)) } /// Assumes that P.x != Q.x /// If `is_strict == true`, then actually constrains that `P.x != Q.x` - pub fn add_unequal<'v>( + pub fn add_unequal( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, is_strict: bool, - ) -> EcPoint> { - ec_add_unequal(&self.field_chip, ctx, P, Q, is_strict) + ) -> EcPoint { + ec_add_unequal(self.field_chip, ctx, P, Q, is_strict) } /// Assumes that P.x != Q.x /// Otherwise will panic - pub fn sub_unequal<'v>( + pub fn sub_unequal( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, is_strict: bool, - ) -> EcPoint> { - ec_sub_unequal(&self.field_chip, ctx, P, Q, is_strict) + ) -> EcPoint { + ec_sub_unequal(self.field_chip, ctx, P, Q, is_strict) } - pub fn double<'v>( + pub fn double( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - ) -> EcPoint> { - ec_double(&self.field_chip, ctx, P) + ctx: &mut Context, + P: impl Into>, + ) -> EcPoint { + ec_double(self.field_chip, ctx, P) } - pub fn is_equal<'v>( + pub fn is_equal( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, - ) -> AssignedValue<'v, F> { + ctx: &mut Context, + P: EcPoint, + Q: EcPoint, + ) -> AssignedValue { // TODO: optimize - let x_is_equal = self.field_chip.is_equal(ctx, &P.x, &Q.x); - let y_is_equal = self.field_chip.is_equal(ctx, &P.y, &Q.y); - self.field_chip.range().gate().and(ctx, Existing(&x_is_equal), Existing(&y_is_equal)) + let x_is_equal = self.field_chip.is_equal(ctx, P.x, Q.x); + let y_is_equal = self.field_chip.is_equal(ctx, P.y, Q.y); + self.field_chip.range().gate().and(ctx, x_is_equal, y_is_equal) } - pub fn assert_equal<'v>( + pub fn assert_equal( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: EcPoint, + Q: EcPoint, ) { - self.field_chip.assert_equal(ctx, &P.x, &Q.x); - self.field_chip.assert_equal(ctx, &P.y, &Q.y); + self.field_chip.assert_equal(ctx, P.x, Q.x); + self.field_chip.assert_equal(ctx, P.y, Q.y); } - pub fn sum<'b, 'v: 'b, C>( + /// None of elements in `points` can be point at infinity. + pub fn sum( &self, - ctx: &mut Context<'v, F>, - points: impl Iterator>>, - ) -> EcPoint> + ctx: &mut Context, + points: impl IntoIterator>, + ) -> EcPoint where C: CurveAffineExt, - FC::FieldPoint<'v>: 'b, { let rand_point = self.load_random_point::(ctx); - self.field_chip.enforce_less_than(ctx, rand_point.x()); + let rand_point = into_strict_point(self.field_chip, ctx, rand_point); let mut acc = rand_point.clone(); for point in points { - self.field_chip.enforce_less_than(ctx, point.x()); - acc = self.add_unequal(ctx, &acc, point, true); - self.field_chip.enforce_less_than(ctx, acc.x()); + let _acc = self.add_unequal(ctx, acc, point, true); + acc = into_strict_point(self.field_chip, ctx, _acc); } - self.sub_unequal(ctx, &acc, &rand_point, true) + self.sub_unequal(ctx, acc, rand_point, true) } } -impl> EccChip +impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> where - for<'v> FC: Selectable = FC::FieldPoint<'v>>, + FC: Selectable, { - pub fn select<'v>( + pub fn select( &self, - ctx: &mut Context<'_, F>, - P: &EcPoint>, - Q: &EcPoint>, - condition: &AssignedValue<'v, F>, - ) -> EcPoint> { - ec_select(&self.field_chip, ctx, P, Q, condition) + ctx: &mut Context, + P: EcPoint, + Q: EcPoint, + condition: AssignedValue, + ) -> EcPoint { + ec_select(self.field_chip, ctx, P, Q, condition) } - pub fn scalar_mult<'v>( + /// See [`scalar_multiply`] for more details. + pub fn scalar_mult( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - scalar: &Vec>, + ctx: &mut Context, + P: EcPoint, + scalar: Vec>, max_bits: usize, window_bits: usize, - ) -> EcPoint> { - scalar_multiply::(&self.field_chip, ctx, P, scalar, max_bits, window_bits) + ) -> EcPoint + where + C: CurveAffineExt, + { + scalar_multiply::(self.field_chip, ctx, P, scalar, max_bits, window_bits) } - // TODO: put a check in place that scalar is < modulus of C::Scalar - pub fn variable_base_msm<'v, C>( + // default for most purposes + /// See [`pippenger::multi_exp_par`] for more details. + pub fn variable_base_msm( + &self, + thread_pool: &mut GateThreadBuilder, + P: &[EcPoint], + scalars: Vec>>, + max_bits: usize, + ) -> EcPoint + where + C: CurveAffineExt, + FC: Selectable, + { + // window_bits = 4 is optimal from empirical observations + self.variable_base_msm_in::(thread_pool, P, scalars, max_bits, 4, 0) + } + + // TODO: add asserts to validate input assumptions described in docs + pub fn variable_base_msm_in( &self, - ctx: &mut Context<'v, F>, - P: &[EcPoint>], - scalars: &[Vec>], + builder: &mut GateThreadBuilder, + P: &[EcPoint], + scalars: Vec>>, max_bits: usize, window_bits: usize, - ) -> EcPoint> + phase: usize, + ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, + FC: Selectable, { #[cfg(feature = "display")] println!("computing length {} MSM", P.len()); if P.len() <= 25 { multi_scalar_multiply::( - &self.field_chip, - ctx, + self.field_chip, + builder.main(phase), P, scalars, max_bits, @@ -800,40 +1094,37 @@ where if radix == 0 { radix = 1; }*/ - let radix = 1; - pippenger::multi_exp::( - &self.field_chip, - ctx, + // guessing that is is always better to use parallelism for >25 points + pippenger::multi_exp_par::( + self.field_chip, + builder, P, scalars, max_bits, - radix, - window_bits, + window_bits, // clump_factor := window_bits + phase, ) } } } -impl> EccChip -where - FC::FieldType: PrimeField, -{ +impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { + /// See [`fixed_base::scalar_multiply`] for more details. // TODO: put a check in place that scalar is < modulus of C::Scalar - pub fn fixed_base_scalar_mult<'v, C>( + pub fn fixed_base_scalar_mult( &self, - ctx: &mut Context<'v, F>, + ctx: &mut Context, point: &C, - scalar: &[AssignedValue<'v, F>], + scalar: Vec>, max_bits: usize, window_bits: usize, - ) -> EcPoint> + ) -> EcPoint where C: CurveAffineExt, - FC: PrimeFieldChip = CRTInteger<'v, F>> - + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { fixed_base::scalar_multiply::( - &self.field_chip, + self.field_chip, ctx, point, scalar, @@ -842,30 +1133,52 @@ where ) } - /// `radix = 0` means auto-calculate - /// + // default for most purposes + pub fn fixed_base_msm( + &self, + builder: &mut GateThreadBuilder, + points: &[C], + scalars: Vec>>, + max_scalar_bits_per_cell: usize, + ) -> EcPoint + where + C: CurveAffineExt, + FC: FieldChip + Selectable, + { + self.fixed_base_msm_in::(builder, points, scalars, max_scalar_bits_per_cell, 4, 0) + } + + // `radix = 0` means auto-calculate + // /// `clump_factor = 0` means auto-calculate /// /// The user should filter out base points that are identity beforehand; we do not separately do this here - pub fn fixed_base_msm<'v, C>( + pub fn fixed_base_msm_in( &self, - ctx: &mut Context<'v, F>, + builder: &mut GateThreadBuilder, points: &[C], - scalars: &[Vec>], + scalars: Vec>>, max_scalar_bits_per_cell: usize, - _radix: usize, clump_factor: usize, - ) -> EcPoint> + phase: usize, + ) -> EcPoint where C: CurveAffineExt, - FC: PrimeFieldChip = CRTInteger<'v, F>> - + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { assert_eq!(points.len(), scalars.len()); #[cfg(feature = "display")] println!("computing length {} fixed base msm", points.len()); - fixed_base::msm(self, ctx, points, scalars, max_scalar_bits_per_cell, clump_factor) + fixed_base::msm_par( + self, + builder, + points, + scalars, + max_scalar_bits_per_cell, + clump_factor, + phase, + ) // Empirically does not seem like pippenger is any better for fixed base msm right now, because of the cost of `select_by_indicator` // Cell usage becomes around comparable when `points.len() > 100`, and `clump_factor` should always be 4 diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index 4598ab1a..934a7432 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -1,12 +1,18 @@ use super::{ - ec_add_unequal, ec_double, ec_select, ec_select_from_bits, ec_sub_unequal, load_random_point, - EcPoint, + ec_add_unequal, ec_double, ec_select, ec_sub_unequal, into_strict_point, load_random_point, + strict_ec_select_from_bits, EcPoint, +}; +use crate::{ + ecc::ec_sub_strict, + fields::{FieldChip, PrimeField, Selectable}, }; -use crate::fields::{FieldChip, Selectable}; use halo2_base::{ - gates::GateInstructions, - utils::{CurveAffineExt, PrimeField}, - AssignedValue, Context, + gates::{ + builder::{parallelize_in, GateThreadBuilder}, + GateInstructions, + }, + utils::CurveAffineExt, + AssignedValue, }; // Reference: https://jbootle.github.io/Misc/pippenger.pdf @@ -15,14 +21,17 @@ use halo2_base::{ // Output: // * new_points: length `points.len() * radix` // * new_bool_scalars: 2d array `ceil(scalar_bits / radix)` by `points.len() * radix` -pub fn decompose<'v, F, FC>( +// +// Empirically `radix = 1` is best, so we don't use this function for now +/* +pub fn decompose( chip: &FC, - ctx: &mut Context<'v, F>, - points: &[EcPoint>], - scalars: &[Vec>], + ctx: &mut Context, + points: &[EcPoint], + scalars: &[Vec>], max_scalar_bits_per_cell: usize, radix: usize, -) -> (Vec>>, Vec>>) +) -> (Vec>, Vec>>) where F: PrimeField, FC: FieldChip, @@ -34,7 +43,7 @@ where let mut new_points = Vec::with_capacity(radix * points.len()); let mut new_bool_scalars = vec![Vec::with_capacity(radix * points.len()); t]; - let zero_cell = chip.gate().load_zero(ctx); + let zero_cell = ctx.load_zero(); for (point, scalar) in points.iter().zip(scalars.iter()) { assert_eq!(scalars[0].len(), scalar.len()); let mut g = point.clone(); @@ -46,7 +55,7 @@ where } let mut bits = Vec::with_capacity(scalar_bits); for x in scalar { - let mut new_bits = chip.gate().num_to_bits(ctx, x, max_scalar_bits_per_cell); + let mut new_bits = chip.gate().num_to_bits(ctx, *x, max_scalar_bits_per_cell); bits.append(&mut new_bits); } for k in 0..t { @@ -58,19 +67,21 @@ where (new_points, new_bool_scalars) } +*/ +/* Left as reference; should always use msm_par // Given points[i] and bool_scalars[j][i], // compute G'[j] = sum_{i=0..points.len()} points[i] * bool_scalars[j][i] // output is [ G'[j] + rand_point ]_{j=0..bool_scalars.len()}, rand_point -pub fn multi_product<'v, F: PrimeField, FC, C>( +pub fn multi_product( chip: &FC, - ctx: &mut Context<'v, F>, - points: &[EcPoint>], - bool_scalars: &[Vec>], + ctx: &mut Context, + points: &[EcPoint], + bool_scalars: &[Vec>], clumping_factor: usize, -) -> (Vec>>, EcPoint>) +) -> (Vec>, EcPoint) where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable + Selectable, C: CurveAffineExt, { let c = clumping_factor; // this is `b` in Section 3 of Bootle @@ -79,127 +90,252 @@ where // we use a trick from halo2wrong where we load a random C point as witness // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints // TODO: an alternate approach is to use Fiat-Shamir transform (with Poseidon) to hash all the inputs (points, bool_scalars, ...) to get the random point. This could be worth it for large MSMs as we get savings from `add_unequal` in "non-strict" mode. Perhaps not worth the trouble / security concern, though. - let rand_base = load_random_point::(chip, ctx); + let any_base = load_random_point::(chip, ctx); let mut acc = Vec::with_capacity(bool_scalars.len()); let mut bucket = Vec::with_capacity(1 << c); - let mut rand_point = rand_base.clone(); + let mut any_point = any_base.clone(); for (round, points_clump) in points.chunks(c).enumerate() { // compute all possible multi-products of elements in points[round * c .. round * (c+1)] // for later addition collision-prevension, we need a different random point per round // we take 2^round * rand_base if round > 0 { - rand_point = ec_double(chip, ctx, &rand_point); + any_point = ec_double(chip, ctx, any_point); } // stores { rand_point, rand_point + points[0], rand_point + points[1], rand_point + points[0] + points[1] , ... } // since rand_point is random, we can always use add_unequal (with strict constraint checking that the points are indeed unequal and not negative of each other) bucket.clear(); - chip.enforce_less_than(ctx, rand_point.x()); - bucket.push(rand_point.clone()); + let strict_any_point = into_strict_point(chip, ctx, any_point.clone()); + bucket.push(strict_any_point); for (i, point) in points_clump.iter().enumerate() { // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates // this can be checked by points[i].y == 0 iff points[i] == O let is_infinity = chip.is_zero(ctx, &point.y); - chip.enforce_less_than(ctx, point.x()); + let point = into_strict_point(chip, ctx, point.clone()); for j in 0..(1 << i) { - let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], point, true); + let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true); // if points[i] is point at infinity, do nothing - new_point = ec_select(chip, ctx, &bucket[j], &new_point, &is_infinity); - chip.enforce_less_than(ctx, new_point.x()); + new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity); + let new_point = into_strict_point(chip, ctx, new_point); bucket.push(new_point); } } // for each j, select using clump in e[j][i=...] for (j, bits) in bool_scalars.iter().enumerate() { - let multi_prod = ec_select_from_bits::( + let multi_prod = strict_ec_select_from_bits( chip, ctx, &bucket, &bits[round * c..round * c + points_clump.len()], ); + // since `bucket` is all `StrictEcPoint` and we are selecting from it, we know `multi_prod` is StrictEcPoint // everything in bucket has already been enforced if round == 0 { acc.push(multi_prod); } else { - acc[j] = ec_add_unequal(chip, ctx, &acc[j], &multi_prod, true); - chip.enforce_less_than(ctx, acc[j].x()); + let _acc = ec_add_unequal(chip, ctx, &acc[j], multi_prod, true); + acc[j] = into_strict_point(chip, ctx, _acc); } } } // we have acc[j] = G'[j] + (2^num_rounds - 1) * rand_base - rand_point = ec_double(chip, ctx, &rand_point); - rand_point = ec_sub_unequal(chip, ctx, &rand_point, &rand_base, false); + any_point = ec_double(chip, ctx, any_point); + any_point = ec_sub_unequal(chip, ctx, any_point, any_base, false); - (acc, rand_point) + (acc, any_point) } -pub fn multi_exp<'v, F: PrimeField, FC, C>( +/// Currently does not support if the final answer is actually the point at infinity (meaning constraints will fail in that case) +/// +/// # Assumptions +/// * `points.len() == scalars.len()` +/// * `scalars[i].len() == scalars[j].len()` for all `i, j` +pub fn multi_exp( chip: &FC, - ctx: &mut Context<'v, F>, - points: &[EcPoint>], - scalars: &[Vec>], + ctx: &mut Context, + points: &[EcPoint], + scalars: Vec>>, max_scalar_bits_per_cell: usize, - radix: usize, + // radix: usize, // specialize to radix = 1 clump_factor: usize, -) -> EcPoint> +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable + Selectable, C: CurveAffineExt, { - let (points, bool_scalars) = - decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); - - /* - let t = bool_scalars.len(); - let c = { - let m = points.len(); - let cost = |b: usize| -> usize { (m + b - 1) / b * ((1 << b) + t) }; - let c_max: usize = f64::from(points.len() as u32).log2().ceil() as usize; - let mut c_best = c_max; - for b in 1..c_max { - if cost(b) <= cost(c_best) { - c_best = b; + // let (points, bool_scalars) = decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); + + debug_assert_eq!(points.len(), scalars.len()); + let scalar_bits = max_scalar_bits_per_cell * scalars[0].len(); + // bool_scalars: 2d array `scalar_bits` by `points.len()` + let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits]; + for scalar in scalars { + for (scalar_chunk, bool_chunk) in + scalar.into_iter().zip(bool_scalars.chunks_mut(max_scalar_bits_per_cell)) + { + let bits = chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell); + for (bit, bool_bit) in bits.into_iter().zip(bool_chunk.iter_mut()) { + bool_bit.push(bit); } } - c_best - }; - #[cfg(feature = "display")] - dbg!(clump_factor); - */ - - let (mut agg, rand_point) = - multi_product::(chip, ctx, &points, &bool_scalars, clump_factor); + } + + let (mut agg, any_point) = + multi_product::(chip, ctx, points, &bool_scalars, clump_factor); // everything in agg has been enforced // compute sum_{k=0..t} agg[k] * 2^{radix * k} - (sum_k 2^{radix * k}) * rand_point - // (sum_{k=0..t} 2^{radix * k}) * rand_point = (2^{radix * t} - 1)/(2^radix - 1) - let mut sum = agg.pop().unwrap(); - let mut rand_sum = rand_point.clone(); + // (sum_{k=0..t} 2^{radix * k}) = (2^{radix * t} - 1)/(2^radix - 1) + let mut sum = agg.pop().unwrap().into(); + let mut any_sum = any_point.clone(); for g in agg.iter().rev() { - for _ in 0..radix { - sum = ec_double(chip, ctx, &sum); - rand_sum = ec_double(chip, ctx, &rand_sum); - } - sum = ec_add_unequal(chip, ctx, &sum, g, true); - chip.enforce_less_than(ctx, sum.x()); + any_sum = ec_double(chip, ctx, any_sum); + // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g` + sum = ec_double(chip, ctx, sum); + sum = ec_add_unequal(chip, ctx, sum, g, true); + } + + any_sum = ec_double(chip, ctx, any_sum); + // assume 2^scalar_bits != +-1 mod modulus::() + any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, false); + + ec_sub_unequal(chip, ctx, sum, any_sum, true) +} +*/ + +/// Multi-thread witness generation for multi-scalar multiplication. +/// +/// # Assumptions +/// * `points.len() == scalars.len()` +/// * `scalars[i].len() == scalars[j].len()` for all `i, j` +/// * `points` are all on the curve or the point at infinity +/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) +/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point +pub fn multi_exp_par( + chip: &FC, + // these are the "threads" within a single Phase + builder: &mut GateThreadBuilder, + points: &[EcPoint], + scalars: Vec>>, + max_scalar_bits_per_cell: usize, + // radix: usize, // specialize to radix = 1 + clump_factor: usize, + phase: usize, +) -> EcPoint +where + FC: FieldChip + Selectable + Selectable, + C: CurveAffineExt, +{ + // let (points, bool_scalars) = decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); - if radix != 1 { - // Can use non-strict as long as some property of the prime is true? - rand_sum = ec_add_unequal(chip, ctx, &rand_sum, &rand_point, false); + assert_eq!(points.len(), scalars.len()); + let scalar_bits = max_scalar_bits_per_cell * scalars[0].len(); + // bool_scalars: 2d array `scalar_bits` by `points.len()` + let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits]; + + // get a main thread + let ctx = builder.main(phase); + // single-threaded computation: + for scalar in scalars { + for (scalar_chunk, bool_chunk) in + scalar.into_iter().zip(bool_scalars.chunks_mut(max_scalar_bits_per_cell)) + { + let bits = chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell); + for (bit, bool_bit) in bits.into_iter().zip(bool_chunk.iter_mut()) { + bool_bit.push(bit); + } } } - if radix == 1 { - rand_sum = ec_double(chip, ctx, &rand_sum); - // assume 2^t != +-1 mod modulus::() - rand_sum = ec_sub_unequal(chip, ctx, &rand_sum, &rand_point, false); + let c = clump_factor; + let num_rounds = (points.len() + c - 1) / c; + // to avoid adding two points that are equal or negative of each other, + // we use a trick from halo2wrong where we load a "sufficiently generic" `C` point as witness + // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints + // we call it "any point" instead of "random point" to emphasize that "any" sufficiently generic point will do + let any_base = load_random_point::(chip, ctx); + let mut any_points = Vec::with_capacity(num_rounds); + any_points.push(any_base); + for _ in 1..num_rounds { + any_points.push(ec_double(chip, ctx, any_points.last().unwrap())); } - chip.enforce_less_than(ctx, rand_sum.x()); - ec_sub_unequal(chip, ctx, &sum, &rand_sum, true) + // now begins multi-threading + // multi_prods is 2d vector of size `num_rounds` by `scalar_bits` + let multi_prods = parallelize_in( + phase, + builder, + points.chunks(c).into_iter().zip(any_points.iter()).enumerate().collect(), + |ctx, (round, (points_clump, any_point))| { + // compute all possible multi-products of elements in points[round * c .. round * (c+1)] + // stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... } + let mut bucket = Vec::with_capacity(1 << c); + let any_point = into_strict_point(chip, ctx, any_point.clone()); + bucket.push(any_point); + for (i, point) in points_clump.iter().enumerate() { + // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates + // this can be checked by points[i].y == 0 iff points[i] == O + let is_infinity = chip.is_zero(ctx, &point.y); + let point = into_strict_point(chip, ctx, point.clone()); + + for j in 0..(1 << i) { + let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true); + // if points[i] is point at infinity, do nothing + new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity); + let new_point = into_strict_point(chip, ctx, new_point); + bucket.push(new_point); + } + } + bool_scalars + .iter() + .map(|bits| { + strict_ec_select_from_bits( + chip, + ctx, + &bucket, + &bits[round * c..round * c + points_clump.len()], + ) + }) + .collect::>() + }, + ); + + // agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits + let mut agg = parallelize_in(phase, builder, (0..scalar_bits).collect(), |ctx, i| { + let mut acc = multi_prods[0][i].clone(); + for multi_prod in multi_prods.iter().skip(1) { + let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true); + acc = into_strict_point(chip, ctx, _acc); + } + acc + }); + + // gets the LAST thread for single threaded work + let ctx = builder.main(phase); + // we have agg[j] = G'[j] + (2^num_rounds - 1) * any_base + // let any_point = (2^num_rounds - 1) * any_base + // TODO: can we remove all these random point operations somehow? + let mut any_point = ec_double(chip, ctx, any_points.last().unwrap()); + any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], true); + + // compute sum_{k=0..scalar_bits} agg[k] * 2^k - (sum_{k=0..scalar_bits} 2^k) * rand_point + // (sum_{k=0..scalar_bits} 2^k) = (2^scalar_bits - 1) + let mut sum = agg.pop().unwrap().into(); + let mut any_sum = any_point.clone(); + for g in agg.iter().rev() { + any_sum = ec_double(chip, ctx, any_sum); + // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g` + sum = ec_double(chip, ctx, sum); + sum = ec_add_unequal(chip, ctx, sum, g, true); + } + + any_sum = ec_double(chip, ctx, any_sum); + any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, true); + + ec_sub_strict(chip, ctx, sum, any_sum) } diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index fa9d6ed5..5bbc612e 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -1,6 +1,5 @@ #![allow(unused_assignments, unused_imports, unused_variables)] use super::*; -use crate::fields::fp::{FpConfig, FpStrategy}; use crate::fields::fp2::Fp2Chip; use crate::halo2_proofs::{ circuit::*, @@ -9,158 +8,73 @@ use crate::halo2_proofs::{ plonk::*, }; use group::Group; +use halo2_base::gates::builder::RangeCircuitBuilder; +use halo2_base::gates::RangeChip; use halo2_base::utils::bigint_to_fe; use halo2_base::SKIP_FIRST_PASS; -use halo2_base::{ - gates::range::RangeStrategy, utils::value_to_option, utils::PrimeField, ContextParams, -}; +use halo2_base::{gates::range::RangeStrategy, utils::value_to_option}; use num_bigint::{BigInt, RandBigInt}; +use rand_core::OsRng; use std::marker::PhantomData; use std::ops::Neg; -#[derive(Default)] -pub struct MyCircuit { - pub P: Option, - pub Q: Option, - pub _marker: PhantomData, -} - -const NUM_ADVICE: usize = 2; -const NUM_FIXED: usize = 2; - -impl Circuit for MyCircuit { - type Config = FpConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { P: None, Q: None, _marker: PhantomData } +fn basic_g1_tests( + ctx: &mut Context, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + P: G1Affine, + Q: G1Affine, +) { + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range = RangeChip::::default(lookup_bits); + let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); + let chip = EccChip::new(&fp_chip); + + let P_assigned = chip.load_private_unchecked(ctx, (P.x, P.y)); + let Q_assigned = chip.load_private_unchecked(ctx, (Q.x, Q.y)); + + // test add_unequal + chip.field_chip.enforce_less_than(ctx, P_assigned.x().clone()); + chip.field_chip.enforce_less_than(ctx, Q_assigned.x().clone()); + let sum = chip.add_unequal(ctx, &P_assigned, &Q_assigned, false); + assert_eq!(sum.x.0.truncation.to_bigint(limb_bits), sum.x.0.value); + assert_eq!(sum.y.0.truncation.to_bigint(limb_bits), sum.y.0.value); + { + let actual_sum = G1Affine::from(P + Q); + assert_eq!(bigint_to_fe::(&sum.x.0.value), actual_sum.x); + assert_eq!(bigint_to_fe::(&sum.y.0.value), actual_sum.y); } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FpConfig::::configure( - meta, - FpStrategy::Simple, - &[NUM_ADVICE], - &[1], - NUM_FIXED, - 22, - 88, - 3, - modulus::(), - 0, - 23, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.load_lookup_table(&mut layouter)?; - let chip = EccChip::construct(config.clone()); - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "ecc", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = chip.field_chip().new_context(region); - let ctx = &mut aux; - - let P_assigned = chip.load_private( - ctx, - match self.P { - Some(P) => (Value::known(P.x), Value::known(P.y)), - None => (Value::unknown(), Value::unknown()), - }, - ); - let Q_assigned = chip.load_private( - ctx, - match self.Q { - Some(Q) => (Value::known(Q.x), Value::known(Q.y)), - None => (Value::unknown(), Value::unknown()), - }, - ); - - // test add_unequal - { - chip.field_chip.enforce_less_than(ctx, P_assigned.x()); - chip.field_chip.enforce_less_than(ctx, Q_assigned.x()); - let sum = chip.add_unequal(ctx, &P_assigned, &Q_assigned, false); - assert_eq!( - value_to_option(sum.x.truncation.to_bigint(config.limb_bits)), - value_to_option(sum.x.value.clone()) - ); - assert_eq!( - value_to_option(sum.y.truncation.to_bigint(config.limb_bits)), - value_to_option(sum.y.value.clone()) - ); - if self.P.is_some() { - let actual_sum = G1Affine::from(self.P.unwrap() + self.Q.unwrap()); - sum.x.value.map(|v| assert_eq!(bigint_to_fe::(&v), actual_sum.x)); - sum.y.value.map(|v| assert_eq!(bigint_to_fe::(&v), actual_sum.y)); - } - println!("add unequal witness OK"); - } - - // test double - { - let doub = chip.double(ctx, &P_assigned); - assert_eq!( - value_to_option(doub.x.truncation.to_bigint(config.limb_bits)), - value_to_option(doub.x.value.clone()) - ); - assert_eq!( - value_to_option(doub.y.truncation.to_bigint(config.limb_bits)), - value_to_option(doub.y.value.clone()) - ); - if self.P.is_some() { - let actual_doub = G1Affine::from(self.P.unwrap() * Fr::from(2u64)); - doub.x.value.map(|v| assert_eq!(bigint_to_fe::(&v), actual_doub.x)); - doub.y.value.map(|v| assert_eq!(bigint_to_fe::(&v), actual_doub.y)); - } - println!("double witness OK"); - } - - chip.field_chip.finalize(ctx); - - #[cfg(feature = "display")] - { - println!("Using {NUM_ADVICE} advice columns and {NUM_FIXED} fixed columns"); - println!("total advice cells: {}", ctx.total_advice); - let (const_rows, _) = ctx.fixed_stats(); - println!("maximum rows used by a fixed column: {const_rows}"); - } - - Ok(()) - }, - ) + println!("add unequal witness OK"); + + // test double + let doub = chip.double(ctx, &P_assigned); + assert_eq!(doub.x.0.truncation.to_bigint(limb_bits), doub.x.0.value); + assert_eq!(doub.y.0.truncation.to_bigint(limb_bits), doub.y.0.value); + { + let actual_doub = G1Affine::from(P * Fr::from(2u64)); + assert_eq!(bigint_to_fe::(&doub.x.0.value), actual_doub.x); + assert_eq!(bigint_to_fe::(&doub.y.0.value), actual_doub.y); } + println!("double witness OK"); } -#[cfg(test)] #[test] fn test_ecc() { let k = 23; - let mut rng = rand::thread_rng(); + let P = G1Affine::random(OsRng); + let Q = G1Affine::random(OsRng); - let P = Some(G1Affine::random(&mut rng)); - let Q = Some(G1Affine::random(&mut rng)); + let mut builder = GateThreadBuilder::::mock(); + basic_g1_tests(builder.main(0), k - 1, 88, 3, P, Q); - let circuit = MyCircuit:: { P, Q, _marker: PhantomData }; + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); } #[cfg(feature = "dev-graph")] -#[cfg(test)] #[test] fn plot_ecc() { let k = 10; @@ -170,7 +84,14 @@ fn plot_ecc() { root.fill(&WHITE).unwrap(); let root = root.titled("Ecc Layout", ("sans-serif", 60)).unwrap(); - let circuit = MyCircuit::::default(); + let P = G1Affine::random(OsRng); + let Q = G1Affine::random(OsRng); + + let mut builder = GateThreadBuilder::::keygen(); + basic_g1_tests(builder.main(0), 22, 88, 3, P, Q); + + builder.config(k, Some(10)); + let circuit = RangeCircuitBuilder::mock(builder); halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); } diff --git a/halo2-ecc/src/fields/fp.rs b/halo2-ecc/src/fields/fp.rs index 1329726a..97bfd8b3 100644 --- a/halo2-ecc/src/fields/fp.rs +++ b/halo2-ecc/src/fields/fp.rs @@ -1,43 +1,55 @@ -use super::{FieldChip, PrimeFieldChip, Selectable}; +use super::{FieldChip, PrimeField, PrimeFieldChip, Selectable}; use crate::bigint::{ add_no_carry, big_is_equal, big_is_zero, carry_mod, check_carry_mod_to_zero, mul_no_carry, scalar_mul_and_add_no_carry, scalar_mul_no_carry, select, select_by_indicator, sub, - sub_no_carry, CRTInteger, FixedCRTInteger, OverflowInteger, -}; -use crate::halo2_proofs::{ - circuit::{Layouter, Region, Value}, - halo2curves::CurveAffine, - plonk::{ConstraintSystem, Error}, + sub_no_carry, CRTInteger, FixedCRTInteger, OverflowInteger, ProperCrtUint, ProperUint, }; +use crate::halo2_proofs::halo2curves::CurveAffine; +use halo2_base::gates::RangeChip; +use halo2_base::utils::ScalarField; use halo2_base::{ - gates::{ - range::{RangeConfig, RangeStrategy}, - GateInstructions, RangeInstructions, - }, - utils::{ - bigint_to_fe, biguint_to_fe, bit_length, decompose_bigint_option, decompose_biguint, - fe_to_biguint, modulus, PrimeField, - }, - AssignedValue, Context, ContextParams, + gates::{range::RangeConfig, GateInstructions, RangeInstructions}, + utils::{bigint_to_fe, biguint_to_fe, bit_length, decompose_biguint, fe_to_biguint, modulus}, + AssignedValue, Context, QuantumCell::{Constant, Existing}, }; use num_bigint::{BigInt, BigUint}; use num_traits::One; -use serde::{Deserialize, Serialize}; use std::{cmp::max, marker::PhantomData}; -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub enum FpStrategy { - Simple, - SimplePlus, +pub type BaseFieldChip<'range, C> = + FpChip<'range, ::ScalarExt, ::Base>; + +pub type FpConfig = RangeConfig; + +/// Wrapper around `FieldPoint` to guarantee this is a "reduced" representation of an `Fp` field element. +/// A reduced representation guarantees that there is a *unique* representation of each field element. +/// Typically this means Uints that are less than the modulus. +#[derive(Clone, Debug)] +pub struct Reduced(pub(crate) FieldPoint, PhantomData); + +impl Reduced { + pub fn as_ref(&self) -> Reduced<&FieldPoint, Fp> { + Reduced(&self.0, PhantomData) + } + + pub fn inner(&self) -> &FieldPoint { + &self.0 + } +} + +impl From, Fp>> for ProperCrtUint { + fn from(x: Reduced, Fp>) -> Self { + x.0 + } } -pub type BaseFieldChip = FpConfig<::ScalarExt, ::Base>; +// `Fp` always needs to be `BigPrimeField`, we may later want support for `F` being just `ScalarField` but for optimization reasons we'll assume it's also `BigPrimeField` for now #[derive(Clone, Debug)] -pub struct FpConfig { - pub range: RangeConfig, - // pub bigint_chip: BigIntConfig, +pub struct FpChip<'range, F: PrimeField, Fp: PrimeField> { + pub range: &'range RangeChip, + pub limb_bits: usize, pub num_limbs: usize, @@ -55,45 +67,13 @@ pub struct FpConfig { _marker: PhantomData, } -impl FpConfig { - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - gate_context_id: usize, - k: usize, - ) -> Self { - let range = RangeConfig::::configure( - meta, - match strategy { - FpStrategy::Simple => RangeStrategy::Vertical, - FpStrategy::SimplePlus => RangeStrategy::PlonkPlus, - }, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - gate_context_id, - k, - ); - - Self::construct(range, limb_bits, num_limbs, p) - } - - pub fn construct( - range: RangeConfig, - // bigint_chip: BigIntConfig, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - ) -> Self { +impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { + pub fn new(range: &'range RangeChip, limb_bits: usize, num_limbs: usize) -> Self { + assert!(limb_bits > 0); + assert!(num_limbs > 0); + assert!(limb_bits <= F::CAPACITY as usize); let limb_mask = (BigUint::from(1u64) << limb_bits) - 1usize; + let p = modulus::(); let p_limbs = decompose_biguint(&p, num_limbs, limb_bits); let native_modulus = modulus::(); let p_native = biguint_to_fe(&(&p % &native_modulus)); @@ -105,9 +85,8 @@ impl FpConfig { limb_bases.push(limb_base * limb_bases.last().unwrap()); } - FpConfig { + Self { range, - // bigint_chip, limb_bits, num_limbs, num_limbs_bits: bit_length(num_limbs as u64), @@ -123,54 +102,37 @@ impl FpConfig { } } - pub fn new_context<'a, 'b>(&'b self, region: Region<'a, F>) -> Context<'a, F> { - Context::new( - region, - ContextParams { - max_rows: self.range.gate.max_rows, - num_context_ids: 1, - fixed_columns: self.range.gate.constants.clone(), - }, - ) - } - - pub fn load_lookup_table(&self, layouter: &mut impl Layouter) -> Result<(), Error> { - self.range.load_lookup_table(layouter) - } - - pub fn enforce_less_than_p<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) { + pub fn enforce_less_than_p(&self, ctx: &mut Context, a: ProperCrtUint) { // a < p iff a - p has underflow let mut borrow: Option> = None; - for (p_limb, a_limb) in self.p_limbs.iter().zip(a.truncation.limbs.iter()) { + for (&p_limb, a_limb) in self.p_limbs.iter().zip(a.0.truncation.limbs) { let lt = match borrow { - None => self.range.is_less_than( - ctx, - Existing(a_limb), - Constant(*p_limb), - self.limb_bits, - ), + None => self.range.is_less_than(ctx, a_limb, Constant(p_limb), self.limb_bits), Some(borrow) => { - let plus_borrow = - self.range.gate.add(ctx, Constant(*p_limb), Existing(&borrow)); + let plus_borrow = self.gate().add(ctx, Constant(p_limb), borrow); self.range.is_less_than( ctx, Existing(a_limb), - Existing(&plus_borrow), + Existing(plus_borrow), self.limb_bits, ) } }; borrow = Some(lt); } - self.range.gate.assert_is_const(ctx, &borrow.unwrap(), F::one()) + self.gate().assert_is_const(ctx, &borrow.unwrap(), &F::one()); } - pub fn finalize(&self, ctx: &mut Context<'_, F>) -> usize { - self.range.finalize(ctx) + pub fn load_constant_uint(&self, ctx: &mut Context, a: BigUint) -> ProperCrtUint { + FixedCRTInteger::from_native(a, self.num_limbs, self.limb_bits).assign( + ctx, + self.limb_bits, + self.native_modulus(), + ) } } -impl PrimeFieldChip for FpConfig { +impl<'range, F: PrimeField, Fp: PrimeField> PrimeFieldChip for FpChip<'range, F, Fp> { fn num_limbs(&self) -> usize { self.num_limbs } @@ -182,163 +144,132 @@ impl PrimeFieldChip for FpConfig { } } -impl FieldChip for FpConfig { +impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, Fp> { const PRIME_FIELD_NUM_BITS: u32 = Fp::NUM_BITS; - type ConstantType = BigUint; - type WitnessType = Value; - type FieldPoint<'v> = CRTInteger<'v, F>; + type UnsafeFieldPoint = CRTInteger; + type FieldPoint = ProperCrtUint; + type ReducedFieldPoint = Reduced, Fp>; type FieldType = Fp; - type RangeChip = RangeConfig; + type RangeChip = RangeChip; fn native_modulus(&self) -> &BigUint { &self.native_modulus } - fn range(&self) -> &Self::RangeChip { - &self.range + fn range(&self) -> &'range Self::RangeChip { + self.range } fn limb_bits(&self) -> usize { self.limb_bits } - fn get_assigned_value(&self, x: &CRTInteger) -> Value { - x.value.as_ref().map(|x| bigint_to_fe::(&(x % &self.p))) + fn get_assigned_value(&self, x: &CRTInteger) -> Fp { + bigint_to_fe(&(&x.value % &self.p)) } - fn fe_to_constant(x: Fp) -> BigUint { - fe_to_biguint(&x) - } - - fn fe_to_witness(x: &Value) -> Value { - x.map(|x| BigInt::from(fe_to_biguint(&x))) - } - - fn load_private<'v>(&self, ctx: &mut Context<'_, F>, a: Value) -> CRTInteger<'v, F> { - let a_vec = decompose_bigint_option::(a.as_ref(), self.num_limbs, self.limb_bits); - let limbs = self.range.gate().assign_witnesses(ctx, a_vec); - - let a_native = OverflowInteger::::evaluate( - self.range.gate(), - //&self.bigint_chip, - ctx, - &limbs, - self.limb_bases.iter().cloned(), - ); + fn load_private(&self, ctx: &mut Context, a: Fp) -> ProperCrtUint { + let a = fe_to_biguint(&a); + let a_vec = decompose_biguint::(&a, self.num_limbs, self.limb_bits); + let limbs = ctx.assign_witnesses(a_vec); let a_loaded = - CRTInteger::construct(OverflowInteger::construct(limbs, self.limb_bits), a_native, a); + ProperUint(limbs).into_crt(ctx, self.gate(), a, &self.limb_bases, self.limb_bits); - // TODO: this range check prevents loading witnesses that are not in "proper" representation form, is that ok? - self.range_check(ctx, &a_loaded, Self::PRIME_FIELD_NUM_BITS as usize); + self.range_check(ctx, a_loaded.clone(), Self::PRIME_FIELD_NUM_BITS as usize); a_loaded } - fn load_constant<'v>(&self, ctx: &mut Context<'_, F>, a: BigUint) -> CRTInteger<'v, F> { - let a_native = self.range.gate.assign_region_last( - ctx, - vec![Constant(biguint_to_fe(&(&a % modulus::())))], - vec![], - ); - let a_limbs = self.range.gate().assign_region( - ctx, - decompose_biguint::(&a, self.num_limbs, self.limb_bits).into_iter().map(Constant), - vec![], - ); - - CRTInteger::construct( - OverflowInteger::construct(a_limbs, self.limb_bits), - a_native, - Value::known(BigInt::from(a)), - ) + fn load_constant(&self, ctx: &mut Context, a: Fp) -> ProperCrtUint { + self.load_constant_uint(ctx, fe_to_biguint(&a)) } // signed overflow BigInt functions - fn add_no_carry<'v>( + fn add_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - ) -> CRTInteger<'v, F> { - add_no_carry::crt::(self.range.gate(), ctx, a, b) + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> CRTInteger { + add_no_carry::crt(self.gate(), ctx, a.into(), b.into()) } - fn add_constant_no_carry<'v>( + fn add_constant_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - c: BigUint, - ) -> CRTInteger<'v, F> { - let c = FixedCRTInteger::from_native(c, self.num_limbs, self.limb_bits); + ctx: &mut Context, + a: impl Into>, + c: Fp, + ) -> CRTInteger { + let c = FixedCRTInteger::from_native(fe_to_biguint(&c), self.num_limbs, self.limb_bits); let c_native = biguint_to_fe::(&(&c.value % modulus::())); + let a = a.into(); let mut limbs = Vec::with_capacity(a.truncation.limbs.len()); - for (a_limb, c_limb) in a.truncation.limbs.iter().zip(c.truncation.limbs.into_iter()) { - let limb = self.range.gate.add(ctx, Existing(a_limb), Constant(c_limb)); + for (a_limb, c_limb) in a.truncation.limbs.into_iter().zip(c.truncation.limbs) { + let limb = self.gate().add(ctx, a_limb, Constant(c_limb)); limbs.push(limb); } - let native = self.range.gate.add(ctx, Existing(&a.native), Constant(c_native)); + let native = self.gate().add(ctx, a.native, Constant(c_native)); let trunc = - OverflowInteger::construct(limbs, max(a.truncation.max_limb_bits, self.limb_bits) + 1); - let value = a.value.as_ref().map(|a| a + BigInt::from(c.value)); + OverflowInteger::new(limbs, max(a.truncation.max_limb_bits, self.limb_bits) + 1); + let value = a.value + BigInt::from(c.value); - CRTInteger::construct(trunc, native, value) + CRTInteger::new(trunc, native, value) } - fn sub_no_carry<'v>( + fn sub_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - ) -> CRTInteger<'v, F> { - sub_no_carry::crt::(self.range.gate(), ctx, a, b) + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> CRTInteger { + sub_no_carry::crt::(self.gate(), ctx, a.into(), b.into()) } // Input: a // Output: p - a if a != 0, else a // Assume the actual value of `a` equals `a.truncation` // Constrains a.truncation <= p using subtraction with carries - fn negate<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) -> CRTInteger<'v, F> { + fn negate(&self, ctx: &mut Context, a: ProperCrtUint) -> ProperCrtUint { // Compute p - a.truncation using carries - let p = self.load_constant(ctx, self.p.to_biguint().unwrap()); + let p = self.load_constant_uint(ctx, self.p.to_biguint().unwrap()); let (out_or_p, underflow) = - sub::crt::(self.range(), ctx, &p, a, self.limb_bits, self.limb_bases[1]); + sub::crt(self.range(), ctx, p, a.clone(), self.limb_bits, self.limb_bases[1]); // constrain underflow to equal 0 - self.range.gate.assert_is_const(ctx, &underflow, F::zero()); + self.gate().assert_is_const(ctx, &underflow, &F::zero()); - let a_is_zero = big_is_zero::assign::(self.gate(), ctx, &a.truncation); - select::crt::(self.range.gate(), ctx, a, &out_or_p, &a_is_zero) + let a_is_zero = big_is_zero::positive(self.gate(), ctx, a.0.truncation.clone()); + ProperCrtUint(select::crt(self.gate(), ctx, a.0, out_or_p, a_is_zero)) } - fn scalar_mul_no_carry<'v>( + fn scalar_mul_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, + ctx: &mut Context, + a: impl Into>, c: i64, - ) -> CRTInteger<'v, F> { - scalar_mul_no_carry::crt::(self.range.gate(), ctx, a, c) + ) -> CRTInteger { + scalar_mul_no_carry::crt(self.gate(), ctx, a.into(), c) } - fn scalar_mul_and_add_no_carry<'v>( + fn scalar_mul_and_add_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, c: i64, - ) -> CRTInteger<'v, F> { - scalar_mul_and_add_no_carry::crt::(self.range.gate(), ctx, a, b, c) + ) -> CRTInteger { + scalar_mul_and_add_no_carry::crt(self.gate(), ctx, a.into(), b.into(), c) } - fn mul_no_carry<'v>( + fn mul_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - ) -> CRTInteger<'v, F> { - mul_no_carry::crt::(self.range.gate(), ctx, a, b, self.num_limbs_log2_ceil) + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> CRTInteger { + mul_no_carry::crt(self.gate(), ctx, a.into(), b.into(), self.num_limbs_log2_ceil) } - fn check_carry_mod_to_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) { + fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: CRTInteger) { check_carry_mod_to_zero::crt::( self.range(), - // &self.bigint_chip, ctx, a, self.num_limbs_bits, @@ -351,10 +282,9 @@ impl FieldChip for FpConfig { ) } - fn carry_mod<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) -> CRTInteger<'v, F> { + fn carry_mod(&self, ctx: &mut Context, a: CRTInteger) -> ProperCrtUint { carry_mod::crt::( self.range(), - // &self.bigint_chip, ctx, a, self.num_limbs_bits, @@ -367,123 +297,177 @@ impl FieldChip for FpConfig { ) } - fn range_check<'v>( + /// # Assumptions + /// * `max_bits` in `(n * (k - 1), n * k]` + fn range_check( &self, - ctx: &mut Context<'v, F>, - a: &CRTInteger<'v, F>, + ctx: &mut Context, + a: impl Into>, max_bits: usize, // the maximum bits that a.value could take ) { let n = self.limb_bits; + let a = a.into(); let k = a.truncation.limbs.len(); debug_assert!(max_bits > n * (k - 1) && max_bits <= n * k); let last_limb_bits = max_bits - n * (k - 1); - #[cfg(debug_assertions)] - a.value.as_ref().map(|v| { - debug_assert!(v.bits() as usize <= max_bits); - }); + debug_assert!(a.value.bits() as usize <= max_bits); // range check limbs of `a` are in [0, 2^n) except last limb should be in [0, 2^last_limb_bits) - for (i, cell) in a.truncation.limbs.iter().enumerate() { + for (i, cell) in a.truncation.limbs.into_iter().enumerate() { let limb_bits = if i == k - 1 { last_limb_bits } else { n }; self.range.range_check(ctx, cell, limb_bits); } } - fn enforce_less_than<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - self.enforce_less_than_p(ctx, a) - } - - fn is_soft_zero<'v>( + fn enforce_less_than( &self, - ctx: &mut Context<'v, F>, - a: &CRTInteger<'v, F>, - ) -> AssignedValue<'v, F> { - let is_zero = big_is_zero::crt::(self.gate(), ctx, a); - - // underflow != 0 iff carry < p - let p = self.load_constant(ctx, self.p.to_biguint().unwrap()); - let (_, underflow) = - sub::crt::(self.range(), ctx, a, &p, self.limb_bits, self.limb_bases[1]); - let is_underflow_zero = self.gate().is_zero(ctx, &underflow); - let range_check = self.gate().not(ctx, Existing(&is_underflow_zero)); - - self.gate().and(ctx, Existing(&is_zero), Existing(&range_check)) + ctx: &mut Context, + a: ProperCrtUint, + ) -> Reduced, Fp> { + self.enforce_less_than_p(ctx, a.clone()); + Reduced(a, PhantomData) } - fn is_soft_nonzero<'v>( + /// Returns 1 iff `a` is 0 as a BigUint. This means that even if `a` is 0 modulo `p`, this may return 0. + fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl Into>, + ) -> AssignedValue { + let a = a.into(); + big_is_zero::positive(self.gate(), ctx, a.0.truncation) + } + + /// Given proper CRT integer `a`, returns 1 iff `a < modulus::()` and `a != 0` as integers + /// + /// # Assumptions + /// * `a` is proper representation of BigUint + fn is_soft_nonzero( &self, - ctx: &mut Context<'v, F>, - a: &CRTInteger<'v, F>, - ) -> AssignedValue<'v, F> { - let is_zero = big_is_zero::crt::(self.gate(), ctx, a); - let is_nonzero = self.gate().not(ctx, Existing(&is_zero)); + ctx: &mut Context, + a: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let is_zero = big_is_zero::positive(self.gate(), ctx, a.0.truncation.clone()); + let is_nonzero = self.gate().not(ctx, is_zero); // underflow != 0 iff carry < p - let p = self.load_constant(ctx, self.p.to_biguint().unwrap()); + let p = self.load_constant_uint(ctx, self.p.to_biguint().unwrap()); let (_, underflow) = - sub::crt::(self.range(), ctx, a, &p, self.limb_bits, self.limb_bases[1]); - let is_underflow_zero = self.gate().is_zero(ctx, &underflow); - let range_check = self.gate().not(ctx, Existing(&is_underflow_zero)); + sub::crt::(self.range(), ctx, a, p, self.limb_bits, self.limb_bases[1]); + let is_underflow_zero = self.gate().is_zero(ctx, underflow); + let no_underflow = self.gate().not(ctx, is_underflow_zero); - self.gate().and(ctx, Existing(&is_nonzero), Existing(&range_check)) + self.gate().and(ctx, is_nonzero, no_underflow) } // assuming `a` has been range checked to be a proper BigInt // constrain the witness `a` to be `< p` // then check if `a` is 0 - fn is_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) -> AssignedValue<'v, F> { - self.enforce_less_than_p(ctx, a); + fn is_zero(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + let a = a.into(); + self.enforce_less_than_p(ctx, a.clone()); // just check truncated limbs are all 0 since they determine the native value - big_is_zero::positive::(self.gate(), ctx, &a.truncation) + big_is_zero::positive(self.gate(), ctx, a.0.truncation) } - fn is_equal_unenforced<'v>( + fn is_equal_unenforced( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - big_is_equal::assign::(self.gate(), ctx, &a.truncation, &b.truncation) + ctx: &mut Context, + a: Reduced, Fp>, + b: Reduced, Fp>, + ) -> AssignedValue { + big_is_equal::assign::(self.gate(), ctx, a.0, b.0) } // assuming `a, b` have been range checked to be a proper BigInt // constrain the witnesses `a, b` to be `< p` // then assert `a == b` as BigInts - fn assert_equal<'v>( + fn assert_equal( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, ) { - self.enforce_less_than_p(ctx, a); - self.enforce_less_than_p(ctx, b); + let a = a.into(); + let b = b.into(); // a.native and b.native are derived from `a.truncation, b.truncation`, so no need to check if they're equal - for (limb_a, limb_b) in a.truncation.limbs.iter().zip(a.truncation.limbs.iter()) { - self.range.gate.assert_equal(ctx, Existing(limb_a), Existing(limb_b)); + for (limb_a, limb_b) in a.limbs().iter().zip(b.limbs().iter()) { + ctx.constrain_equal(limb_a, limb_b); } + self.enforce_less_than_p(ctx, a); + self.enforce_less_than_p(ctx, b); } } -impl Selectable for FpConfig { - type Point<'v> = CRTInteger<'v, F>; +impl<'range, F: PrimeField, Fp: PrimeField> Selectable> for FpChip<'range, F, Fp> { + fn select( + &self, + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, + sel: AssignedValue, + ) -> CRTInteger { + select::crt(self.gate(), ctx, a, b, sel) + } + + fn select_by_indicator( + &self, + ctx: &mut Context, + a: &impl AsRef<[CRTInteger]>, + coeffs: &[AssignedValue], + ) -> CRTInteger { + select_by_indicator::crt(self.gate(), ctx, a.as_ref(), coeffs, &self.limb_bases) + } +} + +impl<'range, F: PrimeField, Fp: PrimeField> Selectable> + for FpChip<'range, F, Fp> +{ + fn select( + &self, + ctx: &mut Context, + a: ProperCrtUint, + b: ProperCrtUint, + sel: AssignedValue, + ) -> ProperCrtUint { + ProperCrtUint(select::crt(self.gate(), ctx, a.0, b.0, sel)) + } + + fn select_by_indicator( + &self, + ctx: &mut Context, + a: &impl AsRef<[ProperCrtUint]>, + coeffs: &[AssignedValue], + ) -> ProperCrtUint { + let out = select_by_indicator::crt(self.gate(), ctx, a.as_ref(), coeffs, &self.limb_bases); + ProperCrtUint(out) + } +} - fn select<'v>( +impl Selectable> for FC +where + FC: Selectable, +{ + fn select( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - sel: &AssignedValue<'v, F>, - ) -> CRTInteger<'v, F> { - select::crt::(self.range.gate(), ctx, a, b, sel) + ctx: &mut Context, + a: Reduced, + b: Reduced, + sel: AssignedValue, + ) -> Reduced { + Reduced(self.select(ctx, a.0, b.0, sel), PhantomData) } - fn select_by_indicator<'v>( + fn select_by_indicator( &self, - ctx: &mut Context<'_, F>, - a: &[CRTInteger<'v, F>], - coeffs: &[AssignedValue<'v, F>], - ) -> CRTInteger<'v, F> { - select_by_indicator::crt::(self.range.gate(), ctx, a, coeffs, &self.limb_bases) + ctx: &mut Context, + a: &impl AsRef<[Reduced]>, + coeffs: &[AssignedValue], + ) -> Reduced { + // this is inefficient, could do std::mem::transmute but that is unsafe. hopefully compiler optimizes it out + let a = a.as_ref().iter().map(|a| a.0.clone()).collect::>(); + Reduced(self.select_by_indicator(ctx, &a, coeffs), PhantomData) } } diff --git a/halo2-ecc/src/fields/fp12.rs b/halo2-ecc/src/fields/fp12.rs index f130fd52..156ca452 100644 --- a/halo2-ecc/src/fields/fp12.rs +++ b/halo2-ecc/src/fields/fp12.rs @@ -1,290 +1,167 @@ -use super::{FieldChip, FieldExtConstructor, FieldExtPoint, PrimeFieldChip}; -use crate::halo2_proofs::{arithmetic::Field, circuit::Value}; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::{fe_to_biguint, value_to_option, PrimeField}, - AssignedValue, Context, - QuantumCell::Existing, -}; -use num_bigint::{BigInt, BigUint}; use std::marker::PhantomData; +use halo2_base::{utils::modulus, AssignedValue, Context}; +use num_bigint::BigUint; + +use crate::impl_field_ext_chip_common; + +use super::{ + vector::{FieldVector, FieldVectorChip}, + FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, +}; + /// Represent Fp12 point as FqPoint with degree = 12 /// `Fp12 = Fp2[w] / (w^6 - u - xi)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to /// be irreducible over Fp; i.e., in order for -1 to not be a square (quadratic residue) in Fp /// This means we store an Fp12 point as `\sum_{i = 0}^6 (a_{i0} + a_{i1} * u) * w^i` /// This is encoded in an FqPoint of degree 12 as `(a_{00}, ..., a_{50}, a_{01}, ..., a_{51})` -pub struct Fp12Chip<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp12: Field, const XI_0: i64> -where - FpChip::FieldType: PrimeField, -{ - // for historical reasons, leaving this as a reference - // for the current implementation we could also just use the de-referenced version: `fp_chip: FpChip` - pub fp_chip: &'a FpChip, - _f: PhantomData, - _fp12: PhantomData, -} +#[derive(Clone, Copy, Debug)] +pub struct Fp12Chip<'a, F: PrimeField, FpChip: FieldChip, Fp12, const XI_0: i64>( + pub FieldVectorChip<'a, F, FpChip>, + PhantomData, +); impl<'a, F, FpChip, Fp12, const XI_0: i64> Fp12Chip<'a, F, FpChip, Fp12, XI_0> where F: PrimeField, FpChip: PrimeFieldChip, FpChip::FieldType: PrimeField, - Fp12: Field + FieldExtConstructor, + Fp12: ff::Field, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. - pub fn construct(fp_chip: &'a FpChip) -> Self { - Self { fp_chip, _f: PhantomData, _fp12: PhantomData } + pub fn new(fp_chip: &'a FpChip) -> Self { + assert_eq!( + modulus::() % 4usize, + BigUint::from(3u64), + "p must be 3 (mod 4) for the polynomial u^2 + 1 to be irreducible" + ); + Self(FieldVectorChip::new(fp_chip), PhantomData) + } + + pub fn fp_chip(&self) -> &FpChip { + self.0.fp_chip } - pub fn fp2_mul_no_carry<'v>( + pub fn fp2_mul_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - fp2_pt: &FieldExtPoint>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 12); - assert_eq!(fp2_pt.coeffs.len(), 2); + ctx: &mut Context, + fp12_pt: FieldVector, + fp2_pt: FieldVector, + ) -> FieldVector { + let fp12_pt = fp12_pt.0; + let fp2_pt = fp2_pt.0; + assert_eq!(fp12_pt.len(), 12); + assert_eq!(fp2_pt.len(), 2); + let fp_chip = self.fp_chip(); let mut out_coeffs = Vec::with_capacity(12); for i in 0..6 { - let coeff1 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &fp2_pt.coeffs[0]); - let coeff2 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &fp2_pt.coeffs[1]); - let coeff = self.fp_chip.sub_no_carry(ctx, &coeff1, &coeff2); + let coeff1 = fp_chip.mul_no_carry(ctx, fp12_pt[i].clone(), fp2_pt[0].clone()); + let coeff2 = fp_chip.mul_no_carry(ctx, fp12_pt[i + 6].clone(), fp2_pt[1].clone()); + let coeff = fp_chip.sub_no_carry(ctx, coeff1, coeff2); out_coeffs.push(coeff); } for i in 0..6 { - let coeff1 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &fp2_pt.coeffs[0]); - let coeff2 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &fp2_pt.coeffs[1]); - let coeff = self.fp_chip.add_no_carry(ctx, &coeff1, &coeff2); + let coeff1 = fp_chip.mul_no_carry(ctx, fp12_pt[i + 6].clone(), fp2_pt[0].clone()); + let coeff2 = fp_chip.mul_no_carry(ctx, fp12_pt[i].clone(), fp2_pt[1].clone()); + let coeff = fp_chip.add_no_carry(ctx, coeff1, coeff2); out_coeffs.push(coeff); } - FieldExtPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // for \sum_i (a_i + b_i u) w^i, returns \sum_i (-1)^i (a_i + b_i u) w^i - pub fn conjugate<'v>( + pub fn conjugate( &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 12); + ctx: &mut Context, + a: FieldVector, + ) -> FieldVector { + let a = a.0; + assert_eq!(a.len(), 12); let coeffs = a - .coeffs - .iter() + .into_iter() .enumerate() - .map(|(i, c)| if i % 2 == 0 { c.clone() } else { self.fp_chip.negate(ctx, c) }) + .map(|(i, c)| if i % 2 == 0 { c } else { self.fp_chip().negate(ctx, c) }) .collect(); - FieldExtPoint::construct(coeffs) + FieldVector(coeffs) } } -/// multiply (a0 + a1 * u) * (XI0 + u) without carry -pub fn mul_no_carry_w6<'v, F: PrimeField, FC: FieldChip, const XI_0: i64>( +/// multiply Fp2 elts: (a0 + a1 * u) * (XI0 + u) without carry +/// +/// # Assumptions +/// * `a` is `Fp2` point represented as `FieldVector` with degree = 2 +pub fn mul_no_carry_w6, const XI_0: i64>( fp_chip: &FC, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, -) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 2); - let (a0, a1) = (&a.coeffs[0], &a.coeffs[1]); + ctx: &mut Context, + a: FieldVector, +) -> FieldVector { + let [a0, a1]: [_; 2] = a.0.try_into().unwrap(); // (a0 + a1 u) * (XI_0 + u) = (a0 * XI_0 - a1) + (a1 * XI_0 + a0) u with u^2 = -1 // This should fit in the overflow representation if limb_bits is large enough - let a0_xi0 = fp_chip.scalar_mul_no_carry(ctx, a0, XI_0); - let out0_0_nocarry = fp_chip.sub_no_carry(ctx, &a0_xi0, a1); + let a0_xi0 = fp_chip.scalar_mul_no_carry(ctx, a0.clone(), XI_0); + let out0_0_nocarry = fp_chip.sub_no_carry(ctx, a0_xi0, a1.clone()); let out0_1_nocarry = fp_chip.scalar_mul_and_add_no_carry(ctx, a1, a0, XI_0); - FieldExtPoint::construct(vec![out0_0_nocarry, out0_1_nocarry]) + FieldVector(vec![out0_0_nocarry, out0_1_nocarry]) } +// a lot of this is common to any field extension (lots of for loops), but due to the way rust traits work, it is hard to create a common generic trait that does this. The main problem is that if you had a `FieldExtCommon` trait and wanted to implement `FieldChip` for anything with `FieldExtCommon`, rust will stop you because someone could implement `FieldExtCommon` and `FieldChip` for the same type, causing a conflict. +// partially solved using macro + impl<'a, F, FpChip, Fp12, const XI_0: i64> FieldChip for Fp12Chip<'a, F, FpChip, Fp12, XI_0> where F: PrimeField, - FpChip: PrimeFieldChip, ConstantType = BigUint>, + FpChip: PrimeFieldChip, FpChip::FieldType: PrimeField, - Fp12: Field + FieldExtConstructor, + Fp12: ff::Field + FieldExtConstructor, + FieldVector: From>, + FieldVector: From>, { const PRIME_FIELD_NUM_BITS: u32 = FpChip::FieldType::NUM_BITS; - type ConstantType = Fp12; - type WitnessType = Vec>; - type FieldPoint<'v> = FieldExtPoint>; + type UnsafeFieldPoint = FieldVector; + type FieldPoint = FieldVector; + type ReducedFieldPoint = FieldVector; type FieldType = Fp12; type RangeChip = FpChip::RangeChip; - fn native_modulus(&self) -> &BigUint { - self.fp_chip.native_modulus() - } - fn range(&self) -> &Self::RangeChip { - self.fp_chip.range() - } - - fn limb_bits(&self) -> usize { - self.fp_chip.limb_bits() - } - - fn get_assigned_value(&self, x: &Self::FieldPoint<'_>) -> Value { - assert_eq!(x.coeffs.len(), 12); - let values = x.coeffs.iter().map(|v| self.fp_chip.get_assigned_value(v)); - let values_collected: Value> = values.into_iter().collect(); - values_collected.map(|c| Fp12::new(c.try_into().unwrap())) - } - - fn fe_to_constant(x: Self::FieldType) -> Self::ConstantType { - x - } - fn fe_to_witness(x: &Value) -> Vec> { - match value_to_option(*x) { - Some(x) => { - x.coeffs().iter().map(|c| Value::known(BigInt::from(fe_to_biguint(c)))).collect() - } - None => vec![Value::unknown(); 12], - } - } - - fn load_private<'v>( - &self, - ctx: &mut Context<'_, F>, - coeffs: Vec>, - ) -> Self::FieldPoint<'v> { - assert_eq!(coeffs.len(), 12); - let mut assigned_coeffs = Vec::with_capacity(12); - for a in coeffs { - let assigned_coeff = self.fp_chip.load_private(ctx, a.clone()); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - fn load_constant<'v>(&self, ctx: &mut Context<'_, F>, c: Fp12) -> Self::FieldPoint<'v> { - let mut assigned_coeffs = Vec::with_capacity(12); - for a in &c.coeffs() { - let assigned_coeff = self.fp_chip.load_constant(ctx, fe_to_biguint(a)); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - // signed overflow BigInt functions - fn add_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn add_constant_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: Self::ConstantType, - ) -> Self::FieldPoint<'v> { - let c_coeffs = c.coeffs(); - assert_eq!(a.coeffs.len(), c_coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for (a, c) in a.coeffs.iter().zip(c_coeffs.into_iter()) { - let coeff = self.fp_chip.add_constant_no_carry(ctx, a, FpChip::fe_to_constant(c)); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn sub_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.sub_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn negate<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let out_coeff = self.fp_chip.negate(ctx, a_coeff); - out_coeffs.push(out_coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: i64, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.scalar_mul_no_carry(ctx, &a.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_and_add_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - c: i64, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = - self.fp_chip.scalar_mul_and_add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) + fn get_assigned_value(&self, x: &Self::UnsafeFieldPoint) -> Fp12 { + assert_eq!(x.0.len(), 12); + let values = x.0.iter().map(|v| self.fp_chip().get_assigned_value(v)).collect::>(); + Fp12::new(values.try_into().unwrap()) } // w^6 = u + xi for xi = 9 - fn mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), 12); - assert_eq!(b.coeffs.len(), 12); - + fn mul_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + let a = a.into().0; + let b = b.into().0; + assert_eq!(a.len(), 12); + assert_eq!(b.len(), 12); + + let fp_chip = self.fp_chip(); // a = \sum_{i = 0}^5 (a_i * w^i + a_{i + 6} * w^i * u) // b = \sum_{i = 0}^5 (b_i * w^i + b_{i + 6} * w^i * u) - let mut a0b0_coeffs = Vec::with_capacity(11); - let mut a0b1_coeffs = Vec::with_capacity(11); - let mut a1b0_coeffs = Vec::with_capacity(11); - let mut a1b1_coeffs = Vec::with_capacity(11); + let mut a0b0_coeffs: Vec = Vec::with_capacity(11); + let mut a0b1_coeffs: Vec = Vec::with_capacity(11); + let mut a1b0_coeffs: Vec = Vec::with_capacity(11); + let mut a1b1_coeffs: Vec = Vec::with_capacity(11); for i in 0..6 { for j in 0..6 { - let coeff00 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &b.coeffs[j]); - let coeff01 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &b.coeffs[j + 6]); - let coeff10 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &b.coeffs[j]); - let coeff11 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &b.coeffs[j + 6]); + let coeff00 = fp_chip.mul_no_carry(ctx, &a[i], &b[j]); + let coeff01 = fp_chip.mul_no_carry(ctx, &a[i], &b[j + 6]); + let coeff10 = fp_chip.mul_no_carry(ctx, &a[i + 6], &b[j]); + let coeff11 = fp_chip.mul_no_carry(ctx, &a[i + 6], &b[j + 6]); if i + j < a0b0_coeffs.len() { - a0b0_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a0b0_coeffs[i + j], &coeff00); - a0b1_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a0b1_coeffs[i + j], &coeff01); - a1b0_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a1b0_coeffs[i + j], &coeff10); - a1b1_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a1b1_coeffs[i + j], &coeff11); + a0b0_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a0b0_coeffs[i + j], coeff00); + a0b1_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a0b1_coeffs[i + j], coeff01); + a1b0_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a1b0_coeffs[i + j], coeff10); + a1b1_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a1b1_coeffs[i + j], coeff11); } else { a0b0_coeffs.push(coeff00); a0b1_coeffs.push(coeff01); @@ -297,10 +174,8 @@ where let mut a0b0_minus_a1b1 = Vec::with_capacity(11); let mut a0b1_plus_a1b0 = Vec::with_capacity(11); for i in 0..11 { - let a0b0_minus_a1b1_entry = - self.fp_chip.sub_no_carry(ctx, &a0b0_coeffs[i], &a1b1_coeffs[i]); - let a0b1_plus_a1b0_entry = - self.fp_chip.add_no_carry(ctx, &a0b1_coeffs[i], &a1b0_coeffs[i]); + let a0b0_minus_a1b1_entry = fp_chip.sub_no_carry(ctx, &a0b0_coeffs[i], &a1b1_coeffs[i]); + let a0b1_plus_a1b0_entry = fp_chip.add_no_carry(ctx, &a0b1_coeffs[i], &a1b0_coeffs[i]); a0b0_minus_a1b1.push(a0b0_minus_a1b1_entry); a0b1_plus_a1b0.push(a0b1_plus_a1b0_entry); @@ -311,13 +186,13 @@ where let mut out_coeffs = Vec::with_capacity(12); for i in 0..6 { if i < 5 { - let mut coeff = self.fp_chip.scalar_mul_and_add_no_carry( + let mut coeff = fp_chip.scalar_mul_and_add_no_carry( ctx, &a0b0_minus_a1b1[i + 6], &a0b0_minus_a1b1[i], XI_0, ); - coeff = self.fp_chip.sub_no_carry(ctx, &coeff, &a0b1_plus_a1b0[i + 6]); + coeff = fp_chip.sub_no_carry(ctx, coeff, &a0b1_plus_a1b0[i + 6]); out_coeffs.push(coeff); } else { out_coeffs.push(a0b0_minus_a1b1[i].clone()); @@ -326,152 +201,18 @@ where for i in 0..6 { if i < 5 { let mut coeff = - self.fp_chip.add_no_carry(ctx, &a0b1_plus_a1b0[i], &a0b0_minus_a1b1[i + 6]); - coeff = self.fp_chip.scalar_mul_and_add_no_carry( - ctx, - &a0b1_plus_a1b0[i + 6], - &coeff, - XI_0, - ); + fp_chip.add_no_carry(ctx, &a0b1_plus_a1b0[i], &a0b0_minus_a1b1[i + 6]); + coeff = + fp_chip.scalar_mul_and_add_no_carry(ctx, &a0b1_plus_a1b0[i + 6], coeff, XI_0); out_coeffs.push(coeff); } else { out_coeffs.push(a0b1_plus_a1b0[i].clone()); } } - Self::FieldPoint::construct(out_coeffs) - } - - fn check_carry_mod_to_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - for coeff in &a.coeffs { - self.fp_chip.check_carry_mod_to_zero(ctx, coeff); - } - } - - fn carry_mod<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.carry_mod(ctx, a_coeff); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn range_check<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>, max_bits: usize) { - for a_coeff in &a.coeffs { - self.fp_chip.range_check(ctx, a_coeff, max_bits); - } - } - - fn enforce_less_than<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - for a_coeff in &a.coeffs { - self.fp_chip.enforce_less_than(ctx, a_coeff) - } - } - - fn is_soft_zero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() + FieldVector(out_coeffs) } - fn is_soft_nonzero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().or(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_zero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_equal<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut acc = None; - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - let coeff = self.fp_chip.is_equal(ctx, a_coeff, b_coeff); - if let Some(c) = acc { - acc = Some(self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&c))); - } else { - acc = Some(coeff); - } - } - acc.unwrap() - } - - fn is_equal_unenforced<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut acc = None; - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff); - if let Some(c) = acc { - acc = Some(self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&c))); - } else { - acc = Some(coeff); - } - } - acc.unwrap() - } - - fn assert_equal<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) { - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - self.fp_chip.assert_equal(ctx, a_coeff, b_coeff); - } - } + impl_field_ext_chip_common!(); } mod bn254 { diff --git a/halo2-ecc/src/fields/fp2.rs b/halo2-ecc/src/fields/fp2.rs index 633ae6fa..55e3243a 100644 --- a/halo2-ecc/src/fields/fp2.rs +++ b/halo2-ecc/src/fields/fp2.rs @@ -1,97 +1,66 @@ -use super::{FieldChip, FieldExtConstructor, FieldExtPoint, PrimeFieldChip, Selectable}; -use crate::halo2_proofs::{arithmetic::Field, circuit::Value}; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::{fe_to_biguint, value_to_option, PrimeField}, - AssignedValue, Context, - QuantumCell::Existing, -}; -use num_bigint::{BigInt, BigUint}; +use std::fmt::Debug; use std::marker::PhantomData; -/// Represent Fp2 point as `FieldExtPoint` with degree = 2 +use halo2_base::{utils::modulus, AssignedValue, Context}; +use num_bigint::BigUint; + +use crate::impl_field_ext_chip_common; + +use super::{ + vector::{FieldVector, FieldVectorChip}, + FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, +}; + +/// Represent Fp2 point as `FieldVector` with degree = 2 /// `Fp2 = Fp[u] / (u^2 + 1)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to be irreducible over Fp; i.e., in order for -1 to not be a square (quadratic residue) in Fp /// This means we store an Fp2 point as `a_0 + a_1 * u` where `a_0, a_1 in Fp` -#[derive(Clone, Debug)] -pub struct Fp2Chip<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp2: Field> -where - FpChip::FieldType: PrimeField, -{ - // for historical reasons, leaving this as a reference - // for the current implementation we could also just use the de-referenced version: `fp_chip: FpChip` - pub fp_chip: &'a FpChip, - _f: PhantomData, - _fp2: PhantomData, -} +#[derive(Clone, Copy, Debug)] +pub struct Fp2Chip<'a, F: PrimeField, FpChip: FieldChip, Fp2>( + pub FieldVectorChip<'a, F, FpChip>, + PhantomData, +); -impl<'a, F, FpChip, Fp2> Fp2Chip<'a, F, FpChip, Fp2> +impl<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp2: ff::Field> Fp2Chip<'a, F, FpChip, Fp2> where - F: PrimeField, - FpChip: PrimeFieldChip, FpChip::FieldType: PrimeField, - Fp2: Field + FieldExtConstructor, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. - pub fn construct(fp_chip: &'a FpChip) -> Self { - Self { fp_chip, _f: PhantomData, _fp2: PhantomData } + pub fn new(fp_chip: &'a FpChip) -> Self { + assert_eq!( + modulus::() % 4usize, + BigUint::from(3u64), + "p must be 3 (mod 4) for the polynomial u^2 + 1 to be irreducible" + ); + Self(FieldVectorChip::new(fp_chip), PhantomData) } - pub fn fp_mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - fp_point: &FpChip::FieldPoint<'v>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 2); - - let mut out_coeffs = Vec::with_capacity(2); - for c in &a.coeffs { - let coeff = self.fp_chip.mul_no_carry(ctx, c, fp_point); - out_coeffs.push(coeff); - } - FieldExtPoint::construct(out_coeffs) + pub fn fp_chip(&self) -> &FpChip { + self.0.fp_chip } - pub fn conjugate<'v>( + pub fn conjugate( &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 2); + ctx: &mut Context, + a: FieldVector, + ) -> FieldVector { + let mut a = a.0; + assert_eq!(a.len(), 2); - let neg_a1 = self.fp_chip.negate(ctx, &a.coeffs[1]); - FieldExtPoint::construct(vec![a.coeffs[0].clone(), neg_a1]) + let neg_a1 = self.fp_chip().negate(ctx, a.pop().unwrap()); + FieldVector(vec![a.pop().unwrap(), neg_a1]) } - pub fn neg_conjugate<'v>( + pub fn neg_conjugate( &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 2); - - let neg_a0 = self.fp_chip.negate(ctx, &a.coeffs[0]); - FieldExtPoint::construct(vec![neg_a0, a.coeffs[1].clone()]) - } + ctx: &mut Context, + a: FieldVector, + ) -> FieldVector { + assert_eq!(a.0.len(), 2); + let mut a = a.0.into_iter(); - pub fn select<'v>( - &self, - ctx: &mut Context<'_, F>, - a: &FieldExtPoint>, - b: &FieldExtPoint>, - sel: &AssignedValue<'v, F>, - ) -> FieldExtPoint> - where - FpChip: Selectable = FpChip::FieldPoint<'v>>, - { - let coeffs: Vec<_> = a - .coeffs - .iter() - .zip(b.coeffs.iter()) - .map(|(a, b)| self.fp_chip.select(ctx, a, b, sel)) - .collect(); - FieldExtPoint::construct(coeffs) + let neg_a0 = self.fp_chip().negate(ctx, a.next().unwrap()); + FieldVector(vec![neg_a0, a.next().unwrap()]) } } @@ -99,302 +68,52 @@ impl<'a, F, FpChip, Fp2> FieldChip for Fp2Chip<'a, F, FpChip, Fp2> where F: PrimeField, FpChip::FieldType: PrimeField, - FpChip: PrimeFieldChip, ConstantType = BigUint>, - Fp2: Field + FieldExtConstructor, + FpChip: PrimeFieldChip, + Fp2: ff::Field + FieldExtConstructor, + FieldVector: From>, + FieldVector: From>, { const PRIME_FIELD_NUM_BITS: u32 = FpChip::FieldType::NUM_BITS; - type ConstantType = Fp2; - type WitnessType = Vec>; - type FieldPoint<'v> = FieldExtPoint>; + type UnsafeFieldPoint = FieldVector; + type FieldPoint = FieldVector; + type ReducedFieldPoint = FieldVector; type FieldType = Fp2; type RangeChip = FpChip::RangeChip; - fn native_modulus(&self) -> &BigUint { - self.fp_chip.native_modulus() - } - fn range(&self) -> &Self::RangeChip { - self.fp_chip.range() - } - - fn limb_bits(&self) -> usize { - self.fp_chip.limb_bits() - } - - fn get_assigned_value(&self, x: &Self::FieldPoint<'_>) -> Value { - assert_eq!(x.coeffs.len(), 2); - let c0 = self.fp_chip.get_assigned_value(&x.coeffs[0]); - let c1 = self.fp_chip.get_assigned_value(&x.coeffs[1]); - c0.zip(c1).map(|(c0, c1)| Fp2::new([c0, c1])) - } - - fn fe_to_constant(x: Fp2) -> Fp2 { - x + fn get_assigned_value(&self, x: &Self::UnsafeFieldPoint) -> Fp2 { + assert_eq!(x.0.len(), 2); + let c0 = self.fp_chip().get_assigned_value(&x[0]); + let c1 = self.fp_chip().get_assigned_value(&x[1]); + Fp2::new([c0, c1]) } - fn fe_to_witness(x: &Value) -> Vec> { - match value_to_option(*x) { - None => vec![Value::unknown(), Value::unknown()], - Some(x) => { - let coeffs = x.coeffs(); - assert_eq!(coeffs.len(), 2); - coeffs.iter().map(|c| Value::known(BigInt::from(fe_to_biguint(c)))).collect() - } - } - } - - fn load_private<'v>( - &self, - ctx: &mut Context<'_, F>, - coeffs: Vec>, - ) -> Self::FieldPoint<'v> { - assert_eq!(coeffs.len(), 2); - let mut assigned_coeffs = Vec::with_capacity(2); - for a in coeffs { - let assigned_coeff = self.fp_chip.load_private(ctx, a); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - fn load_constant<'v>(&self, ctx: &mut Context<'_, F>, c: Fp2) -> Self::FieldPoint<'v> { - let mut assigned_coeffs = Vec::with_capacity(2); - for a in &c.coeffs() { - let assigned_coeff = self.fp_chip.load_constant(ctx, fe_to_biguint(a)); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - // signed overflow BigInt functions - fn add_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn add_constant_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: Self::ConstantType, - ) -> Self::FieldPoint<'v> { - let c_coeffs = c.coeffs(); - assert_eq!(a.coeffs.len(), c_coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for (a, c) in a.coeffs.iter().zip(c_coeffs.into_iter()) { - let coeff = self.fp_chip.add_constant_no_carry(ctx, a, FpChip::fe_to_constant(c)); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn sub_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.sub_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn negate<'v>( + fn mul_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let out_coeff = self.fp_chip.negate(ctx, a_coeff); - out_coeffs.push(out_coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: i64, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.scalar_mul_no_carry(ctx, &a.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_and_add_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - c: i64, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = - self.fp_chip.scalar_mul_and_add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + let a = a.into().0; + let b = b.into().0; + assert_eq!(a.len(), 2); + assert_eq!(b.len(), 2); + let fp_chip = self.fp_chip(); // (a_0 + a_1 * u) * (b_0 + b_1 * u) = (a_0 b_0 - a_1 b_1) + (a_0 b_1 + a_1 b_0) * u - let mut ab_coeffs = Vec::with_capacity(a.coeffs.len() * b.coeffs.len()); - for i in 0..a.coeffs.len() { - for j in 0..b.coeffs.len() { - let coeff = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &b.coeffs[j]); + let mut ab_coeffs = Vec::with_capacity(4); + for a_i in a { + for b_j in b.iter() { + let coeff = fp_chip.mul_no_carry(ctx, &a_i, b_j); ab_coeffs.push(coeff); } } - let a0b0_minus_a1b1 = - self.fp_chip.sub_no_carry(ctx, &ab_coeffs[0], &ab_coeffs[b.coeffs.len() + 1]); - let a0b1_plus_a1b0 = - self.fp_chip.add_no_carry(ctx, &ab_coeffs[1], &ab_coeffs[b.coeffs.len()]); - - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - out_coeffs.push(a0b0_minus_a1b1); - out_coeffs.push(a0b1_plus_a1b0); - - Self::FieldPoint::construct(out_coeffs) - } - - fn check_carry_mod_to_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - for coeff in &a.coeffs { - self.fp_chip.check_carry_mod_to_zero(ctx, coeff); - } - } - - fn carry_mod<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.carry_mod(ctx, a_coeff); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn range_check<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>, max_bits: usize) { - for a_coeff in &a.coeffs { - self.fp_chip.range_check(ctx, a_coeff, max_bits); - } - } - - fn enforce_less_than<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - for a_coeff in &a.coeffs { - self.fp_chip.enforce_less_than(ctx, a_coeff) - } - } - - fn is_soft_zero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_soft_nonzero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().or(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_zero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } + let a0b0_minus_a1b1 = fp_chip.sub_no_carry(ctx, &ab_coeffs[0], &ab_coeffs[3]); + let a0b1_plus_a1b0 = fp_chip.add_no_carry(ctx, &ab_coeffs[1], &ab_coeffs[2]); - fn is_equal_unenforced<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut acc = None; - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff); - if let Some(c) = acc { - acc = Some(self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&c))); - } else { - acc = Some(coeff); - } - } - acc.unwrap() + FieldVector(vec![a0b0_minus_a1b1, a0b1_plus_a1b0]) } - fn assert_equal<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) { - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - self.fp_chip.assert_equal(ctx, a_coeff, b_coeff) - } - } + // ========= inherited from FieldVectorChip ========= + impl_field_ext_chip_common!(); } mod bn254 { diff --git a/halo2-ecc/src/fields/mod.rs b/halo2-ecc/src/fields/mod.rs index e5e65f16..0c55affa 100644 --- a/halo2-ecc/src/fields/mod.rs +++ b/halo2-ecc/src/fields/mod.rs @@ -1,40 +1,52 @@ -use crate::halo2_proofs::{arithmetic::Field, circuit::Value}; -use halo2_base::{gates::RangeInstructions, utils::PrimeField, AssignedValue, Context}; +use crate::halo2_proofs::arithmetic::Field; +use halo2_base::{ + gates::{GateInstructions, RangeInstructions}, + utils::{BigPrimeField, ScalarField}, + AssignedValue, Context, +}; use num_bigint::BigUint; +use serde::{Deserialize, Serialize}; use std::fmt::Debug; pub mod fp; pub mod fp12; pub mod fp2; +pub mod vector; #[cfg(test)] mod tests; -#[derive(Clone, Debug)] -pub struct FieldExtPoint { - // `F_q` field extension of `F_p` where `q = p^degree` - // An `F_q` point consists of `degree` number of `F_p` points - // The `F_p` points are stored as `FieldPoint`s +pub trait PrimeField = BigPrimeField; - // We do not specify the irreducible `F_p` polynomial used to construct `F_q` here - that is implementation specific - pub coeffs: Vec, - // `degree = coeffs.len()` -} - -impl FieldExtPoint { - pub fn construct(coeffs: Vec) -> Self { - Self { coeffs } - } -} - -/// Common functionality for finite field chips -pub trait FieldChip { +/// Trait for common functionality for finite field chips. +/// Primarily intended to emulate a "non-native" finite field using "native" values in a prime field `F`. +/// Most functions are designed for the case when the non-native field is larger than the native field, but +/// the trait can still be implemented and used in other cases. +pub trait FieldChip: Clone + Send + Sync { const PRIME_FIELD_NUM_BITS: u32; - type ConstantType: Debug; - type WitnessType: Debug; - type FieldPoint<'v>: Clone + Debug; - // a type implementing `Field` trait to help with witness generation (for example with inverse) + /// A representation of a field element that is used for intermediate computations. + /// The representation can have "overflows" (e.g., overflow limbs or negative limbs). + type UnsafeFieldPoint: Clone + + Debug + + Send + + Sync + + From + + for<'a> From<&'a Self::UnsafeFieldPoint> + + for<'a> From<&'a Self::FieldPoint>; // Cloning all the time impacts readability, so we allow references to be cloned into owned values + + /// The "proper" representation of a field element. Allowed to be a non-unique representation of a field element (e.g., can be greater than modulus) + type FieldPoint: Clone + + Debug + + Send + + Sync + + From + + for<'a> From<&'a Self::FieldPoint>; + + /// A proper representation of field elements that guarantees a unique representation of each field element. Typically this means Uints that are less than the modulus. + type ReducedFieldPoint: Clone + Debug + Send + Sync; + + /// A type implementing `Field` trait to help with witness generation (for example with inverse) type FieldType: Field; type RangeChip: RangeInstructions; @@ -45,212 +57,242 @@ pub trait FieldChip { fn range(&self) -> &Self::RangeChip; fn limb_bits(&self) -> usize; - fn get_assigned_value(&self, x: &Self::FieldPoint<'_>) -> Value; + fn get_assigned_value(&self, x: &Self::UnsafeFieldPoint) -> Self::FieldType; - fn fe_to_constant(x: Self::FieldType) -> Self::ConstantType; - fn fe_to_witness(x: &Value) -> Self::WitnessType; + /// Assigns `fe` as private witness. Note that the witness may **not** be constrained to be a unique representation of the field element `fe`. + fn load_private(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint; - fn load_private<'v>( + /// Assigns `fe` as private witness and contrains the witness to be in reduced form. + fn load_private_reduced( &self, - ctx: &mut Context<'_, F>, - coeffs: Self::WitnessType, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + fe: Self::FieldType, + ) -> Self::ReducedFieldPoint { + let fe = self.load_private(ctx, fe); + self.enforce_less_than(ctx, fe) + } - fn load_constant<'v>( - &self, - ctx: &mut Context<'_, F>, - coeffs: Self::ConstantType, - ) -> Self::FieldPoint<'v>; + /// Assigns `fe` as constant. + fn load_constant(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint; - fn add_no_carry<'v>( + fn add_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint; /// output: `a + c` - fn add_constant_no_carry<'v>( + fn add_constant_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: Self::ConstantType, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + a: impl Into, + c: Self::FieldType, + ) -> Self::UnsafeFieldPoint; - fn sub_no_carry<'v>( + fn sub_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint; - fn negate<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + fn negate(&self, ctx: &mut Context, a: Self::FieldPoint) -> Self::FieldPoint; /// a * c - fn scalar_mul_no_carry<'v>( + fn scalar_mul_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, + ctx: &mut Context, + a: impl Into, c: i64, - ) -> Self::FieldPoint<'v>; + ) -> Self::UnsafeFieldPoint; /// a * c + b - fn scalar_mul_and_add_no_carry<'v>( + fn scalar_mul_and_add_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, + ctx: &mut Context, + a: impl Into, + b: impl Into, c: i64, - ) -> Self::FieldPoint<'v>; + ) -> Self::UnsafeFieldPoint; - fn mul_no_carry<'v>( + fn mul_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint; - fn check_carry_mod_to_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>); + fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint); - fn carry_mod<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + fn carry_mod(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) -> Self::FieldPoint; - fn range_check<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>, max_bits: usize); + fn range_check( + &self, + ctx: &mut Context, + a: impl Into, + max_bits: usize, + ); - fn enforce_less_than<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>); + /// Constrains that `a` is a reduced representation and returns the wrapped `a`. + fn enforce_less_than( + &self, + ctx: &mut Context, + a: Self::FieldPoint, + ) -> Self::ReducedFieldPoint; - // Assumes the witness for a is 0 - // Constrains that the underlying big integer is 0 and < p. + // Returns 1 iff the underlying big integer for `a` is 0. Otherwise returns 0. // For field extensions, checks coordinate-wise. - fn is_soft_zero<'v>( + fn is_soft_zero( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F>; + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue; - // Constrains that the underlying big integer is in [1, p - 1]. + // Constrains that the underlying big integer is in [0, p - 1]. + // Then returns 1 iff the underlying big integer for `a` is 0. Otherwise returns 0. // For field extensions, checks coordinate-wise. - fn is_soft_nonzero<'v>( + fn is_soft_nonzero( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F>; + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue; - fn is_zero<'v>( + fn is_zero(&self, ctx: &mut Context, a: impl Into) -> AssignedValue; + + fn is_equal_unenforced( + &self, + ctx: &mut Context, + a: Self::ReducedFieldPoint, + b: Self::ReducedFieldPoint, + ) -> AssignedValue; + + fn assert_equal( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F>; + ctx: &mut Context, + a: impl Into, + b: impl Into, + ); + + // =========== default implementations ============= // assuming `a, b` have been range checked to be a proper BigInt // constrain the witnesses `a, b` to be `< p` // then check `a == b` as BigInts - fn is_equal<'v>( + fn is_equal( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - self.enforce_less_than(ctx, a); - self.enforce_less_than(ctx, b); + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> AssignedValue { + let a = self.enforce_less_than(ctx, a.into()); + let b = self.enforce_less_than(ctx, b.into()); // a.native and b.native are derived from `a.truncation, b.truncation`, so no need to check if they're equal self.is_equal_unenforced(ctx, a, b) } - fn is_equal_unenforced<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F>; - - fn assert_equal<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ); - - fn mul<'v>( + /// If using `UnsafeFieldPoint`, make sure multiplication does not cause overflow. + fn mul( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { let no_carry = self.mul_no_carry(ctx, a, b); - self.carry_mod(ctx, &no_carry) + self.carry_mod(ctx, no_carry) } - fn divide<'v>( + /// Constrains that `b` is nonzero as a field element and then returns `a / b`. + fn divide( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let a_val = self.get_assigned_value(a); - let b_val = self.get_assigned_value(b); - let b_inv = b_val.map(|bv| bv.invert().unwrap()); - let quot_val = a_val.zip(b_inv).map(|(a, bi)| a * bi); + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let b = b.into(); + let b_is_zero = self.is_zero(ctx, b.clone()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + + self.divide_unsafe(ctx, a.into(), b) + } - let quot = self.load_private(ctx, Self::fe_to_witness("_val)); + /// Returns `a / b` without constraining `b` to be nonzero. + /// + /// Warning: undefined behavior when `b` is zero. + /// + /// `a, b` must be such that `quot * b - a` without carry does not overflow, where `quot` is the output. + fn divide_unsafe( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let a = a.into(); + let b = b.into(); + let a_val = self.get_assigned_value(&a); + let b_val = self.get_assigned_value(&b); + let b_inv: Self::FieldType = Option::from(b_val.invert()).unwrap_or_default(); + let quot_val = a_val * b_inv; + + let quot = self.load_private(ctx, quot_val); // constrain quot * b - a = 0 mod p - let quot_b = self.mul_no_carry(ctx, ", b); - let quot_constraint = self.sub_no_carry(ctx, "_b, a); - self.check_carry_mod_to_zero(ctx, "_constraint); + let quot_b = self.mul_no_carry(ctx, quot.clone(), b); + let quot_constraint = self.sub_no_carry(ctx, quot_b, a); + self.check_carry_mod_to_zero(ctx, quot_constraint); quot } - // constrain and output -a / b + /// Constrains that `b` is nonzero as a field element and then returns `-a / b`. + fn neg_divide( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let b = b.into(); + let b_is_zero = self.is_zero(ctx, b.clone()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + + self.neg_divide_unsafe(ctx, a.into(), b) + } + + // Returns `-a / b` without constraining `b` to be nonzero. // this is usually cheaper constraint-wise than computing -a and then (-a) / b separately - fn neg_divide<'v>( + fn neg_divide_unsafe( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let a_val = self.get_assigned_value(a); - let b_val = self.get_assigned_value(b); - let b_inv = b_val.map(|bv| bv.invert().unwrap()); - let quot_val = a_val.zip(b_inv).map(|(a, b)| -a * b); - - let quot = self.load_private(ctx, Self::fe_to_witness("_val)); - self.range_check(ctx, ", Self::PRIME_FIELD_NUM_BITS as usize); + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let a = a.into(); + let b = b.into(); + let a_val = self.get_assigned_value(&a); + let b_val = self.get_assigned_value(&b); + let b_inv: Self::FieldType = Option::from(b_val.invert()).unwrap_or_default(); + let quot_val = -a_val * b_inv; + + let quot = self.load_private(ctx, quot_val); // constrain quot * b + a = 0 mod p - let quot_b = self.mul_no_carry(ctx, ", b); - let quot_constraint = self.add_no_carry(ctx, "_b, a); - self.check_carry_mod_to_zero(ctx, "_constraint); + let quot_b = self.mul_no_carry(ctx, quot.clone(), b); + let quot_constraint = self.add_no_carry(ctx, quot_b, a); + self.check_carry_mod_to_zero(ctx, quot_constraint); quot } } -pub trait Selectable { - type Point<'v>; +pub trait Selectable { + fn select(&self, ctx: &mut Context, a: Pt, b: Pt, sel: AssignedValue) -> Pt; - fn select<'v>( + fn select_by_indicator( &self, - ctx: &mut Context<'_, F>, - a: &Self::Point<'v>, - b: &Self::Point<'v>, - sel: &AssignedValue<'v, F>, - ) -> Self::Point<'v>; - - fn select_by_indicator<'v>( - &self, - ctx: &mut Context<'_, F>, - a: &[Self::Point<'v>], - coeffs: &[AssignedValue<'v, F>], - ) -> Self::Point<'v>; + ctx: &mut Context, + a: &impl AsRef<[Pt]>, + coeffs: &[AssignedValue], + ) -> Pt; } // Common functionality for prime field chips @@ -265,8 +307,13 @@ where // helper trait so we can actually construct and read the Fp2 struct // needs to be implemented for Fp2 struct for use cases below -pub trait FieldExtConstructor { +pub trait FieldExtConstructor { fn new(c: [Fp; DEGREE]) -> Self; fn coeffs(&self) -> Vec; } + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub enum FpStrategy { + Simple, +} diff --git a/halo2-ecc/src/fields/tests.rs b/halo2-ecc/src/fields/tests.rs deleted file mode 100644 index 36398e65..00000000 --- a/halo2-ecc/src/fields/tests.rs +++ /dev/null @@ -1,267 +0,0 @@ -mod fp { - use crate::fields::{ - fp::{FpConfig, FpStrategy}, - FieldChip, - }; - use crate::halo2_proofs::{ - circuit::*, - dev::MockProver, - halo2curves::bn256::{Fq, Fr}, - plonk::*, - }; - use group::ff::Field; - use halo2_base::{ - utils::{fe_to_biguint, modulus, PrimeField}, - SKIP_FIRST_PASS, - }; - use num_bigint::BigInt; - use rand::rngs::OsRng; - use std::marker::PhantomData; - - #[derive(Default)] - struct MyCircuit { - a: Value, - b: Value, - _marker: PhantomData, - } - - const NUM_ADVICE: usize = 1; - const NUM_FIXED: usize = 1; - const K: usize = 10; - - impl Circuit for MyCircuit { - type Config = FpConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FpConfig::::configure( - meta, - FpStrategy::Simple, - &[NUM_ADVICE], - &[1], - NUM_FIXED, - 9, - 88, - 3, - modulus::(), - 0, - K, - ) - } - - fn synthesize( - &self, - chip: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "fp", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = chip.new_context(region); - let ctx = &mut aux; - - let a_assigned = - chip.load_private(ctx, self.a.map(|a| BigInt::from(fe_to_biguint(&a)))); - let b_assigned = - chip.load_private(ctx, self.b.map(|b| BigInt::from(fe_to_biguint(&b)))); - - // test fp_multiply - { - chip.mul(ctx, &a_assigned, &b_assigned); - } - - // IMPORTANT: this copies advice cells to enable lookup - // This is not optional. - chip.finalize(ctx); - - #[cfg(feature = "display")] - { - println!( - "Using {NUM_ADVICE} advice columns and {NUM_FIXED} fixed columns" - ); - println!("total cells: {}", ctx.total_advice); - - let (const_rows, _) = ctx.fixed_stats(); - println!("maximum rows used by a fixed column: {const_rows}"); - } - Ok(()) - }, - ) - } - } - - #[test] - fn test_fp() { - let a = Fq::random(OsRng); - let b = Fq::random(OsRng); - - let circuit = - MyCircuit:: { a: Value::known(a), b: Value::known(b), _marker: PhantomData }; - - let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - //assert_eq!(prover.verify(), Ok(())); - } - - #[cfg(feature = "dev-graph")] - #[test] - fn plot_fp() { - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); - - let circuit = MyCircuit::::default(); - halo2_proofs::dev::CircuitLayout::default().render(K as u32, &circuit, &root).unwrap(); - } -} - -mod fp12 { - use crate::fields::{ - fp::{FpConfig, FpStrategy}, - fp12::*, - FieldChip, - }; - use crate::halo2_proofs::{ - circuit::*, - dev::MockProver, - halo2curves::bn256::{Fq, Fq12, Fr}, - plonk::*, - }; - use halo2_base::utils::modulus; - use halo2_base::{utils::PrimeField, SKIP_FIRST_PASS}; - use std::marker::PhantomData; - - #[derive(Default)] - struct MyCircuit { - a: Value, - b: Value, - _marker: PhantomData, - } - - const NUM_ADVICE: usize = 1; - const NUM_FIXED: usize = 1; - const XI_0: i64 = 9; - - impl Circuit for MyCircuit { - type Config = FpConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FpConfig::::configure( - meta, - FpStrategy::Simple, - &[NUM_ADVICE], - &[1], - NUM_FIXED, - 22, - 88, - 3, - modulus::(), - 0, - 23, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.load_lookup_table(&mut layouter)?; - let chip = Fp12Chip::, Fq12, XI_0>::construct(&config); - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "fp12", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = config.new_context(region); - let ctx = &mut aux; - - let a_assigned = chip.load_private( - ctx, - Fp12Chip::, Fq12, XI_0>::fe_to_witness(&self.a), - ); - let b_assigned = chip.load_private( - ctx, - Fp12Chip::, Fq12, XI_0>::fe_to_witness(&self.b), - ); - - // test fp_multiply - { - chip.mul(ctx, &a_assigned, &b_assigned); - } - - // IMPORTANT: this copies advice cells to enable lookup - // This is not optional. - chip.fp_chip.finalize(ctx); - - #[cfg(feature = "display")] - { - println!( - "Using {NUM_ADVICE} advice columns and {NUM_FIXED} fixed columns" - ); - println!("total advice cells: {}", ctx.total_advice); - - let (const_rows, _) = ctx.fixed_stats(); - println!("maximum rows used by a fixed column: {const_rows}"); - } - Ok(()) - }, - ) - } - } - - #[test] - fn test_fp12() { - let k = 23; - let mut rng = rand::thread_rng(); - let a = Fq12::random(&mut rng); - let b = Fq12::random(&mut rng); - - let circuit = - MyCircuit:: { a: Value::known(a), b: Value::known(b), _marker: PhantomData }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - // assert_eq!(prover.verify(), Ok(())); - } - - #[cfg(feature = "dev-graph")] - #[test] - fn plot_fp12() { - let k = 9; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); - - let circuit = MyCircuit::::default(); - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); - } -} diff --git a/halo2-ecc/src/fields/tests/fp/assert_eq.rs b/halo2-ecc/src/fields/tests/fp/assert_eq.rs new file mode 100644 index 00000000..5aac74bf --- /dev/null +++ b/halo2-ecc/src/fields/tests/fp/assert_eq.rs @@ -0,0 +1,82 @@ +use std::env::set_var; + +use ff::Field; +use halo2_base::{ + gates::{ + builder::{GateThreadBuilder, RangeCircuitBuilder}, + tests::{check_proof, gen_proof}, + RangeChip, + }, + halo2_proofs::{ + halo2curves::bn256::Fq, plonk::keygen_pk, plonk::keygen_vk, + poly::kzg::commitment::ParamsKZG, + }, +}; + +use crate::{bn254::FpChip, fields::FieldChip}; +use rand::thread_rng; + +// soundness checks for `` function +fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { + let mut rng = thread_rng(); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + + // first create proving and verifying key + let mut builder = GateThreadBuilder::keygen(); + let range = RangeChip::default(lookup_bits); + let chip = FpChip::new(&range, 88, 3); + + let ctx = builder.main(0); + let a = chip.load_private(ctx, Fq::zero()); + let b = chip.load_private(ctx, Fq::zero()); + chip.assert_equal(ctx, &a, &b); + // set env vars + builder.config(k as usize, Some(9)); + let circuit = RangeCircuitBuilder::keygen(builder); + + let params = ParamsKZG::setup(k, &mut rng); + // generate proving key + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = pk.get_vk(); // pk consumed vk + + // now create different proofs to test the soundness of the circuit + + let gen_pf = |a: Fq, b: Fq| { + let mut builder = GateThreadBuilder::prover(); + let range = RangeChip::default(lookup_bits); + let chip = FpChip::new(&range, 88, 3); + + let ctx = builder.main(0); + let [a, b] = [a, b].map(|x| chip.load_private(ctx, x)); + chip.assert_equal(ctx, &a, &b); + let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points + gen_proof(¶ms, &pk, circuit) + }; + + // expected answer + for _ in 0..num_tries { + let a = Fq::random(&mut rng); + let pf = gen_pf(a, a); + check_proof(¶ms, vk, &pf, true); + } + + // unequal + for _ in 0..num_tries { + let a = Fq::random(&mut rng); + let b = Fq::random(&mut rng); + if a == b { + continue; + } + let pf = gen_pf(a, b); + check_proof(¶ms, vk, &pf, false); + } +} + +#[test] +fn test_fp_assert_eq() { + test_fp_assert_eq_gen(10, 4, 100); + test_fp_assert_eq_gen(10, 8, 100); + test_fp_assert_eq_gen(10, 9, 100); + test_fp_assert_eq_gen(18, 17, 10); +} diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs new file mode 100644 index 00000000..9489abb5 --- /dev/null +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -0,0 +1,72 @@ +use crate::fields::fp::FpChip; +use crate::fields::{FieldChip, PrimeField}; +use crate::halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Fq, Fr}, +}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::RangeChip; +use halo2_base::utils::biguint_to_fe; +use halo2_base::utils::{fe_to_biguint, modulus}; +use halo2_base::Context; +use rand::rngs::OsRng; + +pub mod assert_eq; + +const K: usize = 10; + +fn fp_mul_test( + ctx: &mut Context, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + _a: Fq, + _b: Fq, +) { + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range = RangeChip::::default(lookup_bits); + let chip = FpChip::::new(&range, limb_bits, num_limbs); + + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b); + + assert_eq!(c.0.truncation.to_bigint(limb_bits), c.0.value); + assert_eq!(c.native().value(), &biguint_to_fe(&(c.value() % modulus::()))); + assert_eq!(c.0.value, fe_to_biguint(&(_a * _b)).into()) +} + +#[test] +fn test_fp() { + let k = K; + let a = Fq::random(OsRng); + let b = Fq::random(OsRng); + + let mut builder = GateThreadBuilder::::mock(); + fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); + + builder.config(k, Some(10)); + let circuit = RangeCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[cfg(feature = "dev-graph")] +#[test] +fn plot_fp() { + use plotters::prelude::*; + + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); + + let k = K; + let a = Fq::zero(); + let b = Fq::zero(); + + let mut builder = GateThreadBuilder::keygen(); + fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); + + builder.config(k, Some(10)); + let circuit = RangeCircuitBuilder::keygen(builder); + halo2_proofs::dev::CircuitLayout::default().render(k as u32, &circuit, &root).unwrap(); +} diff --git a/halo2-ecc/src/fields/tests/fp12/mod.rs b/halo2-ecc/src/fields/tests/fp12/mod.rs new file mode 100644 index 00000000..6fb631b9 --- /dev/null +++ b/halo2-ecc/src/fields/tests/fp12/mod.rs @@ -0,0 +1,73 @@ +use crate::fields::fp::FpChip; +use crate::fields::fp12::Fp12Chip; +use crate::fields::{FieldChip, PrimeField}; +use crate::halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Fq, Fq12, Fr}, +}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::RangeChip; +use halo2_base::Context; +use rand_core::OsRng; + +const XI_0: i64 = 9; + +fn fp12_mul_test( + ctx: &mut Context, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + _a: Fq12, + _b: Fq12, +) { + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range = RangeChip::::default(lookup_bits); + let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); + let chip = Fp12Chip::::new(&fp_chip); + + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b).into(); + + assert_eq!(chip.get_assigned_value(&c), _a * _b); + for c in c.into_iter() { + assert_eq!(c.truncation.to_bigint(limb_bits), c.value); + } +} + +#[test] +fn test_fp12() { + let k = 12; + let a = Fq12::random(OsRng); + let b = Fq12::random(OsRng); + + let mut builder = GateThreadBuilder::::mock(); + fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); + + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[cfg(feature = "dev-graph")] +#[test] +fn plot_fp12() { + use ff::Field; + use plotters::prelude::*; + + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); + + let k = 23; + let a = Fq12::zero(); + let b = Fq12::zero(); + + let mut builder = GateThreadBuilder::::mock(); + fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); + + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + + halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); +} diff --git a/halo2-ecc/src/fields/tests/mod.rs b/halo2-ecc/src/fields/tests/mod.rs new file mode 100644 index 00000000..460ae96a --- /dev/null +++ b/halo2-ecc/src/fields/tests/mod.rs @@ -0,0 +1,2 @@ +pub mod fp; +pub mod fp12; diff --git a/halo2-ecc/src/fields/vector.rs b/halo2-ecc/src/fields/vector.rs new file mode 100644 index 00000000..6aea9d97 --- /dev/null +++ b/halo2-ecc/src/fields/vector.rs @@ -0,0 +1,495 @@ +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use itertools::Itertools; +use std::{ + marker::PhantomData, + ops::{Index, IndexMut}, +}; + +use crate::bigint::{CRTInteger, ProperCrtUint}; + +use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, Selectable}; + +/// A fixed length vector of `FieldPoint`s +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct FieldVector(pub Vec); + +impl Index for FieldVector { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } +} + +impl IndexMut for FieldVector { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index] + } +} + +impl AsRef<[T]> for FieldVector { + fn as_ref(&self) -> &[T] { + &self.0 + } +} + +impl<'a, T: Clone, U: From> From<&'a FieldVector> for FieldVector { + fn from(other: &'a FieldVector) -> Self { + FieldVector(other.clone().into_iter().map(Into::into).collect()) + } +} + +impl From>> for FieldVector> { + fn from(other: FieldVector>) -> Self { + FieldVector(other.into_iter().map(|x| x.0).collect()) + } +} + +impl From>> for FieldVector { + fn from(value: FieldVector>) -> Self { + FieldVector(value.0.into_iter().map(|x| x.0).collect()) + } +} + +impl IntoIterator for FieldVector { + type Item = T; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +/// Contains common functionality for vector operations that can be derived from those of the underlying `FpChip` +#[derive(Clone, Copy, Debug)] +pub struct FieldVectorChip<'fp, F: PrimeField, FpChip: FieldChip> { + pub fp_chip: &'fp FpChip, + _f: PhantomData, +} + +impl<'fp, F, FpChip> FieldVectorChip<'fp, F, FpChip> +where + F: PrimeField, + FpChip: PrimeFieldChip, + FpChip::FieldType: PrimeField, +{ + pub fn new(fp_chip: &'fp FpChip) -> Self { + Self { fp_chip, _f: PhantomData } + } + + pub fn gate(&self) -> &impl GateInstructions { + self.fp_chip.gate() + } + + pub fn fp_mul_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + fp_point: impl Into, + ) -> FieldVector + where + FP: Into, + { + let fp_point = fp_point.into(); + FieldVector( + a.into_iter().map(|a| self.fp_chip.mul_no_carry(ctx, a, fp_point.clone())).collect(), + ) + } + + pub fn select( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + sel: AssignedValue, + ) -> FieldVector + where + FpChip: Selectable, + { + FieldVector( + a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.select(ctx, a, b, sel)).collect(), + ) + } + + pub fn load_private( + &self, + ctx: &mut Context, + fe: FieldExt, + ) -> FieldVector + where + FieldExt: FieldExtConstructor, + { + FieldVector(fe.coeffs().into_iter().map(|a| self.fp_chip.load_private(ctx, a)).collect()) + } + + pub fn load_constant( + &self, + ctx: &mut Context, + c: FieldExt, + ) -> FieldVector + where + FieldExt: FieldExtConstructor, + { + FieldVector(c.coeffs().into_iter().map(|a| self.fp_chip.load_constant(ctx, a)).collect()) + } + + // signed overflow BigInt functions + pub fn add_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> FieldVector + where + A: Into, + B: Into, + { + FieldVector( + a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.add_no_carry(ctx, a, b)).collect(), + ) + } + + pub fn add_constant_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + c: FieldExt, + ) -> FieldVector + where + A: Into, + FieldExt: FieldExtConstructor, + { + let c_coeffs = c.coeffs(); + FieldVector( + a.into_iter() + .zip_eq(c_coeffs) + .map(|(a, c)| self.fp_chip.add_constant_no_carry(ctx, a, c)) + .collect(), + ) + } + + pub fn sub_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> FieldVector + where + A: Into, + B: Into, + { + FieldVector( + a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.sub_no_carry(ctx, a, b)).collect(), + ) + } + + pub fn negate( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> FieldVector { + FieldVector(a.into_iter().map(|a| self.fp_chip.negate(ctx, a)).collect()) + } + + pub fn scalar_mul_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + c: i64, + ) -> FieldVector + where + A: Into, + { + FieldVector(a.into_iter().map(|a| self.fp_chip.scalar_mul_no_carry(ctx, a, c)).collect()) + } + + pub fn scalar_mul_and_add_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + c: i64, + ) -> FieldVector + where + A: Into, + B: Into, + { + FieldVector( + a.into_iter() + .zip_eq(b) + .map(|(a, b)| self.fp_chip.scalar_mul_and_add_no_carry(ctx, a, b, c)) + .collect(), + ) + } + + pub fn check_carry_mod_to_zero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) { + for coeff in a { + self.fp_chip.check_carry_mod_to_zero(ctx, coeff); + } + } + + pub fn carry_mod( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> FieldVector { + FieldVector(a.into_iter().map(|coeff| self.fp_chip.carry_mod(ctx, coeff)).collect()) + } + + pub fn range_check( + &self, + ctx: &mut Context, + a: impl IntoIterator, + max_bits: usize, + ) where + A: Into, + { + for coeff in a { + self.fp_chip.range_check(ctx, coeff, max_bits); + } + } + + pub fn enforce_less_than( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> FieldVector { + FieldVector(a.into_iter().map(|coeff| self.fp_chip.enforce_less_than(ctx, coeff)).collect()) + } + + pub fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> AssignedValue { + let mut prev = None; + for a_coeff in a { + let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff); + if let Some(p) = prev { + let new = self.gate().and(ctx, coeff, p); + prev = Some(new); + } else { + prev = Some(coeff); + } + } + prev.unwrap() + } + + pub fn is_soft_nonzero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> AssignedValue { + let mut prev = None; + for a_coeff in a { + let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff); + if let Some(p) = prev { + let new = self.gate().or(ctx, coeff, p); + prev = Some(new); + } else { + prev = Some(coeff); + } + } + prev.unwrap() + } + + pub fn is_zero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> AssignedValue { + let mut prev = None; + for a_coeff in a { + let coeff = self.fp_chip.is_zero(ctx, a_coeff); + if let Some(p) = prev { + let new = self.gate().and(ctx, coeff, p); + prev = Some(new); + } else { + prev = Some(coeff); + } + } + prev.unwrap() + } + + pub fn is_equal_unenforced( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> AssignedValue { + let mut acc = None; + for (a_coeff, b_coeff) in a.into_iter().zip_eq(b) { + let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff); + if let Some(c) = acc { + acc = Some(self.gate().and(ctx, coeff, c)); + } else { + acc = Some(coeff); + } + } + acc.unwrap() + } + + pub fn assert_equal( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) { + for (a_coeff, b_coeff) in a.into_iter().zip(b) { + self.fp_chip.assert_equal(ctx, a_coeff, b_coeff) + } + } +} + +#[macro_export] +macro_rules! impl_field_ext_chip_common { + // Implementation of the functions in `FieldChip` trait for field extensions that can be derived from `FieldVectorChip` + () => { + fn native_modulus(&self) -> &BigUint { + self.0.fp_chip.native_modulus() + } + + fn range(&self) -> &Self::RangeChip { + self.0.fp_chip.range() + } + + fn limb_bits(&self) -> usize { + self.0.fp_chip.limb_bits() + } + + fn load_private(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint { + self.0.load_private(ctx, fe) + } + + fn load_constant(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint { + self.0.load_constant(ctx, fe) + } + + fn add_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + self.0.add_no_carry(ctx, a.into(), b.into()) + } + + fn add_constant_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + c: Self::FieldType, + ) -> Self::UnsafeFieldPoint { + self.0.add_constant_no_carry(ctx, a.into(), c) + } + + fn sub_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + self.0.sub_no_carry(ctx, a.into(), b.into()) + } + + fn negate(&self, ctx: &mut Context, a: Self::FieldPoint) -> Self::FieldPoint { + self.0.negate(ctx, a) + } + + fn scalar_mul_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + c: i64, + ) -> Self::UnsafeFieldPoint { + self.0.scalar_mul_no_carry(ctx, a.into(), c) + } + + fn scalar_mul_and_add_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + c: i64, + ) -> Self::UnsafeFieldPoint { + self.0.scalar_mul_and_add_no_carry(ctx, a.into(), b.into(), c) + } + + fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) { + self.0.check_carry_mod_to_zero(ctx, a); + } + + fn carry_mod(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) -> Self::FieldPoint { + self.0.carry_mod(ctx, a) + } + + fn range_check( + &self, + ctx: &mut Context, + a: impl Into, + max_bits: usize, + ) { + self.0.range_check(ctx, a.into(), max_bits) + } + + fn enforce_less_than( + &self, + ctx: &mut Context, + a: Self::FieldPoint, + ) -> Self::ReducedFieldPoint { + self.0.enforce_less_than(ctx, a) + } + + fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue { + let a = a.into(); + self.0.is_soft_zero(ctx, a) + } + + fn is_soft_nonzero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue { + let a = a.into(); + self.0.is_soft_nonzero(ctx, a) + } + + fn is_zero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue { + let a = a.into(); + self.0.is_zero(ctx, a) + } + + fn is_equal_unenforced( + &self, + ctx: &mut Context, + a: Self::ReducedFieldPoint, + b: Self::ReducedFieldPoint, + ) -> AssignedValue { + self.0.is_equal_unenforced(ctx, a, b) + } + + fn assert_equal( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) { + let a = a.into(); + let b = b.into(); + self.0.assert_equal(ctx, a, b) + } + }; +} diff --git a/halo2-ecc/src/lib.rs b/halo2-ecc/src/lib.rs index ddf2763d..10da56bc 100644 --- a/halo2-ecc/src/lib.rs +++ b/halo2-ecc/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::op_ref)] #![allow(clippy::type_complexity)] #![feature(int_log)] +#![feature(trait_alias)] pub mod bigint; pub mod ecc; diff --git a/halo2-ecc/src/secp256k1/mod.rs b/halo2-ecc/src/secp256k1/mod.rs index c81e136f..ca4528e4 100644 --- a/halo2-ecc/src/secp256k1/mod.rs +++ b/halo2-ecc/src/secp256k1/mod.rs @@ -1,14 +1,12 @@ -use crate::halo2_proofs::halo2curves::secp256k1::Fp; +use crate::halo2_proofs::halo2curves::secp256k1::{Fp, Fq}; use crate::ecc; use crate::fields::fp; -#[allow(dead_code)] -type FpChip = fp::FpConfig; -#[allow(dead_code)] -type Secp256k1Chip = ecc::EccChip>; -#[allow(dead_code)] -const SECP_B: u64 = 7; +pub type FpChip<'range, F> = fp::FpChip<'range, F, Fp>; +pub type FqChip<'range, F> = fp::FpChip<'range, F, Fq>; +pub type Secp256k1Chip<'chip, F> = ecc::EccChip<'chip, F, FpChip<'chip, F>>; +pub const SECP_B: u64 = 7; #[cfg(test)] mod tests; diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index 73389d79..af7050f9 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -1,14 +1,7 @@ #![allow(non_snake_case)] -use ark_std::{end_timer, start_timer}; -use halo2_base::{utils::PrimeField, SKIP_FIRST_PASS}; -use serde::{Deserialize, Serialize}; -use std::fs::File; -use std::marker::PhantomData; -use std::{env::var, io::Write}; - +use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, - circuit::*, dev::MockProver, halo2curves::bn256::{Bn256, Fr, G1Affine}, halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, @@ -16,17 +9,35 @@ use crate::halo2_proofs::{ poly::commitment::ParamsProver, transcript::{Blake2bRead, Blake2bWrite, Challenge255}, }; -use rand_core::OsRng; - -use crate::fields::fp::FpConfig; -use crate::secp256k1::FpChip; +use crate::halo2_proofs::{ + poly::kzg::{ + commitment::KZGCommitmentScheme, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use crate::secp256k1::{FpChip, FqChip}; use crate::{ ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, - fields::{fp::FpStrategy, FieldChip}, + fields::{FieldChip, PrimeField}, }; +use ark_std::{end_timer, start_timer}; +use halo2_base::gates::builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, +}; +use halo2_base::gates::RangeChip; +use halo2_base::utils::fs::gen_srs; use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; +use halo2_base::Context; +use rand_core::OsRng; +use serde::{Deserialize, Serialize}; +use std::fs::File; +use std::io::BufReader; +use std::io::Write; +use std::{fs, io::BufRead}; -#[derive(Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct CircuitParams { strategy: FpStrategy, degree: u32, @@ -38,283 +49,120 @@ struct CircuitParams { num_limbs: usize, } -pub struct ECDSACircuit { - pub r: Option, - pub s: Option, - pub msghash: Option, - pub pk: Option, - pub G: Secp256k1Affine, - pub _marker: PhantomData, -} -impl Default for ECDSACircuit { - fn default() -> Self { - Self { - r: None, - s: None, - msghash: None, - pk: None, - G: Secp256k1Affine::generator(), - _marker: PhantomData, - } - } -} - -impl Circuit for ECDSACircuit { - type Config = FpChip; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("ECDSA_CONFIG") - .unwrap_or_else(|_| "./src/secp256k1/configs/ecdsa_circuit.config".to_string()); - let params: CircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - FpChip::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - modulus::(), - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - fp_chip: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - fp_chip.range.load_lookup_table(&mut layouter)?; - - let limb_bits = fp_chip.limb_bits; - let num_limbs = fp_chip.num_limbs; - let _num_fixed = fp_chip.range.gate.constants.len(); - let _lookup_bits = fp_chip.range.lookup_bits; - let _num_advice = fp_chip.range.gate.num_advice; - - let mut first_pass = SKIP_FIRST_PASS; - // ECDSA verify - layouter.assign_region( - || "ECDSA", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = fp_chip.new_context(region); - let ctx = &mut aux; - - let (r_assigned, s_assigned, m_assigned) = { - let fq_chip = FpConfig::::construct( - fp_chip.range.clone(), - limb_bits, - num_limbs, - modulus::(), - ); - - let m_assigned = fq_chip.load_private( - ctx, - FpConfig::::fe_to_witness( - &self.msghash.map_or(Value::unknown(), Value::known), - ), - ); - - let r_assigned = fq_chip.load_private( - ctx, - FpConfig::::fe_to_witness( - &self.r.map_or(Value::unknown(), Value::known), - ), - ); - let s_assigned = fq_chip.load_private( - ctx, - FpConfig::::fe_to_witness( - &self.s.map_or(Value::unknown(), Value::known), - ), - ); - (r_assigned, s_assigned, m_assigned) - }; - - let ecc_chip = EccChip::>::construct(fp_chip.clone()); - let pk_assigned = ecc_chip.load_private( - ctx, - ( - self.pk.map_or(Value::unknown(), |pt| Value::known(pt.x)), - self.pk.map_or(Value::unknown(), |pt| Value::known(pt.y)), - ), - ); - // test ECDSA - let ecdsa = ecdsa_verify_no_pubkey_check::( - &ecc_chip.field_chip, - ctx, - &pk_assigned, - &r_assigned, - &s_assigned, - &m_assigned, - 4, - 4, - ); - - // IMPORTANT: this copies cells to the lookup advice column to perform range check lookups - // This is not optional. - fp_chip.finalize(ctx); - - #[cfg(feature = "display")] - if self.r.is_some() { - println!("ECDSA res {ecdsa:?}"); - - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } +fn ecdsa_test( + ctx: &mut Context, + params: CircuitParams, + r: Fq, + s: Fq, + msghash: Fq, + pk: Secp256k1Affine, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + + let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); + + let ecc_chip = EccChip::>::new(&fp_chip); + let pk = ecc_chip.load_private_unchecked(ctx, (pk.x, pk.y)); + // test ECDSA + let res = ecdsa_verify_no_pubkey_check::( + &ecc_chip, ctx, pk, r, s, m, 4, 4, + ); + assert_eq!(res.value(), &F::one()); } -#[cfg(test)] -#[test] -fn test_secp256k1_ecdsa() { - let mut folder = std::path::PathBuf::new(); - folder.push("./src/secp256k1"); - folder.push("configs/ecdsa_circuit.config"); - let params_str = std::fs::read_to_string(folder.as_path()) - .expect("src/secp256k1/configs/ecdsa_circuit.config file should exist"); - let params: CircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let K = params.degree; - - // generate random pub key and sign random message - let G = Secp256k1Affine::generator(); +fn random_ecdsa_circuit( + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(G * sk); + let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let msg_hash = ::ScalarExt::random(OsRng); let k = ::ScalarExt::random(OsRng); let k_inv = k.invert().unwrap(); - let r_point = Secp256k1Affine::from(G * k).coordinates().unwrap(); + let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); let x = r_point.x(); let x_bigint = fe_to_biguint(x); let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - let circuit = ECDSACircuit:: { - r: Some(r), - s: Some(s), - msghash: Some(msg_hash), - pk: Some(pubkey), - G, - _marker: PhantomData, + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; + end_timer!(start0); + circuit +} - let prover = MockProver::run(K, &circuit, vec![]).unwrap(); - //prover.assert_satisfied(); - assert_eq!(prover.verify(), Ok(())); +#[test] +fn test_secp256k1_ecdsa() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = random_ecdsa_circuit(params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } -#[cfg(test)] #[test] fn bench_secp256k1_ecdsa() -> Result<(), Box> { - use halo2_base::utils::fs::gen_srs; - - use crate::halo2_proofs::{ - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, - }; - use std::{env::set_var, fs, io::BufRead}; - - let _rng = OsRng; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/secp256k1"); - - folder.push("configs/bench_ecdsa.config"); - let bench_params_file = std::fs::File::open(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); - - folder.push("results/ecdsa_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); + let mut rng = OsRng; + let config_path = "configs/secp256k1/bench_ecdsa.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/secp256k1").unwrap(); + fs::create_dir_all("data").unwrap(); + let results_path = "results/secp256k1/ecdsa_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: CircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); - { - folder.pop(); - folder.push("configs/ecdsa_circuit.tmp.config"); - set_var("ECDSA_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } - let params_time = start_timer!(|| "Time elapsed in circuit & params construction"); - let params = gen_srs(bench_params.degree); - let circuit = ECDSACircuit::::default(); - end_timer!(params_time); + let params = gen_srs(k); + println!("{bench_params:?}"); + + let circuit = random_ecdsa_circuit(bench_params, CircuitBuilderStage::Keygen, None); - let vk_time = start_timer!(|| "Time elapsed in generating vkey"); + let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; end_timer!(vk_time); - let pk_time = start_timer!(|| "Time elapsed in generating pkey"); + let pk_time = start_timer!(|| "Generating pkey"); let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - // generate random pub key and sign random message - let G = Secp256k1Affine::generator(); - let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(G * sk); - let msg_hash = ::ScalarExt::random(OsRng); - - let k = ::ScalarExt::random(OsRng); - let k_inv = k.invert().unwrap(); - - let r_point = Secp256k1Affine::from(G * k).coordinates().unwrap(); - let x = r_point.x(); - let x_bigint = fe_to_biguint(x); - let r = biguint_to_fe::(&x_bigint); - let s = k_inv * (msg_hash + (r * sk)); - - let proof_circuit = ECDSACircuit:: { - r: Some(r), - s: Some(s), - msghash: Some(msg_hash), - pk: Some(pubkey), - G, - _marker: PhantomData, - }; - let mut rng = OsRng; - + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); + let circuit = + random_ecdsa_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -322,14 +170,14 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { Challenge255, _, Blake2bWrite, G1Affine, Challenge255>, - ECDSACircuit, - >(¶ms, &pk, &[proof_circuit], &[&[]], &mut rng, &mut transcript)?; + _, + >(¶ms, &pk, &[circuit], &[&[]], &mut rng, &mut transcript)?; let proof = transcript.finalize(); end_timer!(proof_time); let proof_size = { - folder.push(format!( - "ecdsa_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/ecdsa_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -337,27 +185,27 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierSHPLONK<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("ECDSA_CONFIG").unwrap())?; writeln!( fs_results, diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs new file mode 100644 index 00000000..45e251f3 --- /dev/null +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -0,0 +1,191 @@ +#![allow(non_snake_case)] +use crate::halo2_proofs::{ + arithmetic::CurveAffine, + dev::MockProver, + halo2curves::bn256::Fr, + halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, +}; +use crate::secp256k1::{FpChip, FqChip}; +use crate::{ + ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, + fields::{FieldChip, PrimeField}, +}; +use ark_std::{end_timer, start_timer}; +use halo2_base::gates::builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, +}; + +use halo2_base::gates::RangeChip; +use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; +use halo2_base::Context; +use rand::random; +use rand_core::OsRng; +use std::fs::File; +use test_case::test_case; + +use super::CircuitParams; + +fn ecdsa_test( + ctx: &mut Context, + params: CircuitParams, + r: Fq, + s: Fq, + msghash: Fq, + pk: Secp256k1Affine, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + + let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); + + let ecc_chip = EccChip::>::new(&fp_chip); + let pk = ecc_chip.assign_point(ctx, pk); + // test ECDSA + let res = ecdsa_verify_no_pubkey_check::( + &ecc_chip, ctx, pk, r, s, m, 4, 4, + ); + assert_eq!(res.value(), &F::one()); +} + +fn random_parameters_ecdsa() -> (Fq, Fq, Fq, Secp256k1Affine) { + let sk = ::ScalarExt::random(OsRng); + let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msg_hash = ::ScalarExt::random(OsRng); + + let k = ::ScalarExt::random(OsRng); + let k_inv = k.invert().unwrap(); + + let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); + let x = r_point.x(); + let x_bigint = fe_to_biguint(x); + + let r = biguint_to_fe::(&(x_bigint % modulus::())); + let s = k_inv * (msg_hash + (r * sk)); + + (r, s, msg_hash, pubkey) +} + +fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> (Fq, Fq, Fq, Secp256k1Affine) { + let sk = ::ScalarExt::from(sk); + let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msg_hash = ::ScalarExt::from(msg_hash); + + let k = ::ScalarExt::from(k); + let k_inv = k.invert().unwrap(); + + let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); + let x = r_point.x(); + let x_bigint = fe_to_biguint(x); + + let r = biguint_to_fe::(&(x_bigint % modulus::())); + let s = k_inv * (msg_hash + (r * sk)); + + (r, s, msg_hash, pubkey) +} + +fn ecdsa_circuit( + r: Fq, + s: Fq, + msg_hash: Fq, + pubkey: Secp256k1Affine, + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +#[should_panic(expected = "assertion failed: `(left == right)`")] +fn test_ecdsa_msg_hash_zero() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(random::(), 0, random::()); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +#[should_panic(expected = "assertion failed: `(left == right)`")] +fn test_ecdsa_private_key_zero() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(0, random::(), random::()); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_ecdsa_random_valid_inputs() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = random_parameters_ecdsa(); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test_case(1, 1, 1; "")] +fn test_ecdsa_custom_valid_inputs(sk: u64, msg_hash: u64, k: u64) { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test_case(1, 1, 1; "")] +fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64, msg_hash: u64, k: u64) { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); + let s = -s; + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index ecc8b287..803ac232 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1 +1,162 @@ +#![allow(non_snake_case)] +use std::fs::File; + +use ff::Field; +use group::Curve; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + halo2_proofs::{ + dev::MockProver, + halo2curves::{ + bn256::Fr, + secp256k1::{Fq, Secp256k1Affine}, + }, + }, + utils::{biguint_to_fe, fe_to_biguint, BigPrimeField}, + Context, +}; +use num_bigint::BigUint; +use rand_core::OsRng; +use serde::{Deserialize, Serialize}; + +use crate::{ + ecc::EccChip, + fields::{FieldChip, FpStrategy}, + secp256k1::{FpChip, FqChip}, +}; + pub mod ecdsa; +pub mod ecdsa_tests; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct CircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, +} + +fn sm_test( + ctx: &mut Context, + params: CircuitParams, + base: Secp256k1Affine, + scalar: Fq, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::>::new(&fp_chip); + + let s = fq_chip.load_private(ctx, scalar); + let P = ecc_chip.assign_point(ctx, base); + + let sm = ecc_chip.scalar_mult::( + ctx, + P, + s.limbs().to_vec(), + fq_chip.limb_bits, + window_bits, + ); + + let sm_answer = (base * scalar).to_affine(); + + let sm_x = sm.x.value(); + let sm_y = sm.y.value(); + assert_eq!(sm_x, fe_to_biguint(&sm_answer.x)); + assert_eq!(sm_y, fe_to_biguint(&sm_answer.y)); +} + +fn sm_circuit( + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + base: Secp256k1Affine, + scalar: Fq, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = GateThreadBuilder::new(stage == CircuitBuilderStage::Prover); + + sm_test(builder.main(0), params, base, scalar, 4); + + match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + } +} + +#[test] +fn test_secp_sm_random() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = sm_circuit( + params, + CircuitBuilderStage::Mock, + None, + Secp256k1Affine::random(OsRng), + Fq::random(OsRng), + ); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_secp_sm_minus_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let mut s = -Fq::one(); + let mut n = fe_to_biguint(&s); + loop { + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + if &n % BigUint::from(2usize) == BigUint::from(0usize) { + break; + } + n /= 2usize; + s = biguint_to_fe(&n); + } +} + +#[test] +fn test_secp_sm_0_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let s = Fq::zero(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + + let s = Fq::one(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi.rs b/hashes/zkevm-keccak/src/keccak_packed_multi.rs index 085ff9c6..55be8306 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi.rs @@ -16,7 +16,7 @@ use crate::halo2_proofs::{ }, poly::Rotation, }; -use halo2_base::AssignedValue; +use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; use itertools::Itertools; use log::{debug, info}; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; @@ -285,6 +285,7 @@ impl CellManager { let column = if column_idx < self.columns.len() { self.columns[column_idx].advice } else { + assert!(column_idx == self.columns.len()); let advice = meta.advice_column(); let mut expr = 0.expr(); meta.create_gate("Query column", |meta| { @@ -337,7 +338,7 @@ impl CellManager { // Make sure all rows start at the same column let width = self.get_width(); #[cfg(debug_assertions)] - for row in self.rows.iter_mut() { + for row in self.rows.iter() { self.num_unused_cells += width - *row; } self.rows = vec![width; self.height]; @@ -382,33 +383,26 @@ impl KeccakTable { } } +#[cfg(feature = "halo2-axiom")] +type KeccakAssignedValue<'v, F> = AssignedCell<&'v Assigned, F>; +#[cfg(not(feature = "halo2-axiom"))] +type KeccakAssignedValue<'v, F> = AssignedCell; + pub fn assign_advice_custom<'v, F: Field>( region: &mut Region, column: Column, offset: usize, value: Value, -) -> AssignedValue<'v, F> { +) -> KeccakAssignedValue<'v, F> { #[cfg(feature = "halo2-axiom")] { - AssignedValue { - cell: region.assign_advice(column, offset, value).unwrap(), - #[cfg(feature = "display")] - context_id: usize::MAX, - } + region.assign_advice(column, offset, value) } #[cfg(feature = "halo2-pse")] { - AssignedValue { - cell: region - .assign_advice(|| format!("assign advice {}", offset), column, offset, || value) - .unwrap() - .cell(), - value, - row_offset: offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id: usize::MAX, - } + region + .assign_advice(|| format!("assign advice {}", offset), column, offset, || value) + .unwrap() } } @@ -1142,7 +1136,7 @@ impl KeccakCircuitConfig { for i in 0..5 { let input = scatter::expr(3, part_size_base) - 2.expr() * input[i].clone() + input[(i + 1) % 5].clone() - - input[(i + 2) % 5].clone().clone(); + - input[(i + 2) % 5].clone(); let output = output[i].clone(); meta.lookup("chi base", |_| { vec![(input.clone(), chi_base_table[0]), (output.clone(), chi_base_table[1])] @@ -1604,7 +1598,7 @@ pub fn keccak_phase1<'v, F: Field>( keccak_table: &KeccakTable, bytes: &[u8], challenge: Value, - input_rlcs: &mut Vec>, + input_rlcs: &mut Vec>, offset: &mut usize, ) { let num_chunks = get_num_keccak_f(bytes.len()); @@ -1948,7 +1942,7 @@ pub fn keccak_phase0( .take(4) .map(|a| { pack_with_base::(&unpack(a[0]), 2) - .to_repr() + .to_bytes_le() .into_iter() .take(8) .collect::>() @@ -1967,7 +1961,7 @@ pub fn multi_keccak_phase1<'a, 'v, F: Field>( bytes: impl IntoIterator, challenge: Value, squeeze_digests: Vec<[F; NUM_WORDS_TO_SQUEEZE]>, -) -> (Vec>, Vec>) { +) -> (Vec>, Vec>) { let mut input_rlcs = Vec::with_capacity(squeeze_digests.len()); let mut output_rlcs = Vec::with_capacity(squeeze_digests.len()); diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs index 7af3ba4d..4619a197 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs @@ -38,6 +38,9 @@ impl Circuit for KeccakCircuit { } fn configure(meta: &mut ConstraintSystem) -> Self::Config { + // MockProver complains if you only have columns in SecondPhase, so let's just make an empty column in FirstPhase + meta.advice_column(); + let challenge = meta.challenge_usable_after(FirstPhase); KeccakCircuitConfig::new(meta, challenge) } diff --git a/hashes/zkevm-keccak/src/util.rs b/hashes/zkevm-keccak/src/util.rs index 868c366c..b3e2e2b5 100644 --- a/hashes/zkevm-keccak/src/util.rs +++ b/hashes/zkevm-keccak/src/util.rs @@ -183,7 +183,7 @@ pub fn pack_part(bits: &[u8], info: &PartInfo) -> u64 { /// Unpack a sparse keccak word into bits in the range [0,BIT_SIZE[ pub fn unpack(packed: F) -> [u8; NUM_BITS_PER_WORD] { let mut bits = [0; NUM_BITS_PER_WORD]; - let packed = Word::from_little_endian(packed.to_repr().as_ref()); + let packed = Word::from_little_endian(packed.to_bytes_le().as_ref()); let mask = Word::from(BIT_SIZE - 1); for (idx, bit) in bits.iter_mut().enumerate() { *bit = ((packed >> (idx * BIT_COUNT)) & mask).as_u32() as u8; @@ -200,10 +200,10 @@ pub fn pack_u64(value: u64) -> F { /// Calculates a ^ b with a and b field elements pub fn field_xor(a: F, b: F) -> F { let mut bytes = [0u8; 32]; - for (idx, (a, b)) in a.to_repr().as_ref().iter().zip(b.to_repr().as_ref().iter()).enumerate() { - bytes[idx] = *a ^ *b; + for (idx, (a, b)) in a.to_bytes_le().into_iter().zip(b.to_bytes_le()).enumerate() { + bytes[idx] = a ^ b; } - F::from_repr(bytes).unwrap() + F::from_bytes_le(&bytes) } /// Returns the size (in bits) of each part size when splitting up a keccak word diff --git a/hashes/zkevm-keccak/src/util/constraint_builder.rs b/hashes/zkevm-keccak/src/util/constraint_builder.rs index 94f47c8c..bae9f4a4 100644 --- a/hashes/zkevm-keccak/src/util/constraint_builder.rs +++ b/hashes/zkevm-keccak/src/util/constraint_builder.rs @@ -53,7 +53,7 @@ impl BaseConstraintBuilder { pub(crate) fn validate_degree(&self, degree: usize, name: &'static str) { if self.max_degree > 0 { - debug_assert!( + assert!( degree <= self.max_degree, "Expression {} degree too high: {} > {}", name, diff --git a/hashes/zkevm-keccak/src/util/eth_types.rs b/hashes/zkevm-keccak/src/util/eth_types.rs index 3217f810..6fed74a5 100644 --- a/hashes/zkevm-keccak/src/util/eth_types.rs +++ b/hashes/zkevm-keccak/src/util/eth_types.rs @@ -71,7 +71,7 @@ impl ToScalar for U256 { fn to_scalar(&self) -> Option { let mut bytes = [0u8; 32]; self.to_little_endian(&mut bytes); - F::from_repr(bytes).into() + Some(F::from_bytes_le(&bytes)) } } @@ -113,7 +113,7 @@ impl ToScalar for Address { let mut bytes = [0u8; 32]; bytes[32 - Self::len_bytes()..].copy_from_slice(self.as_bytes()); bytes.reverse(); - F::from_repr(bytes).into() + Some(F::from_bytes_le(&bytes)) } }