Skip to content

Commit

Permalink
feat: eDSL support for secp256k1 add (#449)
Browse files Browse the repository at this point in the history
* feat: eDSL support for secp256k1 add

* fix test

---------

Co-authored-by: luffykai <[email protected]>
  • Loading branch information
jonathanpwang and luffykai authored Sep 19, 2024
1 parent c870c0a commit 4212486
Show file tree
Hide file tree
Showing 26 changed files with 406 additions and 290 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ rand = "0.8.5"
k256 = "0.13.3"
elliptic-curve = "0.13.8"
num-bigint-dig = "0.8.4"
num-bigint = "0.4.6"
num-integer = "0.1.46"
num-traits = "0.2.19"
hex-literal = "0.4.1"
metrics = "0.23.0"
cfg-if = "1.0.0"
Expand Down
8 changes: 4 additions & 4 deletions compiler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ serde.workspace = true
serde_json.workspace = true
backtrace = { version = "0.3.71", features = ["serde"] }
strum_macros = "0.26.4"
num-bigint-dig = { workspace = true }
num-bigint = "0.4.6"
num-integer = "0.1.46"
num-traits = "0.2.19"
num-bigint-dig.workspace = true
num-bigint.workspace = true
num-integer.workspace = true
num-traits.workspace = true
# disable jemalloc to be compatible with afs-starkbackend
snark-verifier-sdk = { workspace = true, optional = true }
zkhash = { workspace = true }
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,18 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
},
_ => unimplemented!(),
},
DslIr::Secp256k1AddUnequal(dst, p, q) => {
self.push(
AsmInstruction::Secp256k1AddUnequal(dst.ptr_fp(), p.ptr_fp(), q.ptr_fp()),
debug_info,
);
}
DslIr::Secp256k1Double(dst, p) => {
self.push(
AsmInstruction::Secp256k1Double(dst.ptr_fp(), p.ptr_fp()),
debug_info,
);
}
DslIr::Error() => self.push(AsmInstruction::j(self.trap_label), debug_info),
DslIr::PrintF(dst) => {
self.push(AsmInstruction::PrintF(dst.fp()), debug_info);
Expand Down
14 changes: 14 additions & 0 deletions compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ pub enum AsmInstruction<F, EF> {
/// Same as `Keccak256`, but with fixed length input (hence length is an immediate value).
Keccak256FixLen(i32, i32, F),

/// (dst_ptr_ptr, p_ptr_ptr, q_ptr_ptr) are pointers to pointers to (dst, p, q).
/// Reads p,q from memory and writes p+q to dst.
/// Assumes p != +-q as secp256k1 points.
Secp256k1AddUnequal(i32, i32, i32),
/// (dst_ptr_ptr, p_ptr_ptr) are pointers to pointers to (dst, p).
/// Reads p,q from memory and writes 2*p to dst.
Secp256k1Double(i32, i32),

/// Print a variable.
PrintV(i32),

Expand Down Expand Up @@ -393,6 +401,12 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
AsmInstruction::Keccak256FixLen(dst, src, len) => {
write!(f, "keccak256 ({dst})fp, ({src})fp, {len}",)
}
AsmInstruction::Secp256k1AddUnequal(dst, p, q) => {
write!(f, "secp256k1_add_unequal ({})fp, ({})fp, ({})fp", dst, p, q)
}
AsmInstruction::Secp256k1Double(dst, p) => {
write!(f, "secp256k1_double ({})fp, ({})fp", dst, p)
}
AsmInstruction::PrintF(dst) => {
write!(f, "print_f ({})fp", dst)
}
Expand Down
17 changes: 17 additions & 0 deletions compiler/src/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,23 @@ fn convert_instruction<F: PrimeField32, EF: ExtensionField<F>>(
// AS::Immediate,
// )
}
AsmInstruction::Secp256k1AddUnequal(dst_ptr_ptr, p_ptr_ptr, q_ptr_ptr) => vec![inst_med(
SECP256K1_EC_ADD_NE,
i32_f(dst_ptr_ptr),
i32_f(p_ptr_ptr),
i32_f(q_ptr_ptr),
AS::Memory,
AS::Memory,
AS::Memory,
)],
AsmInstruction::Secp256k1Double(dst_ptr_ptr, p_ptr_ptr) => vec![inst(
SECP256K1_EC_DOUBLE,
i32_f(dst_ptr_ptr),
i32_f(p_ptr_ptr),
F::zero(),
AS::Memory,
AS::Memory,
)],
AsmInstruction::CycleTrackerStart(name) => {
if options.enable_cycle_tracker {
vec![dbg(CT_START, name)]
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ pub struct Builder<C: Config> {
pub(crate) witness_var_count: u32,
pub(crate) witness_felt_count: u32,
pub(crate) witness_ext_count: u32,
pub(crate) bigint_repr_size: u32,
pub bigint_repr_size: u32,
pub flags: BuilderFlags,
pub is_sub_builder: bool,
}
Expand Down
15 changes: 5 additions & 10 deletions compiler/src/ir/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl<C: Config, V: MemVariable<C>> Array<C, V> {
panic!("Cannot slice a fixed array with a variable start or end");
}
}
Self::Dyn(_, len) => {
Self::Dyn(ptr, len) => {
if builder.flags.debug {
let valid = builder.lt(start, end);
builder.assert_var_eq(valid, C::N::one());
Expand All @@ -132,15 +132,10 @@ impl<C: Config, V: MemVariable<C>> Array<C, V> {
builder.assert_var_eq(valid, C::N::one());
}

let slice_len = builder.eval_expr(end - start);
let slice = builder.dyn_array(slice_len);
builder.range(0, slice_len).for_each(|i, builder| {
let idx = builder.eval_expr(start + i);
let value = builder.get(self, idx);
builder.set(&slice, i, value);
});

slice
let slice_len = builder.eval(end - start);
let address = builder.eval(ptr.address + start);
let ptr = Ptr { address };
Array::Dyn(ptr, Usize::Var(slice_len))
}
}
}
Expand Down
148 changes: 75 additions & 73 deletions compiler/src/ir/elliptic_curve.rs
Original file line number Diff line number Diff line change
@@ -1,102 +1,104 @@
use num_bigint_dig::BigUint;
use num_traits::{FromPrimitive, Zero};
use p3_field::{AbstractField, PrimeField64};

use super::{Array, DslIr};
use crate::ir::{modular_arithmetic::BigUintVar, Builder, Config, Var};

impl<C: Config> Builder<C>
where
C::N: PrimeField64,
{
pub fn ec_add(
/// Computes `p + q`, handling cases where `p` or `q` are identity.
///
/// A point is stored as a tuple of affine coordinates, contiguously in memory as 64 bytes.
/// Identity point is represented as (0, 0).
pub fn secp256k1_add(
&mut self,
point_1: &(BigUintVar<C>, BigUintVar<C>),
point_2: &(BigUintVar<C>, BigUintVar<C>),
) -> (BigUintVar<C>, BigUintVar<C>) {
let (x1, y1) = point_1;
let (x2, y2) = point_2;
point_1: Array<C, Var<C::N>>,
point_2: Array<C, Var<C::N>>,
) -> Array<C, Var<C::N>> {
// number of limbs to represent one coordinate
let num_limbs = ((256 + self.bigint_repr_size - 1) / self.bigint_repr_size) as usize;
// Assuming point_1.len() = 2 * num_limbs
let x1 = point_1.slice(self, 0, num_limbs);
let y1 = point_1.slice(self, num_limbs, 2 * num_limbs);

let x1_zero = self.secp256k1_coord_is_zero(x1);
let y1_zero = self.secp256k1_coord_is_zero(y1);
let x2_zero = self.secp256k1_coord_is_zero(x2);
let y2_zero = self.secp256k1_coord_is_zero(y2);
let xs_equal = self.secp256k1_coord_eq(x1, x2);
let ys_equal = self.secp256k1_coord_eq(y1, y2);
let y_sum = self.secp256k1_coord_add(y1, y2);
let ys_opposite = self.secp256k1_coord_is_zero(&y_sum);
let result_x = self.uninit();
let result_y = self.uninit();
let res = self.uninit();
let x1_zero = self.secp256k1_coord_is_zero(&x1);
let y1_zero = self.secp256k1_coord_is_zero(&y1);

// if point_1 is identity
self.if_eq(x1_zero * y1_zero, C::N::one()).then_or_else(
|builder| {
builder.assign(&result_x, x2.clone());
builder.assign(&result_y, y2.clone());
builder.assign(&res, point_2.clone());
},
|builder| {
let x2 = point_2.slice(builder, 0, num_limbs);
let y2 = point_2.slice(builder, num_limbs, 2 * num_limbs);
let x2_zero = builder.secp256k1_coord_is_zero(&x2);
let y2_zero = builder.secp256k1_coord_is_zero(&y2);
// else if point_2 is identity
builder.if_eq(x2_zero * y2_zero, C::N::one()).then_or_else(
|builder| {
builder.assign(&result_x, x1.clone());
builder.assign(&result_y, y1.clone());
builder.assign(&res, point_1.clone());
},
|builder| {
// else if point_1 = -point_2
builder
.if_eq(xs_equal * ys_opposite, C::N::one())
.then_or_else(
|builder| {
let zero = builder.eval_biguint(BigUint::zero());
builder.assign(&result_x, zero.clone());
builder.assign(&result_y, zero);
},
|builder| {
let lambda = builder.uninit();
// else if point_1 = point_2
builder
.if_eq(xs_equal * ys_equal, C::N::one())
.then_or_else(
|builder| {
let two = builder
.eval_biguint(BigUint::from_u8(2).unwrap());
let three = builder
.eval_biguint(BigUint::from_u8(3).unwrap());
let two_y = builder.secp256k1_coord_mul(&two, y1);
let x_squared = builder.secp256k1_coord_mul(x1, x1);
let three_x_squared =
builder.secp256k1_coord_mul(&three, &x_squared);
let lambda_value = builder
.secp256k1_coord_div(&three_x_squared, &two_y);
builder.assign(&lambda, lambda_value);
},
|builder| {
// else (general case)
let dy = builder.secp256k1_coord_sub(y2, y1);
let dx = builder.secp256k1_coord_sub(x2, x1);
let lambda_value =
builder.secp256k1_coord_div(&dy, &dx);
builder.assign(&lambda, lambda_value);
},
);
let lambda_squared =
builder.secp256k1_coord_mul(&lambda, &lambda);
let x_sum = builder.secp256k1_coord_add(x1, x2);
let x3 = builder.secp256k1_coord_sub(&lambda_squared, &x_sum);
let x1_minus_x3 = builder.secp256k1_coord_sub(x1, &x3);
let lambda_times_x1_minus_x3 =
builder.secp256k1_coord_mul(&lambda, &x1_minus_x3);
let y3 =
builder.secp256k1_coord_sub(&lambda_times_x1_minus_x3, y1);
builder.assign(&result_x, x3);
builder.assign(&result_y, y3);
},
);
let xs_equal = builder.secp256k1_coord_eq(&x1, &x2);
builder.if_eq(xs_equal, C::N::one()).then_or_else(
|builder| {
// if x1 == x2
let ys_equal = builder.secp256k1_coord_eq(&y1, &y2);
builder.if_eq(ys_equal, C::N::one()).then_or_else(
|builder| {
// if y1 == y2 => point_1 == point_2, do double
let res_double = builder.secp256k1_double(point_1.clone());
builder.assign(&res, res_double);
},
|builder| {
// else y1 != y2 => x1 = x2, y1 = - y2 so point_1 + point_2 = identity
let identity = builder.array(2 * num_limbs);
for i in 0..2 * num_limbs {
builder.set(&identity, i, C::N::zero());
}
builder.assign(&res, identity)
},
)
},
|builder| {
// if x1 != x2
let res_ne =
builder.secp256k1_add_unequal(point_1.clone(), point_2.clone());
builder.assign(&res, res_ne);
},
)
},
)
},
);
res
}

/// Assumes that `point_1 != +- point_2` which is equivalent to `point_1.x != point_2.x`.
/// Does not handle identity points.
///
/// A point is stored as a tuple of affine coordinates, contiguously in memory as 64 bytes.
pub fn secp256k1_add_unequal(
&mut self,
point_1: Array<C, Var<C::N>>,
point_2: Array<C, Var<C::N>>,
) -> Array<C, Var<C::N>> {
// TODO: enforce this is constant length
let dst = self.array(point_1.len());
self.push(DslIr::Secp256k1AddUnequal(dst.clone(), point_1, point_2));
dst
}

(result_x, result_y)
/// Does not handle identity points.
///
/// A point is stored as a tuple of affine coordinates, contiguously in memory as 64 bytes.
pub fn secp256k1_double(&mut self, point: Array<C, Var<C::N>>) -> Array<C, Var<C::N>> {
let dst = self.array(point.len());
self.push(DslIr::Secp256k1Double(dst.clone(), point));
dst
}

/// Assert (x, y) is on the curve.
Expand Down
18 changes: 18 additions & 0 deletions compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,24 @@ pub enum DslIr<C: Config> {
/// **little-endian**. The `output` is exactly 16 limbs (32 bytes).
Keccak256(Array<C, Var<C::N>>, Array<C, Var<C::N>>),

/// ```ignore
/// Secp256k1AddUnequal(dst, p, q)
/// ```
/// Reads `p,q` from heap and writes `dst = p + q` to heap. A point is represented on the heap
/// as two affine coordinates concatenated together into a byte array.
/// Assumes that `p.x != q.x` which is equivalent to `p != +-q`.
Secp256k1AddUnequal(
Array<C, Var<C::N>>,
Array<C, Var<C::N>>,
Array<C, Var<C::N>>,
),
/// ```ignore
/// Secp256k1Double(dst, p)
/// ```
/// Reads `p` from heap and writes `dst = p + p` to heap. A point is represented on the heap
/// as two affine coordinates concatenated together into a byte array.
Secp256k1Double(Array<C, Var<C::N>>, Array<C, Var<C::N>>),

// Miscellaneous instructions.
/// Prints a variable.
PrintV(Var<C::N>),
Expand Down
1 change: 1 addition & 0 deletions compiler/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub fn execute_program(program: Program<BabyBear>, input_stream: Vec<Vec<BabyBea
num_public_values: 4,
max_segment_len: (1 << 25) - 100,
modular_multiplication_enabled: true,
secp256k1_enabled: true,
bigint_limb_size: 8,
..Default::default()
},
Expand Down
Loading

0 comments on commit 4212486

Please sign in to comment.