Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Runtime for short weierstrass curve operations over Fp #566

Merged
merged 8 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions circuits/ecc/src/field_expression/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use afs_primitives::{
OverflowInt,
},
sub_chip::{AirConfig, LocalTraceInstructions, SubAir},
var_range::VariableRangeCheckerChip,
var_range::{bus::VariableRangeCheckerBus, VariableRangeCheckerChip},
};
use afs_stark_backend::{
interaction::InteractionBuilder,
Expand Down Expand Up @@ -135,8 +135,7 @@ pub struct FieldExpr {

pub check_carry_mod_to_zero: CheckCarryModToZeroSubAir,

pub range_bus: usize,
pub range_max_bits: usize,
pub range_bus: VariableRangeCheckerBus,
}

impl Deref for FieldExpr {
Expand Down Expand Up @@ -204,8 +203,8 @@ impl<AB: InteractionBuilder> SubAir<AB> for FieldExpr {
for limb in var.limbs.iter() {
range_check(
builder,
self.range_bus,
self.range_max_bits,
self.range_bus.index,
self.range_bus.range_max_bits,
self.limb_bits,
limb.clone(),
is_valid,
Expand Down
22 changes: 22 additions & 0 deletions circuits/ecc/src/field_expression/field_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,28 @@ impl FieldVariable {
}
}

pub fn square(&mut self) -> FieldVariable {
let builder = self.builder.borrow();
let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
let max_overflow_bits = log2_ceil_usize(limb_max_abs);
let (_, carry_bits) = get_carry_max_abs_and_bits(max_overflow_bits, builder.limb_bits);
drop(builder);
if carry_bits > self.range_checker_bits {
self.save();
}

let limb_max_abs = self.limb_max_abs * self.limb_max_abs * self.expr_limbs;
let max_overflow_bits = log2_ceil_usize(limb_max_abs);
FieldVariable {
expr: SymbolicExpr::Mul(Box::new(self.expr.clone()), Box::new(self.expr.clone())),
builder: self.builder.clone(),
limb_max_abs,
max_overflow_bits,
expr_limbs: self.expr_limbs * 2 - 1,
range_checker_bits: self.range_checker_bits,
}
}

pub fn int_mul(&mut self, scalar: isize) -> FieldVariable {
let builder = self.builder.borrow();
let max_limb_bits = builder.max_limb_bits;
Expand Down
27 changes: 9 additions & 18 deletions circuits/ecc/src/field_expression/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ fn test_add() {
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&expr);

Expand Down Expand Up @@ -130,8 +129,7 @@ fn test_div() {
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&expr);

Expand Down Expand Up @@ -176,8 +174,7 @@ fn test_auto_carry_mul() {
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&expr);
let x = generate_random_biguint(&prime);
Expand Down Expand Up @@ -222,8 +219,7 @@ fn test_auto_carry_intmul() {
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&expr);
let x = generate_random_biguint(&prime);
Expand Down Expand Up @@ -277,8 +273,7 @@ fn test_auto_carry_add() {
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&expr);

Expand Down Expand Up @@ -324,8 +319,7 @@ fn test_ec_add() {
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&expr);
let (x1, y1) = SampleEcPoints[0].clone();
Expand Down Expand Up @@ -370,8 +364,7 @@ fn test_ec_double() {
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&expr);

Expand Down Expand Up @@ -417,8 +410,7 @@ fn test_select() {
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&expr);

Expand Down Expand Up @@ -463,8 +455,7 @@ fn test_select2() {
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&expr);

Expand Down
6 changes: 2 additions & 4 deletions circuits/ecc/src/field_extension/fp2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ mod tests {
let air = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&air);

Expand Down Expand Up @@ -211,8 +210,7 @@ mod tests {
let air = FieldExpr {
builder: builder.clone(),
check_carry_mod_to_zero: subair,
range_bus: range_checker.bus().index,
range_max_bits: range_checker.range_max_bits(),
range_bus: range_checker.bus(),
};
let width = BaseAir::<BabyBear>::width(&air);

Expand Down
4 changes: 4 additions & 0 deletions vm/src/intrinsics/ecc_v2/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod sw;

// Babybear
pub const FIELD_ELEMENT_BITS: usize = 30;
1 change: 1 addition & 0 deletions vm/src/intrinsics/ecc_v2/sw/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Short Weierstrass (sw) curve (with a = 0) operations
194 changes: 194 additions & 0 deletions vm/src/intrinsics/ecc_v2/sw/add_ne.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
use std::{cell::RefCell, rc::Rc};

use afs_primitives::{
bigint::check_carry_mod_to_zero::CheckCarryModToZeroSubAir,
var_range::bus::VariableRangeCheckerBus,
};
use afs_stark_backend::rap::BaseAirWithPublicValues;
use ax_ecc_primitives::field_expression::{ExprBuilder, FieldExpr};
use num_bigint_dig::BigUint;
use p3_air::{AirBuilder, BaseAir};
use p3_field::{Field, PrimeField32};

use super::super::FIELD_ELEMENT_BITS;
use crate::{
arch::{
AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
VmCoreAir, VmCoreChip,
},
system::program::Instruction,
utils::{biguint_to_limbs_vec, limbs_to_biguint},
};

#[derive(Clone)]
pub struct SwEcAddNeCoreAir {
pub expr: FieldExpr,
pub offset: usize,
}

impl SwEcAddNeCoreAir {
pub fn new(
modulus: BigUint, // The coordinate field.
num_limbs: usize,
limb_bits: usize,
max_limb_bits: usize,
range_bus: VariableRangeCheckerBus,
offset: usize,
) -> Self {
assert!(modulus.bits() <= num_limbs * limb_bits);
let subair = CheckCarryModToZeroSubAir::new(
modulus.clone(),
limb_bits,
range_bus.index,
range_bus.range_max_bits,
FIELD_ELEMENT_BITS,
);
let builder = ExprBuilder::new(
modulus,
limb_bits,
num_limbs,
range_bus.range_max_bits,
max_limb_bits,
);
let builder = Rc::new(RefCell::new(builder));

let x1 = ExprBuilder::new_input(builder.clone());
let y1 = ExprBuilder::new_input(builder.clone());
let x2 = ExprBuilder::new_input(builder.clone());
let y2 = ExprBuilder::new_input(builder.clone());
let mut lambda = (y2 - y1.clone()) / (x2.clone() - x1.clone());
let mut x3 = lambda.square() - x1.clone() - x2;
x3.save();
let mut y3 = lambda * (x1 - x3.clone()) - y1;
y3.save();

let builder = builder.borrow().clone();
let expr = FieldExpr {
builder,
check_carry_mod_to_zero: subair,
range_bus,
};
Self { expr, offset }
}
}

impl<F: Field> BaseAir<F> for SwEcAddNeCoreAir {
fn width(&self) -> usize {
BaseAir::<F>::width(&self.expr)
}
}

impl<F: Field> BaseAirWithPublicValues<F> for SwEcAddNeCoreAir {}

impl<AB: AirBuilder, I> VmCoreAir<AB, I> for SwEcAddNeCoreAir
where
I: VmAdapterInterface<AB::Expr>,
I::Reads: From<Vec<AB::Expr>>,
I::Writes: From<Vec<AB::Expr>>,
I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
{
fn eval(
&self,
_builder: &mut AB,
_local: &[AB::Var],
_from_pc: AB::Var,
) -> AdapterAirContext<AB::Expr, I> {
todo!()
}
}

pub struct SwEcAddNeCoreChip {
pub air: SwEcAddNeCoreAir,
}

impl SwEcAddNeCoreChip {
pub fn new(
modulus: BigUint,
num_limbs: usize,
limb_bits: usize,
max_limb_bits: usize,
range_bus: VariableRangeCheckerBus,
offset: usize,
) -> Self {
let air = SwEcAddNeCoreAir::new(
modulus,
num_limbs,
limb_bits,
max_limb_bits,
range_bus,
offset,
);
Self { air }
}
}

impl<F: PrimeField32, I> VmCoreChip<F, I> for SwEcAddNeCoreChip
where
I: VmAdapterInterface<F>,
I::Reads: Into<Vec<F>>,
I::Writes: From<Vec<F>>,
{
type Record = ();
type Air = SwEcAddNeCoreAir;

fn execute_instruction(
&self,
_instruction: &Instruction<F>,
_from_pc: u32,
reads: I::Reads,
) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
// Input: 2 EcPoint<Fp>, so total 4 field elements.
let field_element_limbs = self.air.expr.canonical_num_limbs();
let limb_bits = self.air.expr.canonical_limb_bits();
let data: Vec<F> = reads.into();
assert_eq!(data.len(), 4 * field_element_limbs);
let data_u32: Vec<u32> = data.iter().map(|x| x.as_canonical_u32()).collect();

let x1 = limbs_to_biguint(&data_u32[..field_element_limbs], limb_bits);
let y1 = limbs_to_biguint(
&data_u32[field_element_limbs..2 * field_element_limbs],
limb_bits,
);
let x2 = limbs_to_biguint(
&data_u32[2 * field_element_limbs..3 * field_element_limbs],
limb_bits,
);
let y2 = limbs_to_biguint(
&data_u32[3 * field_element_limbs..4 * field_element_limbs],
limb_bits,
);

let vars = self.air.expr.execute(vec![x1, y1, x2, y2], vec![]);
assert_eq!(vars.len(), 3); // lambda, x3, y3
let x3 = vars[1].clone();
let y3 = vars[2].clone();

let x3_limbs = biguint_to_limbs_vec(x3, limb_bits, field_element_limbs);
let y3_limbs = biguint_to_limbs_vec(y3, limb_bits, field_element_limbs);

Ok((
AdapterRuntimeContext {
to_pc: None,
writes: [x3_limbs, y3_limbs]
.concat()
.into_iter()
.map(|x| F::from_canonical_u32(x))
.collect::<Vec<_>>()
.into(),
},
(),
))
}

fn get_opcode_name(&self, _opcode: usize) -> String {
"SwEcAddNe".to_string()
}

fn generate_trace_row(&self, _row_slice: &mut [F], _record: Self::Record) {
todo!()
}

fn air(&self) -> &Self::Air {
&self.air
}
}
Loading