diff --git a/barretenberg/cpp/pil/spike/README.md b/barretenberg/cpp/pil/spike/README.md new file mode 100644 index 00000000000..69e4f55ac79 --- /dev/null +++ b/barretenberg/cpp/pil/spike/README.md @@ -0,0 +1,3 @@ +## Spike machine + +A spike machine for testing new PIL functionality \ No newline at end of file diff --git a/barretenberg/cpp/pil/spike/spike.pil b/barretenberg/cpp/pil/spike/spike.pil new file mode 100644 index 00000000000..1361c446923 --- /dev/null +++ b/barretenberg/cpp/pil/spike/spike.pil @@ -0,0 +1,8 @@ + +namespace Spike(16); + +pol constant first = [1] + [0]*; +pol commit x; +pol public kernel_inputs; + +x - first = 0; \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/relations/generated/spike/declare_views.hpp b/barretenberg/cpp/src/barretenberg/relations/generated/spike/declare_views.hpp new file mode 100644 index 00000000000..df901e8d155 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/relations/generated/spike/declare_views.hpp @@ -0,0 +1,7 @@ + +#define Spike_DECLARE_VIEWS(index) \ + using Accumulator = typename std::tuple_element::type; \ + using View = typename Accumulator::View; \ + [[maybe_unused]] auto Spike_first = View(new_term.Spike_first); \ + [[maybe_unused]] auto Spike_kernel_inputs = View(new_term.Spike_kernel_inputs); \ + [[maybe_unused]] auto Spike_x = View(new_term.Spike_x); diff --git a/barretenberg/cpp/src/barretenberg/relations/generated/spike/spike.hpp b/barretenberg/cpp/src/barretenberg/relations/generated/spike/spike.hpp new file mode 100644 index 00000000000..2a99922e200 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/relations/generated/spike/spike.hpp @@ -0,0 +1,48 @@ + +#pragma once +#include "../../relation_parameters.hpp" +#include "../../relation_types.hpp" +#include "./declare_views.hpp" + +namespace bb::Spike_vm { + +template struct SpikeRow { + FF Spike_first{}; + FF Spike_x{}; +}; + +inline std::string get_relation_label_spike(int index) +{ + switch (index) {} + return std::to_string(index); +} + +template class spikeImpl { + public: + using FF = FF_; + + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + 2, + }; + + template + void static accumulate(ContainerOverSubrelations& evals, + const AllEntities& new_term, + [[maybe_unused]] const RelationParameters&, + [[maybe_unused]] const FF& scaling_factor) + { + + // Contribution 0 + { + Spike_DECLARE_VIEWS(0); + + auto tmp = (Spike_x - Spike_first); + tmp *= scaling_factor; + std::get<0>(evals) += tmp; + } + } +}; + +template using spike = Relation>; + +} // namespace bb::Spike_vm \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/avm_flavor.hpp b/barretenberg/cpp/src/barretenberg/vm/generated/avm_flavor.hpp index 1921397837f..bb97c6808e4 100644 --- a/barretenberg/cpp/src/barretenberg/vm/generated/avm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/generated/avm_flavor.hpp @@ -2026,6 +2026,14 @@ class AvmFlavor { */ template using ProverUnivariates = AllEntities>; + /** + * @brief A container for univariates used during Protogalaxy folding and sumcheck with some of the computation + * optmistically ignored + * @details During folding and sumcheck, the prover evaluates the relations on these univariates. + */ + template + using OptimisedProverUnivariates = AllEntities>; + /** * @brief A container for univariates produced during the hot loop in sumcheck. */ diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/avm_verifier.cpp b/barretenberg/cpp/src/barretenberg/vm/generated/avm_verifier.cpp index ecce0af1b4d..ba34ca33fd0 100644 --- a/barretenberg/cpp/src/barretenberg/vm/generated/avm_verifier.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/generated/avm_verifier.cpp @@ -3,9 +3,11 @@ #include "./avm_verifier.hpp" #include "barretenberg/commitment_schemes/zeromorph/zeromorph.hpp" #include "barretenberg/numeric/bitop/get_msb.hpp" +#include "barretenberg/polynomials/polynomial.hpp" #include "barretenberg/transcript/transcript.hpp" namespace bb { + AvmVerifier::AvmVerifier(std::shared_ptr verifier_key) : key(verifier_key) {} diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/spike_circuit_builder.hpp b/barretenberg/cpp/src/barretenberg/vm/generated/spike_circuit_builder.hpp new file mode 100644 index 00000000000..255ceed71c8 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/generated/spike_circuit_builder.hpp @@ -0,0 +1,110 @@ + + +// AUTOGENERATED FILE +#pragma once + +#include "barretenberg/common/constexpr_utils.hpp" +#include "barretenberg/common/throw_or_abort.hpp" +#include "barretenberg/ecc/curves/bn254/fr.hpp" +#include "barretenberg/honk/proof_system/logderivative_library.hpp" +#include "barretenberg/relations/generic_lookup/generic_lookup_relation.hpp" +#include "barretenberg/relations/generic_permutation/generic_permutation_relation.hpp" +#include "barretenberg/stdlib_circuit_builders/circuit_builder_base.hpp" + +#include "barretenberg/relations/generated/spike/spike.hpp" +#include "barretenberg/vm/generated/spike_flavor.hpp" + +namespace bb { + +template struct SpikeFullRow { + FF Spike_first{}; + FF Spike_kernel_inputs{}; + FF Spike_x{}; +}; + +class SpikeCircuitBuilder { + public: + using Flavor = bb::SpikeFlavor; + using FF = Flavor::FF; + using Row = SpikeFullRow; + + // TODO: template + using Polynomial = Flavor::Polynomial; + using ProverPolynomials = Flavor::ProverPolynomials; + + static constexpr size_t num_fixed_columns = 3; + static constexpr size_t num_polys = 3; + std::vector rows; + + void set_trace(std::vector&& trace) { rows = std::move(trace); } + + ProverPolynomials compute_polynomials() + { + const auto num_rows = get_circuit_subgroup_size(); + ProverPolynomials polys; + + // Allocate mem for each column + for (auto& poly : polys.get_all()) { + poly = Polynomial(num_rows); + } + + for (size_t i = 0; i < rows.size(); i++) { + polys.Spike_first[i] = rows[i].Spike_first; + polys.Spike_kernel_inputs[i] = rows[i].Spike_kernel_inputs; + polys.Spike_x[i] = rows[i].Spike_x; + } + + return polys; + } + + [[maybe_unused]] bool check_circuit() + { + + auto polys = compute_polynomials(); + const size_t num_rows = polys.get_polynomial_size(); + + const auto evaluate_relation = [&](const std::string& relation_name, + std::string (*debug_label)(int)) { + typename Relation::SumcheckArrayOfValuesOverSubrelations result; + for (auto& r : result) { + r = 0; + } + constexpr size_t NUM_SUBRELATIONS = result.size(); + + for (size_t i = 0; i < num_rows; ++i) { + Relation::accumulate(result, polys.get_row(i), {}, 1); + + bool x = true; + for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) { + if (result[j] != 0) { + std::string row_name = debug_label(static_cast(j)); + throw_or_abort( + format("Relation ", relation_name, ", subrelation index ", row_name, " failed at row ", i)); + x = false; + } + } + if (!x) { + return false; + } + } + return true; + }; + + if (!evaluate_relation.template operator()>("spike", Spike_vm::get_relation_label_spike)) { + return false; + } + + return true; + } + + [[nodiscard]] size_t get_num_gates() const { return rows.size(); } + + [[nodiscard]] size_t get_circuit_subgroup_size() const + { + const size_t num_rows = get_num_gates(); + const auto num_rows_log2 = static_cast(numeric::get_msb64(num_rows)); + size_t num_rows_pow2 = 1UL << (num_rows_log2 + (1UL << num_rows_log2 == num_rows ? 0 : 1)); + return num_rows_pow2; + } +}; +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/spike_composer.cpp b/barretenberg/cpp/src/barretenberg/vm/generated/spike_composer.cpp new file mode 100644 index 00000000000..9745b6accda --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/generated/spike_composer.cpp @@ -0,0 +1,86 @@ + + +#include "./spike_composer.hpp" +#include "barretenberg/plonk_honk_shared/composer/composer_lib.hpp" +#include "barretenberg/plonk_honk_shared/composer/permutation_lib.hpp" +#include "barretenberg/vm/generated/spike_circuit_builder.hpp" +#include "barretenberg/vm/generated/spike_verifier.hpp" + +namespace bb { + +using Flavor = SpikeFlavor; +void SpikeComposer::compute_witness(CircuitConstructor& circuit) +{ + if (computed_witness) { + return; + } + + auto polynomials = circuit.compute_polynomials(); + + for (auto [key_poly, prover_poly] : zip_view(proving_key->get_all(), polynomials.get_unshifted())) { + ASSERT(flavor_get_label(*proving_key, key_poly) == flavor_get_label(polynomials, prover_poly)); + key_poly = prover_poly; + } + + computed_witness = true; +} + +SpikeProver SpikeComposer::create_prover(CircuitConstructor& circuit_constructor) +{ + compute_proving_key(circuit_constructor); + compute_witness(circuit_constructor); + compute_commitment_key(circuit_constructor.get_circuit_subgroup_size()); + + SpikeProver output_state(proving_key, proving_key->commitment_key); + + return output_state; +} + +SpikeVerifier SpikeComposer::create_verifier(CircuitConstructor& circuit_constructor) +{ + auto verification_key = compute_verification_key(circuit_constructor); + + SpikeVerifier output_state(verification_key); + + auto pcs_verification_key = std::make_unique(); + + output_state.pcs_verification_key = std::move(pcs_verification_key); + + return output_state; +} + +std::shared_ptr SpikeComposer::compute_proving_key(CircuitConstructor& circuit_constructor) +{ + if (proving_key) { + return proving_key; + } + + // Initialize proving_key + { + const size_t subgroup_size = circuit_constructor.get_circuit_subgroup_size(); + proving_key = std::make_shared(subgroup_size, 0); + } + + proving_key->contains_recursive_proof = false; + + return proving_key; +} + +std::shared_ptr SpikeComposer::compute_verification_key( + CircuitConstructor& circuit_constructor) +{ + if (verification_key) { + return verification_key; + } + + if (!proving_key) { + compute_proving_key(circuit_constructor); + } + + verification_key = + std::make_shared(proving_key->circuit_size, proving_key->num_public_inputs); + + return verification_key; +} + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/spike_composer.hpp b/barretenberg/cpp/src/barretenberg/vm/generated/spike_composer.hpp new file mode 100644 index 00000000000..10ddf7dbd93 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/generated/spike_composer.hpp @@ -0,0 +1,69 @@ + + +#pragma once + +#include "barretenberg/plonk_honk_shared/composer/composer_lib.hpp" +#include "barretenberg/srs/global_crs.hpp" +#include "barretenberg/vm/generated/spike_circuit_builder.hpp" +#include "barretenberg/vm/generated/spike_prover.hpp" +#include "barretenberg/vm/generated/spike_verifier.hpp" + +namespace bb { +class SpikeComposer { + public: + using Flavor = SpikeFlavor; + using CircuitConstructor = SpikeCircuitBuilder; + using ProvingKey = Flavor::ProvingKey; + using VerificationKey = Flavor::VerificationKey; + using PCS = Flavor::PCS; + using CommitmentKey = Flavor::CommitmentKey; + using VerifierCommitmentKey = Flavor::VerifierCommitmentKey; + + // TODO: which of these will we really need + static constexpr std::string_view NAME_STRING = "Spike"; + static constexpr size_t NUM_RESERVED_GATES = 0; + static constexpr size_t NUM_WIRES = Flavor::NUM_WIRES; + + std::shared_ptr proving_key; + std::shared_ptr verification_key; + + // The crs_factory holds the path to the srs and exposes methods to extract the srs elements + std::shared_ptr> crs_factory_; + + // The commitment key is passed to the prover but also used herein to compute the verfication key commitments + std::shared_ptr commitment_key; + + std::vector recursive_proof_public_input_indices; + bool contains_recursive_proof = false; + bool computed_witness = false; + + SpikeComposer() { crs_factory_ = bb::srs::get_bn254_crs_factory(); } + + SpikeComposer(std::shared_ptr p_key, std::shared_ptr v_key) + : proving_key(std::move(p_key)) + , verification_key(std::move(v_key)) + {} + + SpikeComposer(SpikeComposer&& other) noexcept = default; + SpikeComposer(SpikeComposer const& other) noexcept = default; + SpikeComposer& operator=(SpikeComposer&& other) noexcept = default; + SpikeComposer& operator=(SpikeComposer const& other) noexcept = default; + ~SpikeComposer() = default; + + std::shared_ptr compute_proving_key(CircuitConstructor& circuit_constructor); + std::shared_ptr compute_verification_key(CircuitConstructor& circuit_constructor); + + void compute_witness(CircuitConstructor& circuit_constructor); + + SpikeProver create_prover(CircuitConstructor& circuit_constructor); + SpikeVerifier create_verifier(CircuitConstructor& circuit_constructor); + + void add_table_column_selector_poly_to_proving_key(bb::polynomial& small, const std::string& tag); + + void compute_commitment_key(size_t circuit_size) + { + proving_key->commitment_key = std::make_shared(circuit_size); + }; +}; + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/spike_flavor.hpp b/barretenberg/cpp/src/barretenberg/vm/generated/spike_flavor.hpp new file mode 100644 index 00000000000..b841904764d --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/generated/spike_flavor.hpp @@ -0,0 +1,286 @@ + + +#pragma once +#include "barretenberg/commitment_schemes/kzg/kzg.hpp" +#include "barretenberg/ecc/curves/bn254/g1.hpp" +#include "barretenberg/flavor/relation_definitions.hpp" +#include "barretenberg/polynomials/barycentric.hpp" +#include "barretenberg/polynomials/univariate.hpp" + +#include "barretenberg/relations/generic_permutation/generic_permutation_relation.hpp" + +#include "barretenberg/flavor/flavor.hpp" +#include "barretenberg/flavor/flavor_macros.hpp" +#include "barretenberg/polynomials/evaluation_domain.hpp" +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/relations/generated/spike/spike.hpp" +#include "barretenberg/transcript/transcript.hpp" + +namespace bb { + +class SpikeFlavor { + public: + using Curve = curve::BN254; + using G1 = Curve::Group; + using PCS = KZG; + + using FF = G1::subgroup_field; + using Polynomial = bb::Polynomial; + using PolynomialHandle = std::span; + using GroupElement = G1::element; + using Commitment = G1::affine_element; + using CommitmentHandle = G1::affine_element; + using CommitmentKey = bb::CommitmentKey; + using VerifierCommitmentKey = bb::VerifierCommitmentKey; + using RelationSeparator = FF; + + static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 1; + static constexpr size_t NUM_WITNESS_ENTITIES = 2; + static constexpr size_t NUM_WIRES = NUM_WITNESS_ENTITIES + NUM_PRECOMPUTED_ENTITIES; + // We have two copies of the witness entities, so we subtract the number of fixed ones (they have no shift), one for + // the unshifted and one for the shifted + static constexpr size_t NUM_ALL_ENTITIES = 3; + + using Relations = std::tuple>; + + static constexpr size_t MAX_PARTIAL_RELATION_LENGTH = compute_max_partial_relation_length(); + + // BATCHED_RELATION_PARTIAL_LENGTH = algebraic degree of sumcheck relation *after* multiplying by the `pow_zeta` + // random polynomial e.g. For \sum(x) [A(x) * B(x) + C(x)] * PowZeta(X), relation length = 2 and random relation + // length = 3 + static constexpr size_t BATCHED_RELATION_PARTIAL_LENGTH = MAX_PARTIAL_RELATION_LENGTH + 1; + static constexpr size_t NUM_RELATIONS = std::tuple_size_v; + + template + using ProtogalaxyTupleOfTuplesOfUnivariates = + decltype(create_protogalaxy_tuple_of_tuples_of_univariates()); + using SumcheckTupleOfTuplesOfUnivariates = decltype(create_sumcheck_tuple_of_tuples_of_univariates()); + using TupleOfArraysOfValues = decltype(create_tuple_of_arrays_of_values()); + + static constexpr bool has_zero_row = true; + + private: + template class PrecomputedEntities : public PrecomputedEntitiesBase { + public: + using DataType = DataType_; + + DEFINE_FLAVOR_MEMBERS(DataType, Spike_first) + + RefVector get_selectors() { return { Spike_first }; }; + RefVector get_sigma_polynomials() { return {}; }; + RefVector get_id_polynomials() { return {}; }; + RefVector get_table_polynomials() { return {}; }; + }; + + template class WitnessEntities { + public: + DEFINE_FLAVOR_MEMBERS(DataType, Spike_kernel_inputs, Spike_x) + + RefVector get_wires() { return { Spike_kernel_inputs, Spike_x }; }; + }; + + template class AllEntities { + public: + DEFINE_FLAVOR_MEMBERS(DataType, Spike_first, Spike_kernel_inputs, Spike_x) + + RefVector get_wires() { return { Spike_first, Spike_kernel_inputs, Spike_x }; }; + RefVector get_unshifted() { return { Spike_first, Spike_kernel_inputs, Spike_x }; }; + RefVector get_to_be_shifted() { return {}; }; + RefVector get_shifted() { return {}; }; + }; + + public: + class ProvingKey + : public ProvingKeyAvm_, WitnessEntities, CommitmentKey> { + public: + // Expose constructors on the base class + using Base = ProvingKeyAvm_, WitnessEntities, CommitmentKey>; + using Base::Base; + + RefVector get_to_be_shifted() { return {}; }; + }; + + using VerificationKey = VerificationKey_, VerifierCommitmentKey>; + + using FoldedPolynomials = AllEntities>; + + class AllValues : public AllEntities { + public: + using Base = AllEntities; + using Base::Base; + }; + + /** + * @brief A container for the prover polynomials handles. + */ + class ProverPolynomials : public AllEntities { + public: + // Define all operations as default, except copy construction/assignment + ProverPolynomials() = default; + ProverPolynomials& operator=(const ProverPolynomials&) = delete; + ProverPolynomials(const ProverPolynomials& o) = delete; + ProverPolynomials(ProverPolynomials&& o) noexcept = default; + ProverPolynomials& operator=(ProverPolynomials&& o) noexcept = default; + ~ProverPolynomials() = default; + + ProverPolynomials(ProvingKey& proving_key) + { + for (auto [prover_poly, key_poly] : zip_view(this->get_unshifted(), proving_key.get_all())) { + ASSERT(flavor_get_label(*this, prover_poly) == flavor_get_label(proving_key, key_poly)); + prover_poly = key_poly.share(); + } + for (auto [prover_poly, key_poly] : zip_view(this->get_shifted(), proving_key.get_to_be_shifted())) { + ASSERT(flavor_get_label(*this, prover_poly) == (flavor_get_label(proving_key, key_poly) + "_shift")); + prover_poly = key_poly.shifted(); + } + } + + [[nodiscard]] size_t get_polynomial_size() const { return Spike_kernel_inputs.size(); } + /** + * @brief Returns the evaluations of all prover polynomials at one point on the boolean hypercube, which + * represents one row in the execution trace. + */ + [[nodiscard]] AllValues get_row(size_t row_idx) const + { + AllValues result; + for (auto [result_field, polynomial] : zip_view(result.get_all(), this->get_all())) { + result_field = polynomial[row_idx]; + } + return result; + } + }; + + using RowPolynomials = AllEntities; + + class PartiallyEvaluatedMultivariates : public AllEntities { + public: + PartiallyEvaluatedMultivariates() = default; + PartiallyEvaluatedMultivariates(const size_t circuit_size) + { + // Storage is only needed after the first partial evaluation, hence polynomials of size (n / 2) + for (auto& poly : get_all()) { + poly = Polynomial(circuit_size / 2); + } + } + }; + + /** + * @brief A container for univariates used during Protogalaxy folding and sumcheck. + * @details During folding and sumcheck, the prover evaluates the relations on these univariates. + */ + template using ProverUnivariates = AllEntities>; + + /** + * @brief A container for univariates used during Protogalaxy folding and sumcheck with some of the computation + * optmistically ignored + * @details During folding and sumcheck, the prover evaluates the relations on these univariates. + */ + template + using OptimisedProverUnivariates = AllEntities>; + + /** + * @brief A container for univariates produced during the hot loop in sumcheck. + */ + using ExtendedEdges = ProverUnivariates; + + /** + * @brief A container for the witness commitments. + * + */ + using WitnessCommitments = WitnessEntities; + + class CommitmentLabels : public AllEntities { + private: + using Base = AllEntities; + + public: + CommitmentLabels() + : AllEntities() + { + Base::Spike_first = "SPIKE_FIRST"; + Base::Spike_kernel_inputs = "SPIKE_KERNEL_INPUTS"; + Base::Spike_x = "SPIKE_X"; + }; + }; + + class VerifierCommitments : public AllEntities { + private: + using Base = AllEntities; + + public: + VerifierCommitments(const std::shared_ptr& verification_key) + { + Spike_first = verification_key->Spike_first; + } + }; + + class Transcript : public NativeTranscript { + public: + uint32_t circuit_size; + + Commitment Spike_kernel_inputs; + Commitment Spike_x; + + std::vector> sumcheck_univariates; + std::array sumcheck_evaluations; + std::vector zm_cq_comms; + Commitment zm_cq_comm; + Commitment zm_pi_comm; + + Transcript() = default; + + Transcript(const std::vector& proof) + : NativeTranscript(proof) + {} + + void deserialize_full_transcript() + { + size_t num_frs_read = 0; + circuit_size = deserialize_from_buffer(proof_data, num_frs_read); + size_t log_n = numeric::get_msb(circuit_size); + + Spike_kernel_inputs = deserialize_from_buffer(Transcript::proof_data, num_frs_read); + Spike_x = deserialize_from_buffer(Transcript::proof_data, num_frs_read); + + for (size_t i = 0; i < log_n; ++i) { + sumcheck_univariates.emplace_back( + deserialize_from_buffer>(Transcript::proof_data, + num_frs_read)); + } + sumcheck_evaluations = + deserialize_from_buffer>(Transcript::proof_data, num_frs_read); + for (size_t i = 0; i < log_n; ++i) { + zm_cq_comms.push_back(deserialize_from_buffer(proof_data, num_frs_read)); + } + zm_cq_comm = deserialize_from_buffer(proof_data, num_frs_read); + zm_pi_comm = deserialize_from_buffer(proof_data, num_frs_read); + } + + void serialize_full_transcript() + { + size_t old_proof_length = proof_data.size(); + Transcript::proof_data.clear(); + size_t log_n = numeric::get_msb(circuit_size); + + serialize_to_buffer(circuit_size, Transcript::proof_data); + + serialize_to_buffer(Spike_kernel_inputs, Transcript::proof_data); + serialize_to_buffer(Spike_x, Transcript::proof_data); + + for (size_t i = 0; i < log_n; ++i) { + serialize_to_buffer(sumcheck_univariates[i], Transcript::proof_data); + } + serialize_to_buffer(sumcheck_evaluations, Transcript::proof_data); + for (size_t i = 0; i < log_n; ++i) { + serialize_to_buffer(zm_cq_comms[i], proof_data); + } + serialize_to_buffer(zm_cq_comm, proof_data); + serialize_to_buffer(zm_pi_comm, proof_data); + + // sanity check to make sure we generate the same length of proof as before. + ASSERT(proof_data.size() == old_proof_length); + } + }; +}; + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/spike_prover.cpp b/barretenberg/cpp/src/barretenberg/vm/generated/spike_prover.cpp new file mode 100644 index 00000000000..1f2925eecd1 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/generated/spike_prover.cpp @@ -0,0 +1,135 @@ + + +#include "spike_prover.hpp" +#include "barretenberg/commitment_schemes/claim.hpp" +#include "barretenberg/commitment_schemes/commitment_key.hpp" +#include "barretenberg/honk/proof_system/logderivative_library.hpp" +#include "barretenberg/honk/proof_system/permutation_library.hpp" +#include "barretenberg/plonk_honk_shared/library/grand_product_library.hpp" +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/relations/lookup_relation.hpp" +#include "barretenberg/relations/permutation_relation.hpp" +#include "barretenberg/sumcheck/sumcheck.hpp" + +namespace bb { + +using Flavor = SpikeFlavor; +using FF = Flavor::FF; + +/** + * Create SpikeProver from proving key, witness and manifest. + * + * @param input_key Proving key. + * @param input_manifest Input manifest + * + * @tparam settings Settings class. + * */ +SpikeProver::SpikeProver(std::shared_ptr input_key, + std::shared_ptr commitment_key) + : key(input_key) + , commitment_key(commitment_key) +{ + for (auto [prover_poly, key_poly] : zip_view(prover_polynomials.get_unshifted(), key->get_all())) { + ASSERT(bb::flavor_get_label(prover_polynomials, prover_poly) == bb::flavor_get_label(*key, key_poly)); + prover_poly = key_poly.share(); + } + for (auto [prover_poly, key_poly] : zip_view(prover_polynomials.get_shifted(), key->get_to_be_shifted())) { + ASSERT(bb::flavor_get_label(prover_polynomials, prover_poly) == + bb::flavor_get_label(*key, key_poly) + "_shift"); + prover_poly = key_poly.shifted(); + } +} + +/** + * @brief Add circuit size, public input size, and public inputs to transcript + * + */ +void SpikeProver::execute_preamble_round() +{ + const auto circuit_size = static_cast(key->circuit_size); + + transcript->send_to_verifier("circuit_size", circuit_size); +} + +/** + * @brief Compute commitments to all of the witness wires (apart from the logderivative inverse wires) + * + */ +void SpikeProver::execute_wire_commitments_round() +{ + + // Commit to all polynomials (apart from logderivative inverse polynomials, which are committed to in the later + // logderivative phase) + witness_commitments.Spike_kernel_inputs = commitment_key->commit(key->Spike_kernel_inputs); + witness_commitments.Spike_x = commitment_key->commit(key->Spike_x); + + // Send all commitments to the verifier + transcript->send_to_verifier(commitment_labels.Spike_kernel_inputs, witness_commitments.Spike_kernel_inputs); + transcript->send_to_verifier(commitment_labels.Spike_x, witness_commitments.Spike_x); +} + +void SpikeProver::execute_log_derivative_inverse_round() {} + +/** + * @brief Run Sumcheck resulting in u = (u_1,...,u_d) challenges and all evaluations at u being calculated. + * + */ +void SpikeProver::execute_relation_check_rounds() +{ + using Sumcheck = SumcheckProver; + + auto sumcheck = Sumcheck(key->circuit_size, transcript); + + FF alpha = transcript->template get_challenge("Sumcheck:alpha"); + std::vector gate_challenges(numeric::get_msb(key->circuit_size)); + + for (size_t idx = 0; idx < gate_challenges.size(); idx++) { + gate_challenges[idx] = transcript->template get_challenge("Sumcheck:gate_challenge_" + std::to_string(idx)); + } + sumcheck_output = sumcheck.prove(prover_polynomials, relation_parameters, alpha, gate_challenges); +} + +/** + * @brief Execute the ZeroMorph protocol to prove the multilinear evaluations produced by Sumcheck + * @details See https://hackmd.io/dlf9xEwhTQyE3hiGbq4FsA?view for a complete description of the unrolled protocol. + * + * */ +void SpikeProver::execute_zeromorph_rounds() +{ + ZeroMorph::prove(prover_polynomials.get_unshifted(), + prover_polynomials.get_to_be_shifted(), + sumcheck_output.claimed_evaluations.get_unshifted(), + sumcheck_output.claimed_evaluations.get_shifted(), + sumcheck_output.challenge, + commitment_key, + transcript); +} + +HonkProof& SpikeProver::export_proof() +{ + proof = transcript->proof_data; + return proof; +} + +HonkProof& SpikeProver::construct_proof() +{ + // Add circuit size public input size and public inputs to transcript. + execute_preamble_round(); + + // Compute wire commitments + execute_wire_commitments_round(); + + // Compute sorted list accumulator and commitment + + // Fiat-Shamir: alpha + // Run sumcheck subprotocol. + execute_relation_check_rounds(); + + // Fiat-Shamir: rho, y, x, z + // Execute Zeromorph multilinear PCS + execute_zeromorph_rounds(); + + return export_proof(); +} + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/spike_prover.hpp b/barretenberg/cpp/src/barretenberg/vm/generated/spike_prover.hpp new file mode 100644 index 00000000000..e80b92f384f --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/generated/spike_prover.hpp @@ -0,0 +1,64 @@ + + +#pragma once +#include "barretenberg/commitment_schemes/zeromorph/zeromorph.hpp" +#include "barretenberg/plonk/proof_system/types/proof.hpp" +#include "barretenberg/relations/relation_parameters.hpp" +#include "barretenberg/sumcheck/sumcheck_output.hpp" +#include "barretenberg/transcript/transcript.hpp" + +#include "barretenberg/vm/generated/spike_flavor.hpp" + +namespace bb { + +class SpikeProver { + + using Flavor = SpikeFlavor; + using FF = Flavor::FF; + using PCS = Flavor::PCS; + using PCSCommitmentKey = Flavor::CommitmentKey; + using ProvingKey = Flavor::ProvingKey; + using Polynomial = Flavor::Polynomial; + using ProverPolynomials = Flavor::ProverPolynomials; + using CommitmentLabels = Flavor::CommitmentLabels; + using Transcript = Flavor::Transcript; + + public: + explicit SpikeProver(std::shared_ptr input_key, std::shared_ptr commitment_key); + + void execute_preamble_round(); + void execute_wire_commitments_round(); + void execute_log_derivative_inverse_round(); + void execute_relation_check_rounds(); + void execute_zeromorph_rounds(); + + HonkProof& export_proof(); + HonkProof& construct_proof(); + + std::shared_ptr transcript = std::make_shared(); + + std::vector public_inputs; + + bb::RelationParameters relation_parameters; + + std::shared_ptr key; + + // Container for spans of all polynomials required by the prover (i.e. all multivariates evaluated by Sumcheck). + ProverPolynomials prover_polynomials; + + CommitmentLabels commitment_labels; + typename Flavor::WitnessCommitments witness_commitments; + + Polynomial quotient_W; + + SumcheckOutput sumcheck_output; + + std::shared_ptr commitment_key; + + using ZeroMorph = ZeroMorphProver_; + + private: + HonkProof proof; +}; + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/spike_verifier.cpp b/barretenberg/cpp/src/barretenberg/vm/generated/spike_verifier.cpp new file mode 100644 index 00000000000..52660b91ed9 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/generated/spike_verifier.cpp @@ -0,0 +1,110 @@ + + +#include "./spike_verifier.hpp" +#include "barretenberg/commitment_schemes/zeromorph/zeromorph.hpp" +#include "barretenberg/numeric/bitop/get_msb.hpp" +#include "barretenberg/polynomials/polynomial.hpp" +#include "barretenberg/transcript/transcript.hpp" + +namespace bb { + +SpikeVerifier::SpikeVerifier(std::shared_ptr verifier_key) + : key(verifier_key) +{} + +SpikeVerifier::SpikeVerifier(SpikeVerifier&& other) noexcept + : key(std::move(other.key)) + , pcs_verification_key(std::move(other.pcs_verification_key)) +{} + +SpikeVerifier& SpikeVerifier::operator=(SpikeVerifier&& other) noexcept +{ + key = other.key; + pcs_verification_key = (std::move(other.pcs_verification_key)); + commitments.clear(); + return *this; +} + +using FF = SpikeFlavor::FF; + +// Evaluate the given public input column over the multivariate challenge points +[[maybe_unused]] FF evaluate_public_input_column(std::vector points, std::vector challenges) +{ + Polynomial polynomial(points); + return polynomial.evaluate_mle(challenges); +} + +/** + * @brief This function verifies an Spike Honk proof for given program settings. + * + */ +bool SpikeVerifier::verify_proof(const HonkProof& proof, const std::vector& public_inputs) +{ + using Flavor = SpikeFlavor; + using FF = Flavor::FF; + using Commitment = Flavor::Commitment; + // using PCS = Flavor::PCS; + // using ZeroMorph = ZeroMorphVerifier_; + using VerifierCommitments = Flavor::VerifierCommitments; + using CommitmentLabels = Flavor::CommitmentLabels; + + RelationParameters relation_parameters; + + transcript = std::make_shared(proof); + + VerifierCommitments commitments{ key }; + CommitmentLabels commitment_labels; + + const auto circuit_size = transcript->template receive_from_prover("circuit_size"); + + if (circuit_size != key->circuit_size) { + return false; + } + + // Get commitments to VM wires + commitments.Spike_kernel_inputs = + transcript->template receive_from_prover(commitment_labels.Spike_kernel_inputs); + commitments.Spike_x = transcript->template receive_from_prover(commitment_labels.Spike_x); + + // Get commitments to inverses + + // Execute Sumcheck Verifier + const size_t log_circuit_size = numeric::get_msb(circuit_size); + auto sumcheck = SumcheckVerifier(log_circuit_size, transcript); + + FF alpha = transcript->template get_challenge("Sumcheck:alpha"); + + auto gate_challenges = std::vector(log_circuit_size); + for (size_t idx = 0; idx < log_circuit_size; idx++) { + gate_challenges[idx] = transcript->template get_challenge("Sumcheck:gate_challenge_" + std::to_string(idx)); + } + + auto [multivariate_challenge, claimed_evaluations, sumcheck_verified] = + sumcheck.verify(relation_parameters, alpha, gate_challenges); + + // If Sumcheck did not verify, return false + if (sumcheck_verified.has_value() && !sumcheck_verified.value()) { + return false; + } + + FF public_column_evaluation = evaluate_public_input_column(public_inputs, multivariate_challenge); + if (public_column_evaluation != claimed_evaluations.Spike_kernel_inputs) { + return false; + } + + // Execute ZeroMorph rounds. See https://hackmd.io/dlf9xEwhTQyE3hiGbq4FsA?view for a complete description of the + // unrolled protocol. + // NOTE: temporarily disabled - facing integration issues + // auto pairing_points = ZeroMorph::verify(commitments.get_unshifted(), + // commitments.get_to_be_shifted(), + // claimed_evaluations.get_unshifted(), + // claimed_evaluations.get_shifted(), + // multivariate_challenge, + // transcript); + + // auto verified = pcs_verification_key->pairing_check(pairing_points[0], pairing_points[1]); + // return sumcheck_verified.value() && verified; + return sumcheck_verified.value(); +} + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/vm/generated/spike_verifier.hpp b/barretenberg/cpp/src/barretenberg/vm/generated/spike_verifier.hpp new file mode 100644 index 00000000000..c4fb767455a --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/generated/spike_verifier.hpp @@ -0,0 +1,33 @@ + + +#pragma once +#include "barretenberg/plonk/proof_system/types/proof.hpp" +#include "barretenberg/sumcheck/sumcheck.hpp" +#include "barretenberg/vm/generated/spike_flavor.hpp" + +namespace bb { +class SpikeVerifier { + using Flavor = SpikeFlavor; + using FF = Flavor::FF; + using Commitment = Flavor::Commitment; + using VerificationKey = Flavor::VerificationKey; + using VerifierCommitmentKey = Flavor::VerifierCommitmentKey; + using Transcript = Flavor::Transcript; + + public: + explicit SpikeVerifier(std::shared_ptr verifier_key = nullptr); + SpikeVerifier(SpikeVerifier&& other) noexcept; + SpikeVerifier(const SpikeVerifier& other) = delete; + + SpikeVerifier& operator=(const SpikeVerifier& other) = delete; + SpikeVerifier& operator=(SpikeVerifier&& other) noexcept; + + bool verify_proof(const HonkProof& proof, const std::vector& public_inputs); + + std::shared_ptr key; + std::map commitments; + std::shared_ptr pcs_verification_key; + std::shared_ptr transcript; +}; + +} // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/vm/tests/spike.test.cpp b/barretenberg/cpp/src/barretenberg/vm/tests/spike.test.cpp new file mode 100644 index 00000000000..1b30f1f4a6c --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/tests/spike.test.cpp @@ -0,0 +1,73 @@ +#include "barretenberg/crypto/generators/generator_data.hpp" +#include "barretenberg/numeric/random/engine.hpp" +#include "barretenberg/numeric/uint256/uint256.hpp" +#include "barretenberg/vm/generated/spike_circuit_builder.hpp" +#include "barretenberg/vm/generated/spike_flavor.hpp" + +// Proofs +#include "barretenberg/vm/generated/spike_composer.hpp" +#include "barretenberg/vm/generated/spike_prover.hpp" +#include "barretenberg/vm/generated/spike_verifier.hpp" + +#include + +using namespace bb; +namespace { +auto& engine = numeric::get_debug_randomness(); +} + +class SpikePublicColumnsTests : public ::testing::Test { + protected: + // TODO(640): The Standard Honk on Grumpkin test suite fails unless the SRS is initialised for every test. + void SetUp() override { srs::init_crs_factory("../srs_db/ignition"); }; +}; + +// Test file for testing public inputs evaluations are the same in the verifier and in sumcheck +// +// The first test runs the verification with the same public inputs in the verifier and in the prover, prover inputs are +// set in the below function The second failure test runs the verification with the different public inputs +bool verify_spike_with_public_with_public_inputs(std::vector verifier_public__inputs) +{ + using Builder = SpikeCircuitBuilder; + using Row = Builder::Row; + Builder circuit_builder; + + srs::init_crs_factory("../srs_db/ignition"); + + const size_t circuit_size = 16; + std::vector rows; + + // Add to the public input column that is increasing + for (size_t i = 0; i < circuit_size; i++) { + // Make sure the external and trace public inputs are the same + Row row{ .Spike_kernel_inputs = i + 1 }; + rows.push_back(row); + } + + circuit_builder.set_trace(std::move(rows)); + + // Create a prover and verifier + auto composer = SpikeComposer(); + auto prover = composer.create_prover(circuit_builder); + HonkProof proof = prover.construct_proof(); + + auto verifier = composer.create_verifier(circuit_builder); + + return verifier.verify_proof(proof, verifier_public__inputs); +} + +TEST(SpikePublicColumnsTests, VerificationSuccess) +{ + std::vector public_inputs = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }; + bool verified = verify_spike_with_public_with_public_inputs(public_inputs); + ASSERT_TRUE(verified); +} + +TEST(SpikePublicColumnsTests, VerificationFailure) +{ + std::vector public_inputs = { + 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160 + }; + bool verified = verify_spike_with_public_with_public_inputs(public_inputs); + ASSERT_FALSE(verified); +} \ No newline at end of file