Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(avm): bugfixing witness generation for add, sub, mul for FF #9938

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions barretenberg/cpp/pil/avm/alu.pil
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ namespace alu(256);

// This holds the product over the integers
// (u1 multiplication only cares about a_lo and b_lo)
// TODO(9937): The following is not well constrained as this expression overflows the field.
pol PRODUCT = a_lo * b_lo + (1 - u1_tag) * (LIMB_BITS_POW * partial_prod_lo + MAX_BITS_POW * (partial_prod_hi + a_hi * b_hi));

// =============== ADDITION/SUBTRACTION Operation Constraints =================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,74 +403,83 @@ std::vector<std::array<FF, 3>> positive_op_div_test_values = { {
// Test on basic addition over finite field type.
TEST_F(AvmArithmeticTestsFF, addition)
{
std::vector<FF> const calldata = { 37, 4, 11 };
const FF a = FF::modulus - 19;
const FF b = FF::modulus - 5;
const FF c = FF::modulus - 24; // c = a + b
std::vector<FF> const calldata = { a, b, 4 };
gen_trace_builder(calldata);
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(0, 0, 1, 0);

// Memory layout: [37,4,11,0,0,0,....]
trace_builder.op_add(0, 0, 1, 4); // [37,4,11,0,41,0,....]
// Memory layout: [a,b,4,0,0,....]
trace_builder.op_add(0, 0, 1, 4); // [a,b,4,0,c,0,....]
trace_builder.op_set(0, 5, 100, AvmMemoryTag::U32);
trace_builder.op_return(0, 0, 100);
auto trace = trace_builder.finalize();

auto alu_row = common_validate_add(trace, FF(37), FF(4), FF(41), FF(0), FF(1), FF(4), AvmMemoryTag::FF);
auto alu_row = common_validate_add(trace, a, b, c, FF(0), FF(1), FF(4), AvmMemoryTag::FF);

EXPECT_EQ(alu_row.alu_ff_tag, FF(1));
EXPECT_EQ(alu_row.alu_cf, FF(0));

std::vector<FF> const returndata = { 37, 4, 11, 0, 41 };
std::vector<FF> const returndata = { a, b, 4, 0, c };

validate_trace(std::move(trace), public_inputs, calldata, returndata);
}

// Test on basic subtraction over finite field type.
TEST_F(AvmArithmeticTestsFF, subtraction)
{
std::vector<FF> const calldata = { 8, 4, 17 };
const FF a = 8;
const FF b = FF::modulus - 5;
const FF c = 13; // c = a - b
std::vector<FF> const calldata = { b, 4, a };
gen_trace_builder(calldata);
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(0, 0, 1, 0);

// Memory layout: [8,4,17,0,0,0,....]
trace_builder.op_sub(0, 2, 0, 1); // [8,9,17,0,0,0....]
// Memory layout: [b,4,a,0,0,0,....]
trace_builder.op_sub(0, 2, 0, 1); // [b,c,a,0,0,0....]
trace_builder.op_set(0, 3, 100, AvmMemoryTag::U32);
trace_builder.op_return(0, 0, 100);
auto trace = trace_builder.finalize();

auto alu_row = common_validate_sub(trace, FF(17), FF(8), FF(9), FF(2), FF(0), FF(1), AvmMemoryTag::FF);
auto alu_row = common_validate_sub(trace, a, b, c, FF(2), FF(0), FF(1), AvmMemoryTag::FF);

EXPECT_EQ(alu_row.alu_ff_tag, FF(1));
EXPECT_EQ(alu_row.alu_cf, FF(0));

std::vector<FF> const returndata = { 8, 9, 17 };
std::vector<FF> const returndata = { b, c, a };
validate_trace(std::move(trace), public_inputs, calldata, returndata);
}

// Test on basic multiplication over finite field type.
TEST_F(AvmArithmeticTestsFF, multiplication)
{
std::vector<FF> const calldata = { 5, 0, 20 };
const FF a = FF::modulus - 1;
const FF b = 278;
const FF c = FF::modulus - 278;
std::vector<FF> const calldata = { b, 0, a };
gen_trace_builder(calldata);
trace_builder.op_set(0, 0, 0, AvmMemoryTag::U32);
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(0, 0, 1, 0);

// Memory layout: [5,0,20,0,0,0,....]
trace_builder.op_mul(0, 2, 0, 1); // [5,100,20,0,0,0....]
// Memory layout: [b,0,a,0,0,0,....]
trace_builder.op_mul(0, 2, 0, 1); // [b,c,a,0,0,0....]
trace_builder.op_set(0, 3, 100, AvmMemoryTag::U32);
trace_builder.op_return(0, 0, 100);
auto trace = trace_builder.finalize();

auto alu_row_index = common_validate_mul(trace, FF(20), FF(5), FF(100), FF(2), FF(0), FF(1), AvmMemoryTag::FF);
auto alu_row_index = common_validate_mul(trace, a, b, c, FF(2), FF(0), FF(1), AvmMemoryTag::FF);
auto alu_row = trace.at(alu_row_index);

EXPECT_EQ(alu_row.alu_ff_tag, FF(1));
EXPECT_EQ(alu_row.alu_cf, FF(0));

std::vector<FF> const returndata = { 5, 100, 20 };
std::vector<FF> const returndata = { b, c, a };
validate_trace(std::move(trace), public_inputs, calldata, returndata);
}

Expand Down
66 changes: 43 additions & 23 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/alu_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,14 @@ void AvmAluTraceBuilder::reset()
FF AvmAluTraceBuilder::op_add(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk)
{
bool carry = false;
uint256_t c_u256 = uint256_t(a) + uint256_t(b);
FF c = cast_to_mem_tag(c_u256, in_tag);
FF c;

if (in_tag == AvmMemoryTag::FF) {
c = a + b;
} else {
uint256_t c_u256 = uint256_t(a) + uint256_t(b);
c = cast_to_mem_tag(c_u256, in_tag);

if (in_tag != AvmMemoryTag::FF) {
// a_u128 + b_u128 >= 2^128 <==> c_u128 < a_u128
if (uint128_t(c) < uint128_t(a)) {
carry = true;
Expand Down Expand Up @@ -150,10 +154,14 @@ FF AvmAluTraceBuilder::op_add(FF const& a, FF const& b, AvmMemoryTag in_tag, uin
FF AvmAluTraceBuilder::op_sub(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk)
{
bool carry = false;
uint256_t c_u256 = uint256_t(a) - uint256_t(b);
FF c = cast_to_mem_tag(c_u256, in_tag);
FF c;

if (in_tag == AvmMemoryTag::FF) {
c = a - b;
} else {
uint256_t c_u256 = uint256_t(a) - uint256_t(b);
c = cast_to_mem_tag(c_u256, in_tag);

if (in_tag != AvmMemoryTag::FF) {
// Underflow when a_u128 < b_u128
if (uint128_t(a) < uint128_t(b)) {
carry = true;
Expand Down Expand Up @@ -189,29 +197,41 @@ FF AvmAluTraceBuilder::op_sub(FF const& a, FF const& b, AvmMemoryTag in_tag, uin
*/
FF AvmAluTraceBuilder::op_mul(FF const& a, FF const& b, AvmMemoryTag in_tag, uint32_t const clk)
{
uint256_t a_u256{ a };
uint256_t b_u256{ b };
uint256_t c_u256 = a_u256 * b_u256; // Multiplication over the integers (not mod. 2^128)
FF c = 0;
uint256_t alu_a_lo = 0;
uint256_t alu_a_hi = 0;
uint256_t alu_b_lo = 0;
uint256_t alu_b_hi = 0;
uint256_t c_hi = 0;
uint256_t partial_prod_lo = 0;
uint256_t partial_prod_hi = 0;

FF c = cast_to_mem_tag(c_u256, in_tag);
if (in_tag == AvmMemoryTag::FF) {
c = a * b;
} else {

uint8_t bits = mem_tag_bits(in_tag);
// limbs are size 1 for u1
uint8_t limb_bits = bits == 1 ? 1 : bits / 2;
uint8_t num_bits = bits;
uint256_t a_u256{ a };
uint256_t b_u256{ b };
uint256_t c_u256 = a_u256 * b_u256; // Multiplication over the integers (not mod. 2^128)

// Decompose a
auto [alu_a_lo, alu_a_hi] = decompose(a_u256, limb_bits);
// Decompose b
auto [alu_b_lo, alu_b_hi] = decompose(b_u256, limb_bits);
c = cast_to_mem_tag(c_u256, in_tag);

uint256_t partial_prod = alu_a_lo * alu_b_hi + alu_a_hi * alu_b_lo;
// Decompose the partial product
auto [partial_prod_lo, partial_prod_hi] = decompose(partial_prod, limb_bits);
uint8_t bits = mem_tag_bits(in_tag);
// limbs are size 1 for u1
uint8_t limb_bits = bits == 1 ? 1 : bits / 2;
uint8_t num_bits = bits;

auto c_hi = c_u256 >> num_bits;
// Decompose a
std::tie(alu_a_lo, alu_a_hi) = decompose(a_u256, limb_bits);
// Decompose b
std::tie(alu_b_lo, alu_b_hi) = decompose(b_u256, limb_bits);

uint256_t partial_prod = alu_a_lo * alu_b_hi + alu_a_hi * alu_b_lo;
// Decompose the partial product
std::tie(partial_prod_lo, partial_prod_hi) = decompose(partial_prod, limb_bits);

c_hi = c_u256 >> num_bits;

if (in_tag != AvmMemoryTag::FF) {
cmp_builder.range_check_builder.assert_range(uint128_t(c), mem_tag_bits(in_tag), EventEmitter::ALU, clk);
}

Expand Down
Loading