Skip to content

Commit

Permalink
feat: added LUT provider stub and plain impl for MemOps
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis authored and dkales committed Sep 2, 2024
1 parent 9b3da88 commit 3d2377f
Show file tree
Hide file tree
Showing 13 changed files with 221 additions and 19 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ eyre = "0.6"
figment = { version = "0.10.19", features = ["toml", "env"] }
futures = "0.3.30"
hex-literal = "0.4.1"
intmap = "2.0.0"
itertools = "0.13.0"
mpc-core = { version = "0.4.0", path = "mpc-core" }
mpc-net = { version = "0.1.2", path = "mpc-net" }
Expand Down
1 change: 1 addition & 0 deletions co-noir/co-acvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rust-version.workspace = true
acir.workspace = true
acvm.workspace = true
eyre.workspace = true
intmap.workspace = true
mpc-core.workspace = true
noirc-abi.workspace = true
noirc-artifacts.workspace = true
Expand Down
22 changes: 18 additions & 4 deletions co-noir/co-acvm/src/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use acir::{
native_types::{WitnessMap, WitnessStack},
AcirField, FieldElement,
};
use intmap::IntMap;
use mpc_core::{
protocols::{
plain::PlainDriver,
Expand All @@ -19,19 +20,18 @@ use std::{io, path::PathBuf};
pub(crate) const CO_EXPRESSION_WIDTH: ExpressionWidth = ExpressionWidth::Bounded { width: 4 };

mod assert_zero_solver;
mod memory_solver;
pub type PlainCoSolver<F> = CoSolver<PlainDriver<F>, F>;
pub type Rep3CoSolver<F, N> = CoSolver<Rep3Protocol<F, N>, F>;

type CoAcvmResult<T> = std::result::Result<T, CoAcvmError>;

#[derive(Debug, thiserror::Error)]
pub enum CoAcvmError {
#[error("Expected at most one mul term, but got {0}")]
TooManyMulTerm(usize),
#[error(transparent)]
IOError(#[from] io::Error),
#[error("unsolvable, too many unknown terms")]
TooManyUnknowns,
#[error(transparent)]
UnrecoverableError(#[from] eyre::Report),
}

pub struct CoSolver<T, F>
Expand All @@ -46,6 +46,8 @@ where
witness_map: Vec<WitnessMap<T::AcvmType>>,
// there will a more fields added as we add functionality
function_index: usize,
// the memory blocks
memory_access: IntMap<T::LUT>,
}

impl<T> CoSolver<T, FieldElement>
Expand Down Expand Up @@ -97,6 +99,7 @@ where
.collect::<Vec<_>>(),
witness_map,
function_index: 0,
memory_access: IntMap::new(),
})
}
}
Expand Down Expand Up @@ -152,6 +155,7 @@ where
T: NoirWitnessExtensionProtocol<F>,
F: AcirField,
{
#[inline(always)]
fn witness(&mut self) -> &mut WitnessMap<T::AcvmType> {
&mut self.witness_map[self.function_index]
}
Expand All @@ -167,6 +171,16 @@ where
for opcode in functions[self.function_index].opcodes.iter() {
match opcode {
Opcode::AssertZero(expr) => self.solve_assert_zero(expr)?,
Opcode::MemoryInit {
block_id,
init,
block_type: _, // apparently not used
} => self.solve_memory_init_block(*block_id, init)?,
Opcode::MemoryOp {
block_id,
op,
predicate,
} => self.solve_memory_op(*block_id, op, predicate.to_owned())?,
_ => todo!("non assert zero opcode detected, not supported yet"),
//Opcode::Call {
// id,
Expand Down
37 changes: 24 additions & 13 deletions co-noir/co-acvm/src/solver/assert_zero_solver.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use acir::{native_types::Expression, AcirField};
use mpc_core::traits::NoirWitnessExtensionProtocol;

use super::{CoAcvmError, CoAcvmResult, CoSolver};
use super::{CoAcvmResult, CoSolver};

impl<T, F> CoSolver<T, F>
where
Expand Down Expand Up @@ -44,20 +44,16 @@ where
let partly_solved = self.driver.acvm_mul_with_public(*c, rhs)?;
acc.linear_combinations.push((partly_solved, *lhs));
}
(None, None) => {
tracing::debug!(
"two unknowns in evaluate mul term. Not solvable for expr: {:?}",
expr
);
return Err(CoAcvmError::TooManyUnknowns);
}
(None, None) => Err(eyre::eyre!(
"two unknowns in evaluate mul term. Not solvable for expr: {:?}",
expr
))?,
};
tracing::trace!("after eval mul term: {acc:?}");
Ok(())
}
} else {
tracing::debug!("more than one mul term found!");
Err(CoAcvmError::TooManyMulTerm(expr.mul_terms.len()))
Err(eyre::eyre!("more than one mul term found!"))?
}
}

Expand All @@ -76,7 +72,7 @@ where
}
}

fn simplify_expression(
pub(crate) fn simplify_expression(
&mut self,
expr: &Expression<F>,
) -> CoAcvmResult<Expression<T::AcvmType>> {
Expand Down Expand Up @@ -119,8 +115,23 @@ where
self.witness().insert(w_l, witness);
Ok(())
} else {
tracing::debug!("too many unknowns. not solvable for expression: {:?}", expr);
Err(CoAcvmError::TooManyUnknowns)
Err(eyre::eyre!(
"too many unknowns. not solvable for expression: {:?}",
expr
))?
}
}

pub(crate) fn evaluate_expression(
&mut self,
expr: &Expression<F>,
) -> CoAcvmResult<T::AcvmType> {
Ok(self
.simplify_expression(expr)?
.to_const()
.cloned()
.ok_or(eyre::eyre!(
"cannot evaluate expression to const - has unknown"
))?)
}
}
97 changes: 97 additions & 0 deletions co-noir/co-acvm/src/solver/memory_solver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use acir::{
circuit::opcodes::{BlockId, MemOp},
native_types::{Expression, Witness},
AcirField,
};
use mpc_core::traits::NoirWitnessExtensionProtocol;

use super::{CoAcvmResult, CoSolver};

impl<T, F> CoSolver<T, F>
where
T: NoirWitnessExtensionProtocol<F>,
F: AcirField,
{
pub(super) fn solve_memory_init_block(
&mut self,
block_id: BlockId,
init: &[Witness],
) -> CoAcvmResult<()> {
// TODO: should we trust on the compiler here?
// this should not be possible so maybe we do not need the check?
if self.memory_access.get(block_id.0.into()).is_some() {
//there is already a block? This should no be possible
Err(eyre::eyre!(
"There is already a block for id {}",
block_id.0
))?;
}
// let get all witnesses
let witness_map = self.witness();
let init = init
.iter()
.map(|witness| witness_map.get(witness).cloned())
.collect::<Option<Vec<_>>>()
.ok_or(eyre::eyre!(
"tried to write not initialized witness to memory - this is a bug"
))?;
let lut = self.driver.init_lut(init);
self.memory_access.insert(block_id.0.into(), lut);
Ok(())
}

pub(super) fn solve_memory_op(
&mut self,
block_id: BlockId,
op: &MemOp<F>,
_predicate: Option<Expression<F>>,
) -> CoAcvmResult<()> {
let index = self.evaluate_expression(&op.index)?;
let value = self.simplify_expression(&op.value)?;
let witness = if value.is_degree_one_univariate() {
//we can get the witness
let (_coef, witness) = &value.linear_combinations[0];
let _q_c = value.q_c;
Ok(*witness)
//todo check if coef is one and q_c is zero!
} else {
Err(eyre::eyre!(
"value for mem op must be a degree one univariate polynomial"
))
}?;
//TODO CHECK PREDICATE - do we need to cmux here?
if op.operation.q_c.is_zero() {
// read the value from the LUT
let lut = self
.memory_access
.get(block_id.0.into())
.ok_or(eyre::eyre!(
"tried to access block {} but not present",
block_id.0
))?;
let value = self.driver.get_from_lut(&index, lut);
self.witness().insert(witness, value);
} else if op.operation.q_c.is_one() {
// write value to LUT
let value = self
.witness()
.get(&witness)
.cloned()
.ok_or(eyre::eyre!("Trying to write unknown witness in mem block"))?;
let lut = self
.memory_access
.get_mut(block_id.0.into())
.ok_or(eyre::eyre!(
"tried to access block {} but not present",
block_id.0
))?;
self.driver.write_to_lut(index, value, lut);
} else {
Err(eyre::eyre!(
"Got unknown operation {} for mem op - this is a bug",
op.operation.q_c
))?
}
Ok(())
}
}
25 changes: 24 additions & 1 deletion mpc-core/src/protocols/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
//!
//! This module contains the reference implementation without MPC. It will be used by the VM for computing on public values and can be used to test MPC circuits.
use std::collections::HashMap;

use crate::{
traits::{
CircomWitnessExtensionProtocol, EcMpcProtocol, FFTProvider, FieldShareVecTrait,
MSMProvider, NoirWitnessExtensionProtocol, PairingEcMpcProtocol, PrimeFieldMpcProtocol,
LookupTableProvider, MSMProvider, NoirWitnessExtensionProtocol, PairingEcMpcProtocol,
PrimeFieldMpcProtocol,
},
RngType,
};
Expand Down Expand Up @@ -634,3 +637,23 @@ impl<F: AcirField> NoirWitnessExtensionProtocol<F> for PlainDriver<F> {
Ok(-c / q_l)
}
}

impl<F: AcirField> LookupTableProvider<F> for PlainDriver<F> {
type LUT = HashMap<F, F>;

fn init_lut(&mut self, values: Vec<F>) -> Self::LUT {
let mut lut = HashMap::with_capacity(values.len());
for (idx, value) in values.into_iter().enumerate() {
lut.insert(F::from(idx), value);
}
lut
}

fn get_from_lut(&mut self, index: &F, lut: &Self::LUT) -> F {
lut[index]
}

fn write_to_lut(&mut self, index: F, value: F, lut: &mut Self::LUT) {
lut.insert(index, value);
}
}
23 changes: 23 additions & 0 deletions mpc-core/src/protocols/rep3/acvm_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use std::io;
use acir::AcirField;
use eyre::bail;

use crate::traits::LookupTableProvider;

use super::{id::PartyID, network::Rep3Network, Rep3Protocol};

//TODO maybe merge this with the VM type? Can be generic over F and then either use AcirField or PrimeField
Expand Down Expand Up @@ -395,3 +397,24 @@ impl<F: AcirField, N: Rep3Network> Rep3Protocol<F, N> {
Ok(a.a + a.b + c)
}
}

impl<F: AcirField, N: Rep3Network> LookupTableProvider<Rep3AcvmType<F>> for Rep3Protocol<F, N> {
type LUT = ();

fn init_lut(&mut self, _values: Vec<Rep3AcvmType<F>>) -> Self::LUT {
todo!()
}

fn get_from_lut(&mut self, _index: &Rep3AcvmType<F>, _lut: &Self::LUT) -> Rep3AcvmType<F> {
todo!()
}

fn write_to_lut(
&mut self,
_index: Rep3AcvmType<F>,
_value: Rep3AcvmType<F>,
_lut: &mut Self::LUT,
) {
todo!()
}
}
16 changes: 15 additions & 1 deletion mpc-core/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,26 @@ pub trait PrimeFieldMpcProtocol<F: PrimeField> {
) -> std::io::Result<Vec<F>>;
}

/// This is some place holder definition. This will change most likely
pub trait LookupTableProvider<F> {
/// A type that holds the data of the LUT.
type LUT;

/// Initializes a new LUT from the provided values. The index shall be the order
/// of the values in the `Vec`.
fn init_lut(&mut self, values: Vec<F>) -> Self::LUT;
/// Reads a value from the LUT.
fn get_from_lut(&mut self, index: &F, lut: &Self::LUT) -> F;
/// Writes a value to the LUT.
fn write_to_lut(&mut self, index: F, value: F, lut: &mut Self::LUT);
}

/// A trait representing the MPC operations required for extending the secret-shared Noir witness in MPC.
/// The operations are generic over public and private (i.e., secret-shared) inputs.
/// In contrast to the other traits, we have to be generic over [`AcirField`], as the ACVM wraps
/// the [`PrimeField`] of arkworks in another trait. This may be subject to change if we add functionality as
/// we have to implement a lot of the stuff twice.
pub trait NoirWitnessExtensionProtocol<F: AcirField> {
pub trait NoirWitnessExtensionProtocol<F: AcirField>: LookupTableProvider<Self::AcvmType> {
/// A type representing the values encountered during Circom compilation. It should at least contain public field elements and shared values.
type AcvmType: Clone + Default + fmt::Debug + fmt::Display + From<F>;

Expand Down
7 changes: 7 additions & 0 deletions test_vectors/noir/slice/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "slice"
type = "bin"
authors = [""]
compiler_version = ">=0.33.0"

[dependencies]
1 change: 1 addition & 0 deletions test_vectors/noir/slice/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vk = ["1", "2"]
Binary file added test_vectors/noir/slice/kat/slice.gz
Binary file not shown.
1 change: 1 addition & 0 deletions test_vectors/noir/slice/kat/slice.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"noir_version":"0.33.0+325dac54efb6f99201de9fdeb0a507d45189607d","hash":11907287274716899065,"abi":{"parameters":[{"name":"vk","type":{"kind":"array","length":2,"type":{"kind":"field"}},"visibility":"private"}],"return_type":null,"error_types":{}},"bytecode":"H4sIAAAAAAAA/9WWbQ7CIAyGYTAXb9PyMco/ryIZ3P8IOsfizDQxDhJ5EtKGH81LWwqcLQz3dc6+zLZjTzjbs+5dstUwGhOdiqjxCsoHsmBsGAkJLdlJkdaRDDkfvAOPRkdM1usEC90mFhxEsPeUiv8pB3AMLJmDWhpFBY1VGkoWPnSLDSUb0NiX1CizwHV6zYXrNz5vsIgtTIVTSY2tFqqF2zawSuNbFBb6p935ksBfn6iYHkxzcb/9b+24AResqhe9CQAA","debug_symbols":"pZHBDoMgEET/Zc9cUBDLrzRNg4oNCQEj2KQh/nuh1aa1Xoi3nd15k00mQCeb6XZVprcO+DmAtq3wypqoAuDXyg3CJOW8GD1wTCgCabo40XpG0CstgVM2XxAUuUCZC5BcgOYC1S5Q4g9w+gbQn7UqFyfDm2R2MJmRxVkXm+R6N7lYP8EV+QGiuItRiUbL1HO6TaZda4/SP4b3JXqf","file_map":{"57":{"source":"use dep::std;\n\nfn flatten_slice(slice: [Field]) -> Field {\n slice[0] + slice[1]\n}\n\nfn main(vk: [Field; 2]) {\n assert(vk[0] + flatten_slice(vk.as_slice()) == 4);\n}\n","path":"/home/fnieddu/repos/collaborative-circom/test_vectors/noir/slice/src/main.nr"}},"names":["main"]}
9 changes: 9 additions & 0 deletions test_vectors/noir/slice/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use dep::std;

fn flatten_slice(slice: [Field]) -> Field {
slice[0] + slice[1]
}

fn main(vk: [Field; 2]) {
assert(vk[0] + flatten_slice(vk.as_slice()) == 4);
}

0 comments on commit 3d2377f

Please sign in to comment.