Skip to content

Commit

Permalink
chore: Less cloning (#420)
Browse files Browse the repository at this point in the history
* refactor: Avoid unnecessary cloning of lang parameter

- Update multiframes creation to take a reference
- Add derive traits to HashArity enum in hash.rs for better usability

* refactor: Reduce cloning in LEM

- Enhance duplicate key error message in interpreter.rs to not need a clone
- Modify parameter type for `deconflict` function and its calls in mod.rs
- modification of helper functions in Symbol inspired by [C-CALLER-CONTROL](https://rust-lang.github.io/api-guidelines/flexibility.html#caller-decides-where-to-copy-and-place-data-c-caller-control)

* refactor: remove uneeded clones in circuit_frame

- take by reference where possible (and easy)
  • Loading branch information
huitseeker authored and porcuquine committed Jun 6, 2023
1 parent f92f9cd commit be75f31
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 59 deletions.
3 changes: 1 addition & 2 deletions benches/synthesis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ fn synthesize<M: measurement::Measurement>(
.unwrap();

let multiframe =
MultiFrame::from_frames(*reduction_count, &frames, &store, lang_rc.clone())[0]
.clone();
MultiFrame::from_frames(*reduction_count, &frames, &store, &lang_rc)[0].clone();

b.iter_batched(
|| (multiframe.clone()), // avoid cloning the frames in the benchmark
Expand Down
72 changes: 36 additions & 36 deletions src/circuit/circuit_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl<'a, F: LurkField, T: Clone + Copy + std::cmp::PartialEq, W: Copy, C: Coproc
count: usize,
frames: &[Frame<T, W, C>],
store: &'a Store<F>,
lang: Arc<Lang<F, C>>,
lang: &Arc<Lang<F, C>>,
) -> Vec<Self> {
// `count` is the number of `Frames` to include per `MultiFrame`.
let total_frames = frames.len();
Expand Down Expand Up @@ -3991,7 +3991,7 @@ fn apply_continuation<F: LurkField, CS: ConstraintSystem<F>>(

let (alloc_q, alloc_r) = enforce_u64_div_mod(
&mut cs.namespace(|| "u64 div mod equation"),
op2_is_mod.clone(),
&op2_is_mod,
&arg1,
arg2,
)?;
Expand Down Expand Up @@ -4717,15 +4717,15 @@ fn to_unsigned_integer_helper<F: LurkField, CS: ConstraintSystem<F>>(
mut cs: CS,
g: &GlobalAllocations<F>,
field_elem: &AllocatedNum<F>,
field_bn: BigUint,
field_bn: &BigUint,
field_elem_bits: &[Boolean],
size: UnsignedInt,
) -> Result<AllocatedNum<F>, SynthesisError> {
let power_of_two_bn = BigUint::pow(&BigUint::from_u32(2).unwrap(), size.num_bits());

let (q_bn, r_bn) = field_bn.div_rem(&power_of_two_bn);
let q_num = allocate_unconstrained_bignum(&mut cs.namespace(|| "q"), q_bn)?;
let r_num = allocate_unconstrained_bignum(&mut cs.namespace(|| "r"), r_bn)?;
let q_num = allocate_unconstrained_bignum(&mut cs.namespace(|| "q"), &q_bn)?;
let r_num = allocate_unconstrained_bignum(&mut cs.namespace(|| "r"), &r_bn)?;
let pow2_size = match size {
UnsignedInt::U32 => &g.power2_32_num,
UnsignedInt::U64 => &g.power2_64_num,
Expand Down Expand Up @@ -4772,15 +4772,15 @@ fn to_unsigned_integers<F: LurkField, CS: ConstraintSystem<F>>(
&mut cs.namespace(|| "enforce u32"),
g,
maybe_unsigned,
field_bn.clone(),
&field_bn,
&field_elem_bits,
UnsignedInt::U32,
)?;
let r64_num = to_unsigned_integer_helper(
&mut cs.namespace(|| "enforce u64"),
g,
maybe_unsigned,
field_bn,
&field_bn,
&field_elem_bits,
UnsignedInt::U64,
)?;
Expand All @@ -4802,7 +4802,7 @@ fn to_u64<F: LurkField, CS: ConstraintSystem<F>>(
&mut cs.namespace(|| "enforce u64"),
g,
maybe_u64,
field_bn,
&field_bn,
&field_elem_bits,
UnsignedInt::U64,
)?;
Expand All @@ -4814,7 +4814,7 @@ fn to_u64<F: LurkField, CS: ConstraintSystem<F>>(
// arg1 = q * arg2 + r, such that 0 <= r < arg2.
fn enforce_u64_div_mod<F: LurkField, CS: ConstraintSystem<F>>(
mut cs: CS,
cond: Boolean,
cond: &Boolean,
arg1: &AllocatedPtr<F>,
arg2: &AllocatedPtr<F>,
) -> Result<(AllocatedNum<F>, AllocatedNum<F>), SynthesisError> {
Expand Down Expand Up @@ -4862,7 +4862,7 @@ fn enforce_u64_div_mod<F: LurkField, CS: ConstraintSystem<F>>(
let b_is_not_zero_and_cond = Boolean::and(
&mut cs.namespace(|| "b is not zero and cond"),
&b_is_zero.not(),
&cond,
cond,
)?;
enforce_implication(
&mut cs.namespace(|| "enforce u64 mod decomposition"),
Expand All @@ -4873,8 +4873,8 @@ fn enforce_u64_div_mod<F: LurkField, CS: ConstraintSystem<F>>(
enforce_less_than_bound(
&mut cs.namespace(|| "remainder in range b"),
cond,
alloc_r_num.clone(),
alloc_arg2_num,
&alloc_r_num,
&alloc_arg2_num,
)?;

Ok((alloc_q_num, alloc_r_num))
Expand All @@ -4886,11 +4886,11 @@ fn enforce_u64_div_mod<F: LurkField, CS: ConstraintSystem<F>>(
// `cond` is a Boolean condition that enforces the validation iff it is True.
fn enforce_less_than_bound<F: LurkField, CS: ConstraintSystem<F>>(
mut cs: CS,
cond: Boolean,
num: AllocatedNum<F>,
bound: AllocatedNum<F>,
cond: &Boolean,
num: &AllocatedNum<F>,
bound: &AllocatedNum<F>,
) -> Result<(), SynthesisError> {
let diff_bound_num = sub(&mut cs.namespace(|| "bound minus num"), &bound, &num)?;
let diff_bound_num = sub(&mut cs.namespace(|| "bound minus num"), bound, num)?;

let diff_bound_num_is_negative = allocate_is_negative(
&mut cs.namespace(|| "diff bound num is negative"),
Expand All @@ -4899,7 +4899,7 @@ fn enforce_less_than_bound<F: LurkField, CS: ConstraintSystem<F>>(

enforce_implication(
&mut cs.namespace(|| "enforce u64 range"),
&cond,
cond,
&diff_bound_num_is_negative.not(),
)
}
Expand All @@ -4911,7 +4911,7 @@ fn enforce_less_than_bound<F: LurkField, CS: ConstraintSystem<F>>(
// after dividing by 2ˆ64. Therefore we constrain this relation afterwards.
fn allocate_unconstrained_bignum<F: LurkField, CS: ConstraintSystem<F>>(
mut cs: CS,
bn: BigUint,
bn: &BigUint,
) -> Result<AllocatedNum<F>, SynthesisError> {
let bytes_le = bn.to_bytes_le();
let mut bytes_padded = [0u8; 32];
Expand Down Expand Up @@ -5297,7 +5297,7 @@ mod tests {
_p: Default::default(),
}],
store,
lang.clone(),
&lang,
);

let multiframe = &multiframes[0];
Expand Down Expand Up @@ -5413,7 +5413,7 @@ mod tests {
DEFAULT_REDUCTION_COUNT,
&[frame],
store,
lang.clone(),
&lang,
)[0]
.clone()
.synthesize(&mut cs)
Expand Down Expand Up @@ -5494,7 +5494,7 @@ mod tests {
DEFAULT_REDUCTION_COUNT,
&[frame],
store,
lang.clone(),
&lang,
)[0]
.clone()
.synthesize(&mut cs)
Expand Down Expand Up @@ -5576,7 +5576,7 @@ mod tests {
DEFAULT_REDUCTION_COUNT,
&[frame],
store,
lang.clone(),
&lang,
)[0]
.clone()
.synthesize(&mut cs)
Expand Down Expand Up @@ -5659,7 +5659,7 @@ mod tests {
DEFAULT_REDUCTION_COUNT,
&[frame],
store,
lang.clone(),
&lang,
)[0]
.clone()
.synthesize(&mut cs)
Expand Down Expand Up @@ -5757,9 +5757,9 @@ mod tests {

let res = enforce_less_than_bound(
&mut cs.namespace(|| "enforce less than bound"),
cond,
alloc_num,
alloc_most_positive,
&cond,
&alloc_num,
&alloc_most_positive,
);
assert!(res.is_ok());
assert!(cs.is_satisfied());
Expand All @@ -5776,9 +5776,9 @@ mod tests {

let res = enforce_less_than_bound(
&mut cs.namespace(|| "enforce less than bound"),
cond,
alloc_num,
alloc_bound,
&cond,
&alloc_num,
&alloc_bound,
);
assert!(res.is_ok());
assert!(cs.is_satisfied());
Expand All @@ -5795,9 +5795,9 @@ mod tests {

let res = enforce_less_than_bound(
&mut cs.namespace(|| "enforce less than bound"),
cond,
alloc_num,
alloc_bound,
&cond,
&alloc_num,
&alloc_bound,
);
assert!(res.is_ok());
assert!(!cs.is_satisfied());
Expand All @@ -5817,7 +5817,7 @@ mod tests {

let (q, r) = enforce_u64_div_mod(
&mut cs.namespace(|| "enforce u64 div mod"),
cond,
&cond,
&alloc_a,
&alloc_b,
)
Expand All @@ -5841,7 +5841,7 @@ mod tests {

let (q, r) = enforce_u64_div_mod(
&mut cs.namespace(|| "enforce u64 div mod"),
cond,
&cond,
&alloc_a,
&alloc_b,
)
Expand Down Expand Up @@ -5900,7 +5900,7 @@ mod tests {
&mut cs,
&g,
&a_plus_power2_32_num,
field_bn,
&field_bn,
&bits,
UnsignedInt::U32,
)
Expand Down Expand Up @@ -5931,7 +5931,7 @@ mod tests {
&mut cs,
&g,
&a_plus_power2_64_num,
field_bn,
&field_bn,
&bits,
UnsignedInt::U64,
)
Expand Down
8 changes: 4 additions & 4 deletions src/circuit/gadgets/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,12 +544,12 @@ pub(crate) fn alloc_is_zero<CS: ConstraintSystem<F>, F: PrimeField>(
cs: CS,
x: &AllocatedNum<F>,
) -> Result<Boolean, SynthesisError> {
alloc_num_is_zero(cs, Num::from(x.clone()))
alloc_num_is_zero(cs, &Num::from(x.clone()))
}

pub(crate) fn alloc_num_is_zero<CS: ConstraintSystem<F>, F: PrimeField>(
mut cs: CS,
num: Num<F>,
num: &Num<F>,
) -> Result<Boolean, SynthesisError> {
let num_value = num.get_value();
let x = num_value.unwrap_or(F::ZERO);
Expand Down Expand Up @@ -618,7 +618,7 @@ pub(crate) fn or_v_unchecked_for_optimization<CS: ConstraintSystem<F>, F: PrimeF

// If the number of true values is zero, then none of the values is true.
// Therefore, nor(v0, v1, ..., vn) is true.
let nor = alloc_num_is_zero(&mut cs.namespace(|| "nor"), count_true)?;
let nor = alloc_num_is_zero(&mut cs.namespace(|| "nor"), &count_true)?;

Ok(nor.not())
}
Expand All @@ -639,7 +639,7 @@ pub(crate) fn and_v<CS: ConstraintSystem<F>, F: PrimeField>(

// If the number of false values is zero, then all of the values are true.
// Therefore, and(v0, v1, ..., vn) is true.
let and = alloc_num_is_zero(&mut cs.namespace(|| "nor_of_nots"), count_false)?;
let and = alloc_num_is_zero(&mut cs.namespace(|| "nor_of_nots"), &count_false)?;

Ok(and)
}
Expand Down
1 change: 1 addition & 0 deletions src/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use generic_array::typenum::{U3, U4, U6, U8};
use neptune::{poseidon::PoseidonConstants, Poseidon};
use once_cell::sync::OnceCell;

#[derive(Debug, Clone, Copy)]
pub enum HashArity {
A3,
A4,
Expand Down
6 changes: 4 additions & 2 deletions src/lem/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ fn insert_into_ptrs<F: LurkField>(
key: String,
value: Ptr<F>,
) -> Result<()> {
if ptrs.insert(key.clone(), value).is_some() {
bail!("{} already defined", key);
let mut msg = "Key already defined: ".to_owned();
msg.push_str(&key);
if ptrs.insert(key, value).is_some() {
bail!("{msg}");
}
Ok(())
}
Expand Down
9 changes: 5 additions & 4 deletions src/lem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl LEMOP {
/// `LEM::new`, which is the API that should be used directly.
pub fn deconflict(
&self,
path: String,
path: &str,
dmap: &DashMap<String, String, ahash::RandomState>, // name -> path/name
) -> Result<Self> {
match self {
Expand Down Expand Up @@ -255,7 +255,8 @@ impl LEMOP {
let mut new_cases = vec![];
for (tag, case) in cases {
// each case needs it's own clone of `dmap`
let new_case = case.deconflict(format!("{}.{}", &path, &tag), &dmap.clone())?;
let new_case =
case.deconflict(&format!("{}.{}", &path, &tag), &dmap.clone())?;
new_cases.push((*tag, new_case));
}
Ok(LEMOP::MatchTag(
Expand All @@ -266,7 +267,7 @@ impl LEMOP {
LEMOP::Seq(ops) => {
let mut new_ops = vec![];
for op in ops {
new_ops.push(op.deconflict(path.clone(), dmap)?);
new_ops.push(op.deconflict(path, dmap)?);
}
Ok(LEMOP::Seq(new_ops))
}
Expand Down Expand Up @@ -338,7 +339,7 @@ impl LEM {
let dmap = DashMap::from_iter(input.map(|i| (i.to_string(), i.to_string())));
Ok(LEM {
input: input.map(|i| i.to_string()),
lem_op: lem_op.deconflict(String::new(), &dmap)?,
lem_op: lem_op.deconflict("", &dmap)?,
})
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/lem/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ pub enum Symbol {
}

impl Symbol {
pub fn sym(path: Vec<&str>) -> Symbol {
Symbol::Sym(path.iter().map(|x| x.to_string()).collect())
pub fn sym(path: &[String]) -> Symbol {
Symbol::Sym(path.into())
}

pub fn key(path: Vec<&str>) -> Symbol {
Symbol::Key(path.iter().map(|x| x.to_string()).collect())
pub fn key(path: &[String]) -> Symbol {
Symbol::Key(path.into())
}

#[inline]
Expand Down
5 changes: 2 additions & 3 deletions src/proof/groth16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ impl<C: Coprocessor<Scalar>> Groth16Prover<Bls12, C, Scalar> {
let frames = Evaluator::generate_frames(expr, env, store, limit, padding_predicate, &lang)?;
store.hydrate_scalar_cache();

let multiframes =
MultiFrame::from_frames(self.reduction_count(), &frames, store, lang.clone());
let multiframes = MultiFrame::from_frames(self.reduction_count(), &frames, store, &lang);
let mut proofs = Vec::with_capacity(multiframes.len());
let mut statements = Vec::with_capacity(multiframes.len());

Expand Down Expand Up @@ -408,7 +407,7 @@ mod tests {
s.hydrate_scalar_cache();

let multi_frames =
MultiFrame::from_frames(DEFAULT_REDUCTION_COUNT, &frames, s, lang_rc.clone());
MultiFrame::from_frames(DEFAULT_REDUCTION_COUNT, &frames, s, &lang_rc);

let cs = groth_prover.outer_synthesize(&multi_frames).unwrap();

Expand Down
Loading

0 comments on commit be75f31

Please sign in to comment.