Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: non-deterministic array sort #4279

Merged
merged 7 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 1 addition & 59 deletions acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,18 +915,7 @@ namespace Circuit {
static ToLeRadix bincodeDeserialize(std::vector<uint8_t>);
};

struct PermutationSort {
std::vector<std::vector<Circuit::Expression>> inputs;
uint32_t tuple;
std::vector<Circuit::Witness> bits;
std::vector<uint32_t> sort_by;

friend bool operator==(const PermutationSort&, const PermutationSort&);
std::vector<uint8_t> bincodeSerialize() const;
static PermutationSort bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<ToLeRadix, PermutationSort> value;
std::variant<ToLeRadix> value;

friend bool operator==(const Directive&, const Directive&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4960,53 +4949,6 @@ Circuit::Directive::ToLeRadix serde::Deserializable<Circuit::Directive::ToLeRadi
return obj;
}

namespace Circuit {

inline bool operator==(const Directive::PermutationSort &lhs, const Directive::PermutationSort &rhs) {
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.tuple == rhs.tuple)) { return false; }
if (!(lhs.bits == rhs.bits)) { return false; }
if (!(lhs.sort_by == rhs.sort_by)) { return false; }
return true;
}

inline std::vector<uint8_t> Directive::PermutationSort::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<Directive::PermutationSort>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline Directive::PermutationSort Directive::PermutationSort::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<Directive::PermutationSort>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::Directive::PermutationSort>::serialize(const Circuit::Directive::PermutationSort &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.tuple)>::serialize(obj.tuple, serializer);
serde::Serializable<decltype(obj.bits)>::serialize(obj.bits, serializer);
serde::Serializable<decltype(obj.sort_by)>::serialize(obj.sort_by, serializer);
}

template <>
template <typename Deserializer>
Circuit::Directive::PermutationSort serde::Deserializable<Circuit::Directive::PermutationSort>::deserialize(Deserializer &deserializer) {
Circuit::Directive::PermutationSort obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.tuple = serde::Deserializable<decltype(obj.tuple)>::deserialize(deserializer);
obj.bits = serde::Deserializable<decltype(obj.bits)>::deserialize(deserializer);
obj.sort_by = serde::Deserializable<decltype(obj.sort_by)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const Expression &lhs, const Expression &rhs) {
Expand Down
15 changes: 1 addition & 14 deletions acvm-repo/acir/src/circuit/directives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,5 @@ use serde::{Deserialize, Serialize};
/// In the future, this can be replaced with asm non-determinism blocks
pub enum Directive {
//decomposition of a: a=\sum b[i]*radix^i where b is an array of witnesses < radix in little endian form
ToLeRadix {
a: Expression,
b: Vec<Witness>,
radix: u32,
},

// Sort directive, using a sorting network
// This directive is used to generate the values of the control bits for the sorting network such that its outputs are properly sorted according to sort_by
PermutationSort {
inputs: Vec<Vec<Expression>>, // Array of tuples to sort
tuple: u32, // tuple size; if 1 then inputs is a single array [a0,a1,..], if 2 then inputs=[(a0,b0),..] is [a0,b0,a1,b1,..], etc..
bits: Vec<Witness>, // control bits of the network which permutes the inputs into its sorted version
sort_by: Vec<u32>, // specify primary index to sort by, then the secondary,... For instance, if tuple is 2 and sort_by is [1,0], then a=[(a0,b0),..] is sorted by bi and then ai.
},
ToLeRadix { a: Expression, b: Vec<Witness>, radix: u32 },
}
14 changes: 0 additions & 14 deletions acvm-repo/acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

Opcode::BlackBoxFuncCall(g) => write!(f, "{g}"),
Opcode::Directive(Directive::ToLeRadix { a, b, radix: _ }) => {
write!(f, "DIR::TORADIX ")?;

Check warning on line 51 in acvm-repo/acir/src/circuit/opcodes.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (TORADIX)
write!(
f,
// TODO (Note): this assumes that the decomposed bits have contiguous witness indices
Expand All @@ -59,20 +59,6 @@
b.last().unwrap().witness_index(),
)
}
Opcode::Directive(Directive::PermutationSort { inputs: a, tuple, bits, sort_by }) => {
write!(f, "DIR::PERMUTATIONSORT ")?;
write!(
f,
"(permutation size: {} {}-tuples, sort_by: {:#?}, bits: [_{}..._{}]))",
a.len(),
tuple,
sort_by,
// (Note): the bits do not have contiguous index but there are too many for display
bits.first().unwrap().witness_index(),
bits.last().unwrap().witness_index(),
)
}

Opcode::Brillig(brillig) => {
write!(f, "BRILLIG: ")?;
writeln!(f, "inputs: {:?}", brillig.inputs)?;
Expand Down
5 changes: 0 additions & 5 deletions acvm-repo/acvm/src/compiler/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,6 @@ pub(super) fn transform_internal(
transformer.mark_solvable(*witness);
}
}
Directive::PermutationSort { bits, .. } => {
for witness in bits {
transformer.mark_solvable(*witness);
}
}
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode);
Expand Down
37 changes: 0 additions & 37 deletions acvm-repo/acvm/src/pwg/directives/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use std::cmp::Ordering;

use acir::{circuit::directives::Directive, native_types::WitnessMap, FieldElement};
use num_bigint::BigUint;

use crate::OpcodeResolutionError;

use super::{get_value, insert_value, ErrorLocation};

mod sorting;

/// Attempts to solve the [`Directive`] opcode `directive`.
/// If successful, `initial_witness` will be mutated to contain the new witness assignment.
///
Expand Down Expand Up @@ -48,38 +44,5 @@ pub(super) fn solve_directives(

Ok(())
}
Directive::PermutationSort { inputs: a, tuple, bits, sort_by } => {
let mut val_a = Vec::new();
let mut base = Vec::new();
for (i, element) in a.iter().enumerate() {
assert_eq!(element.len(), *tuple as usize);
let mut element_val = Vec::with_capacity(*tuple as usize + 1);
for e in element {
element_val.push(get_value(e, initial_witness)?);
}
let field_i = FieldElement::from(i as i128);
element_val.push(field_i);
base.push(field_i);
val_a.push(element_val);
}
val_a.sort_by(|a, b| {
for i in sort_by {
let int_a = BigUint::from_bytes_be(&a[*i as usize].to_be_bytes());
let int_b = BigUint::from_bytes_be(&b[*i as usize].to_be_bytes());
let cmp = int_a.cmp(&int_b);
if cmp != Ordering::Equal {
return cmp;
}
}
Ordering::Equal
});
let b = val_a.iter().map(|a| *a.last().unwrap()).collect();
let control = sorting::route(base, b);
for (w, value) in bits.iter().zip(control) {
let value = if value { FieldElement::one() } else { FieldElement::zero() };
insert_value(w, value, initial_witness)?;
}
Ok(())
}
}
}
Loading
Loading