Skip to content

Commit

Permalink
chore: aggregate with short scalars in UH Recursion (#11478)
Browse files Browse the repository at this point in the history
Take advantage of short challenges to create less gates while
aggregating pairing points.

Fixed several short scalar issues and re-used
`bn254_endo_batch_mul(...)` to define scalar mul operator in relevant
contexts.

**UH Recursive Verifier finalized num gates**

Before:  866732
After:  729534
  • Loading branch information
iakovenkos authored Feb 6, 2025
1 parent 62e5de7 commit a6fcdb0
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ UltraRecursiveVerifier_<Flavor>::Output UltraRecursiveVerifier_<Flavor>::verify_
// TODO(https://github.com/AztecProtocol/barretenberg/issues/995): generate this challenge properly.
typename Curve::ScalarField recursion_separator =
Curve::ScalarField::from_witness_index(builder, builder->add_variable(42));
agg_obj.aggregate(nested_agg_obj, recursion_separator);
agg_obj.template aggregate<Builder>(nested_agg_obj, recursion_separator);

// Execute Sumcheck Verifier and extract multivariate opening point u = (u_0, ..., u_{d-1}) and purported
// multivariate evaluations at u
Expand Down Expand Up @@ -143,11 +143,11 @@ UltraRecursiveVerifier_<Flavor>::Output UltraRecursiveVerifier_<Flavor>::verify_
pairing_points[0] = pairing_points[0].normalize();
pairing_points[1] = pairing_points[1].normalize();
// TODO(https://github.com/AztecProtocol/barretenberg/issues/995): generate recursion separator challenge properly.
agg_obj.aggregate(pairing_points, recursion_separator);
agg_obj.template aggregate<Builder>(pairing_points, recursion_separator);
output.agg_obj = std::move(agg_obj);

// Extract the IPA claim from the public inputs
// Parse out the nested IPA claim using key->ipa_claim_public_input_indices and runs the native IPA verifier.
// Parse out the nested IPA claim using key->ipa_claim_public_input_indices and run the native IPA verifier.
if constexpr (HasIPAAccumulator<Flavor>) {
const auto recover_fq_from_public_inputs = [](std::array<FF, Curve::BaseField::NUM_LIMBS>& limbs) {
for (size_t k = 0; k < Curve::BaseField::NUM_LIMBS; k++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,36 @@ template <typename Curve> struct aggregation_state {
{
return P0 == other.P0 && P1 == other.P1;
};

template <typename BuilderType = void>
void aggregate(aggregation_state const& other, typename Curve::ScalarField recursion_separator)
{
P0 += other.P0 * recursion_separator;
P1 += other.P1 * recursion_separator;
if constexpr (std::is_same_v<BuilderType, MegaCircuitBuilder>) {
P0 += other.P0 * recursion_separator;
P1 += other.P1 * recursion_separator;
} else {
// Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1,
// recursion_separator} directly to avoid edge cases.
typename Curve::Group point_to_aggregate = other.P0.scalar_mul(recursion_separator, 128);
P0 += point_to_aggregate;
point_to_aggregate = other.P1.scalar_mul(recursion_separator, 128);
P1 += point_to_aggregate;
}
}

template <typename BuilderType = void>
void aggregate(std::array<typename Curve::Group, 2> const& other, typename Curve::ScalarField recursion_separator)
{
P0 += other[0] * recursion_separator;
P1 += other[1] * recursion_separator;
if constexpr (std::is_same_v<BuilderType, MegaCircuitBuilder>) {
P0 += other[0] * recursion_separator;
P1 += other[1] * recursion_separator;
} else {
// Save gates using short scalars. We don't apply `bn254_endo_batch_mul` to the vector {1,
// recursion_separator} directly to avoid edge cases.
typename Curve::Group point_to_aggregate = other[0].scalar_mul(recursion_separator, 128);
P0 += point_to_aggregate;
point_to_aggregate = other[1].scalar_mul(recursion_separator, 128);
P1 += point_to_aggregate;
}
}

PairingPointAccumulatorIndices get_witness_indices()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
result.y.assert_is_in_field();
return result;
}
element scalar_mul(const Fr& scalar, const size_t max_num_bits = 0) const;

element reduce() const
{
Expand Down Expand Up @@ -525,7 +526,10 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
num_fives = num_points / 5;
num_sixes = 0;
// size-6 table is expensive and only benefits us if creating them reduces the number of total tables
if (num_fives * 5 == (num_points - 1)) {
if (num_points == 1) {
num_fives = 0;
num_sixes = 0;
} else if (num_fives * 5 == (num_points - 1)) {
num_fives -= 1;
num_sixes = 1;
} else if (num_fives * 5 == (num_points - 2) && num_fives >= 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,142 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
EXPECT_CIRCUIT_CORRECTNESS(builder);
}

// Test short scalar mul with variable even bit length. For efficiency, it's split into two tests.
static void test_short_scalar_mul_2_126()
{
Builder builder;
const size_t max_num_bits = 128;

// We only test even bit lengths, because `bn254_endo_batch_mul` used in 'scalar_mul' can't handle odd lengths.
for (size_t i = 2; i < max_num_bits; i += 2) {
affine_element input(element::random_element());
// Get a random 256 integer
uint256_t scalar_raw = engine.get_random_uint256();
// Produce a length =< i scalar.
scalar_raw = scalar_raw >> (256 - i);
fr scalar = fr(scalar_raw);

// Avoid multiplication by 0 that may occur when `i` is small
if (scalar == fr(0)) {
scalar += 1;
};

element_ct P = element_ct::from_witness(&builder, input);
scalar_ct x = scalar_ct::from_witness(&builder, scalar);

// Set input tags
x.set_origin_tag(challenge_origin_tag);
P.set_origin_tag(submitted_value_origin_tag);

std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl;
// Multiply using specified scalar length
element_ct c = P.scalar_mul(x, i);
std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl;
affine_element c_expected(element(input) * scalar);

// Check the result of the multiplication has a tag that's the union of inputs' tags
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);
fq c_x_result(c.x.get_value().lo);
fq c_y_result(c.y.get_value().lo);

EXPECT_EQ(c_x_result, c_expected.x);

EXPECT_EQ(c_y_result, c_expected.y);
}

EXPECT_CIRCUIT_CORRECTNESS(builder);
}

static void test_short_scalar_mul_128_252()
{
Builder builder;
const size_t max_num_bits = 254;

// We only test even bit lengths, because `bn254_endo_batch_mul` used in 'scalar_mul' can't handle odd lengths.
for (size_t i = 128; i < max_num_bits; i += 2) {
affine_element input(element::random_element());
// Get a random 256-bit integer
uint256_t scalar_raw = engine.get_random_uint256();
// Produce a length =< i scalar.
scalar_raw = scalar_raw >> (256 - i);
fr scalar = fr(scalar_raw);

element_ct P = element_ct::from_witness(&builder, input);
scalar_ct x = scalar_ct::from_witness(&builder, scalar);

// Set input tags
x.set_origin_tag(challenge_origin_tag);
P.set_origin_tag(submitted_value_origin_tag);

std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl;
// Multiply using specified scalar length
element_ct c = P.scalar_mul(x, i);
std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl;
affine_element c_expected(element(input) * scalar);

// Check the result of the multiplication has a tag that's the union of inputs' tags
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);
fq c_x_result(c.x.get_value().lo);
fq c_y_result(c.y.get_value().lo);

EXPECT_EQ(c_x_result, c_expected.x);

EXPECT_EQ(c_y_result, c_expected.y);
}

EXPECT_CIRCUIT_CORRECTNESS(builder);
}

static void test_short_scalar_mul_infinity()
{
// We check that a point at infinity preserves `is_point_at_infinity()` flag after being multiplied against a
// short scalar and also check that the number of gates in this case is equal to the number of gates spent on a
// finite point.

// Populate test points.
std::vector<element> points(2);

points[0] = element::infinity();
points[1] = element::random_element();
// Containter for gate counts.
std::vector<size_t> gates(2);

// We initialize this flag as `true`, because the first result is expected to be the point at infinity.
bool expect_infinity = true;

for (auto [point, num_gates] : zip_view(points, gates)) {
Builder builder;

const size_t max_num_bits = 128;
// Get a random 256-bit integer
uint256_t scalar_raw = engine.get_random_uint256();
// Produce a length =< max_num_bits scalar.
scalar_raw = scalar_raw >> (256 - max_num_bits);
fr scalar = fr(scalar_raw);

element_ct P = element_ct::from_witness(&builder, point);
scalar_ct x = scalar_ct::from_witness(&builder, scalar);

// Set input tags
x.set_origin_tag(challenge_origin_tag);
P.set_origin_tag(submitted_value_origin_tag);

std::cerr << "gates before mul " << builder.get_estimated_num_finalized_gates() << std::endl;
element_ct c = P.scalar_mul(x, max_num_bits);
std::cerr << "builder aftr mul " << builder.get_estimated_num_finalized_gates() << std::endl;
num_gates = builder.get_estimated_num_finalized_gates();
// Check the result of the multiplication has a tag that's the union of inputs' tags
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);

EXPECT_EQ(c.is_point_at_infinity().get_value(), expect_infinity);
EXPECT_CIRCUIT_CORRECTNESS(builder);
// The second point is finite, hence we flip the flag
expect_infinity = false;
}
// Check that the numbers of gates are equal in both cases.
EXPECT_EQ(gates[0], gates[1]);
}

static void test_twin_mul()
{
Builder builder;
Expand Down Expand Up @@ -950,26 +1086,39 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
static void test_compute_naf()
{
Builder builder = Builder();
size_t num_repetitions(32);
for (size_t i = 0; i < num_repetitions; i++) {
fr scalar_val = fr::random_element();
size_t max_num_bits = 254;
// Our design of NAF and the way it is used assumes the even length of scalars.
for (size_t length = 2; length < max_num_bits; length += 2) {

fr scalar_val;

uint256_t scalar_raw = engine.get_random_uint256();
scalar_raw = scalar_raw >> (256 - length);

scalar_val = fr(scalar_raw);

// NAF with short scalars doesn't handle 0
if (scalar_val == fr(0)) {
scalar_val += 1;
};
scalar_ct scalar = scalar_ct::from_witness(&builder, scalar_val);
// Set tag for scalar
scalar.set_origin_tag(submitted_value_origin_tag);
auto naf = element_ct::compute_naf(scalar);
auto naf = element_ct::compute_naf(scalar, length);

for (const auto& bit : naf) {
// Check that the tag is propagated to bits
EXPECT_EQ(bit.get_origin_tag(), submitted_value_origin_tag);
}
// scalar = -naf[254] + \sum_{i=0}^{253}(1-2*naf[i]) 2^{253-i}
fr reconstructed_val(0);
for (size_t i = 0; i < 254; i++) {
reconstructed_val += (fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (253 - i));
for (size_t i = 0; i < length; i++) {
reconstructed_val += (fr(1) - fr(2) * fr(naf[i].witness_bool)) * fr(uint256_t(1) << (length - 1 - i));
};
reconstructed_val -= fr(naf[254].witness_bool);
reconstructed_val -= fr(naf[length].witness_bool);
EXPECT_EQ(scalar_val, reconstructed_val);
}

EXPECT_CIRCUIT_CORRECTNESS(builder);
}

Expand Down Expand Up @@ -1614,6 +1763,33 @@ HEAVY_TYPED_TEST(stdlib_biggroup, mul)
{
TestFixture::test_mul();
}

HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_2_126_bits)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
GTEST_SKIP();
} else {
TestFixture::test_short_scalar_mul_2_126();
}
}
HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_128_252_bits)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
GTEST_SKIP();
} else {
TestFixture::test_short_scalar_mul_128_252();
}
}

HEAVY_TYPED_TEST(stdlib_biggroup, short_scalar_mul_infinity)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
GTEST_SKIP();
} else {
TestFixture::test_short_scalar_mul_infinity();
}
}

HEAVY_TYPED_TEST(stdlib_biggroup, twin_mul)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::bn254_endo_batch_mul(const std::vec
const std::vector<Fr>& small_scalars,
const size_t max_num_small_bits)
{
ASSERT(max_num_small_bits >= 128);

ASSERT(max_num_small_bits % 2 == 0);

const size_t num_big_points = big_points.size();
const size_t num_small_points = small_points.size();
C* ctx = nullptr;
Expand Down
Loading

0 comments on commit a6fcdb0

Please sign in to comment.