diff --git a/Cargo.lock b/Cargo.lock index f5a360067ce..145adba6a7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,7 +31,9 @@ dependencies = [ "ark-bls12-381", "ark-bn254 0.5.0", "ark-ff 0.5.0", + "ark-std 0.5.0", "cfg-if", + "criterion", "hex", "num-bigint", "proptest", diff --git a/acvm-repo/acir_field/Cargo.toml b/acvm-repo/acir_field/Cargo.toml index bc4be250cfd..78602eefe3c 100644 --- a/acvm-repo/acir_field/Cargo.toml +++ b/acvm-repo/acir_field/Cargo.toml @@ -23,12 +23,18 @@ serde.workspace = true ark-bn254.workspace = true ark-bls12-381 = { workspace = true, optional = true } ark-ff.workspace = true +ark-std.workspace = true cfg-if.workspace = true [dev-dependencies] proptest.workspace = true +criterion.workspace = true [features] bn254 = [] bls12_381 = ["dep:ark-bls12-381"] + +[[bench]] +name = "field_element" +harness = false diff --git a/acvm-repo/acir_field/benches/field_element.rs b/acvm-repo/acir_field/benches/field_element.rs new file mode 100644 index 00000000000..6666e6dd9a1 --- /dev/null +++ b/acvm-repo/acir_field/benches/field_element.rs @@ -0,0 +1,11 @@ +use acir_field::{AcirField, FieldElement}; +use criterion::{criterion_group, criterion_main, Criterion}; +use std::hint::black_box; + +fn criterion_benchmark(c: &mut Criterion) { + let field_element = FieldElement::from(123456789_u128); + c.bench_function("FieldElement::num_bits", |b| b.iter(|| black_box(field_element).num_bits())); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/acvm-repo/acir_field/src/field_element.rs b/acvm-repo/acir_field/src/field_element.rs index 0249b410aa7..e53fb760476 100644 --- a/acvm-repo/acir_field/src/field_element.rs +++ b/acvm-repo/acir_field/src/field_element.rs @@ -1,5 +1,6 @@ use ark_ff::PrimeField; use ark_ff::Zero; +use ark_std::io::Write; use num_bigint::BigUint; use serde::{Deserialize, Serialize}; use std::borrow::Cow; @@ -195,26 +196,9 @@ impl AcirField for FieldElement { /// This is the number of bits required to represent this specific field element fn num_bits(&self) -> u32 { - let bytes = self.to_be_bytes(); - - // Iterate through the byte decomposition and pop off all leading zeroes - let mut iter = bytes.iter().skip_while(|x| (**x) == 0); - - // The first non-zero byte in the decomposition may have some leading zero-bits. - let Some(head_byte) = iter.next() else { - // If we don't have a non-zero byte then the field element is zero, - // which we consider to require a single bit to represent. - return 1; - }; - let num_bits_for_head_byte = head_byte.ilog2(); - - // Each remaining byte in the byte decomposition requires 8 bits. - // - // Note: count will panic if it goes over usize::MAX. - // This may not be suitable for devices whose usize < u16 - let tail_length = iter.count() as u32; - - 8 * tail_length + num_bits_for_head_byte + 1 + let mut bit_counter = BitCounter::default(); + self.0.serialize_uncompressed(&mut bit_counter).unwrap(); + bit_counter.bits() } fn to_u128(self) -> u128 { @@ -354,6 +338,52 @@ impl SubAssign for FieldElement { } } +#[derive(Default, Debug)] +struct BitCounter { + /// Total number of non-zero bytes we found. + count: usize, + /// Total bytes we found. + total: usize, + /// The last non-zero byte we found. + head_byte: u8, +} + +impl BitCounter { + fn bits(&self) -> u32 { + // If we don't have a non-zero byte then the field element is zero, + // which we consider to require a single bit to represent. + if self.count == 0 { + return 1; + } + + let num_bits_for_head_byte = self.head_byte.ilog2(); + + // Each remaining byte in the byte decomposition requires 8 bits. + // + // Note: count will panic if it goes over usize::MAX. + // This may not be suitable for devices whose usize < u16 + let tail_length = (self.count - 1) as u32; + 8 * tail_length + num_bits_for_head_byte + 1 + } +} + +impl Write for BitCounter { + fn write(&mut self, buf: &[u8]) -> ark_std::io::Result { + for byte in buf { + self.total += 1; + if *byte != 0 { + self.count = self.total; + self.head_byte = *byte; + } + } + Ok(buf.len()) + } + + fn flush(&mut self) -> ark_std::io::Result<()> { + Ok(()) + } +} + #[cfg(test)] mod tests { use super::{AcirField, FieldElement};