Skip to content

Commit

Permalink
change to inner product (#1280)
Browse files Browse the repository at this point in the history
* change to inner product

* Update baby_bear.rs

* Update special_inner_product

* Update baby_bear.rs

* Update extensions/native/compiler/src/constraints/halo2/baby_bear.rs

---------

Co-authored-by: Jonathan Wang <[email protected]>
  • Loading branch information
MonkeyKing-1 and jonathanpwang authored Jan 26, 2025
1 parent 854ffd0 commit 352ca42
Showing 1 changed file with 64 additions and 10 deletions.
74 changes: 64 additions & 10 deletions extensions/native/compiler/src/constraints/halo2/baby_bear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,66 @@ impl BabyBearChip {
c
}

// This inner product function will be used exclusively for optimizing extension element multiplication.
fn special_inner_product(
&self,
ctx: &mut Context<Fr>,
a: &mut [AssignedBabyBear],
b: &mut [AssignedBabyBear],
s: usize,
) -> AssignedBabyBear {
assert!(a.len() == b.len());
assert!(a.len() == 4);
let mut max_bits = 0;
let lb = if s > 3 { s - 3 } else { 0 };
let ub = 4.min(s + 1);
let range = lb..ub;
let other_range = (s + 1 - ub)..(s + 1 - lb);
let len = if s < 3 { s + 1 } else { 7 - s };
for (i, (c, d)) in a[range.clone()]
.iter_mut()
.zip(b[other_range.clone()].iter_mut().rev())
.enumerate()
{
if c.max_bits + d.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS - len + i {
if c.max_bits >= d.max_bits {
*c = self.reduce(ctx, *c);
if c.max_bits + d.max_bits
> Fr::CAPACITY as usize - RESERVED_HIGH_BITS - len + i
{
*d = self.reduce(ctx, *d);
}
} else {
*d = self.reduce(ctx, *d);
if c.max_bits + d.max_bits
> Fr::CAPACITY as usize - RESERVED_HIGH_BITS - len + i
{
*c = self.reduce(ctx, *c);
}
}
}
if i == 0 {
max_bits = c.max_bits + d.max_bits;
} else {
max_bits = max_bits.max(c.max_bits + d.max_bits) + 1
}
}
let a_raw = a[range]
.iter()
.map(|a| QuantumCell::Existing(a.value))
.collect_vec();
let b_raw = b[other_range]
.iter()
.rev()
.map(|b| QuantumCell::Existing(b.value))
.collect_vec();
let prod = self.gate().inner_product(ctx, a_raw, b_raw);
AssignedBabyBear {
value: prod,
max_bits,
}
}

pub fn select(
&self,
ctx: &mut Context<Fr>,
Expand Down Expand Up @@ -488,18 +548,12 @@ impl BabyBearExt4Chip {
pub fn mul(
&self,
ctx: &mut Context<Fr>,
a: AssignedBabyBearExt4,
b: AssignedBabyBearExt4,
mut a: AssignedBabyBearExt4,
mut b: AssignedBabyBearExt4,
) -> AssignedBabyBearExt4 {
let mut coeffs = Vec::with_capacity(7);
for i in 0..4 {
for j in 0..4 {
if i + j < coeffs.len() {
coeffs[i + j] = self.base.mul_add(ctx, a.0[i], b.0[j], coeffs[i + j]);
} else {
coeffs.push(self.base.mul(ctx, a.0[i], b.0[j]));
}
}
for s in 0..7 {
coeffs.push(self.base.special_inner_product(ctx, &mut a.0, &mut b.0, s));
}
let w = self
.base
Expand Down

0 comments on commit 352ca42

Please sign in to comment.