diff --git a/.gitignore b/.gitignore index caaf8d2..a233986 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,4 @@ packages/*/build package-lock.json yarn.lock +index.node diff --git a/packages/compiler/src/circom.rs b/packages/compiler/src/circom.rs index b85ee7a..24c934e 100644 --- a/packages/compiler/src/circom.rs +++ b/packages/compiler/src/circom.rs @@ -31,11 +31,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) rev_graph.get_mut(k).unwrap().insert(i, chars.clone()); if i == 0 { - if let Some(index) = chars.iter().position(|&x| x == 94) { - init_going_state = Some(*k); - rev_graph.get_mut(&k).unwrap().get_mut(&i).unwrap()[index] = 255; - } - for j in rev_graph.get(&k).unwrap().get(&i).unwrap() { if *j == 255 { continue; @@ -83,9 +78,9 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) let accept_nodes_array: Vec = accept_nodes.into_iter().collect(); - if accept_nodes_array.len() != 1 { - panic!("The size of accept nodes must be one"); - } + // if accept_nodes_array.len() != 1 { + // panic!("The size of accept nodes must be one"); + // } let mut eq_i = 0; let mut lt_i = 0; @@ -96,25 +91,17 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) let mut eq_checks = vec![None; 256]; let mut multi_or_checks1 = BTreeMap::::new(); let mut multi_or_checks2 = BTreeMap::::new(); - let mut zero_starting_states = vec![]; - let mut zero_starting_and_idxes = BTreeMap::>::new(); let mut lines = vec![]; - // let mut zero_starting_lines = vec![]; - lines.push("\tfor (var i = 0; i < num_bytes; i++) {".to_string()); - lines.push(format!("\t\tstate_changed[i] = MultiOR({});", n - 1)); - lines.push(format!("\t\tstates[i][0] <== 1;")); + lines.push("\tfor (var i = 0; i < msg_bytes; i++) {".to_string()); + lines.push("\t\tstates[i+1][0] <== 0;".to_string()); + for i in 1..n { let mut outputs = vec![]; - zero_starting_and_idxes.insert(i, vec![]); - // let mut state_change_lines = vec![]; for (prev_i, k) in rev_graph.get(&(i as usize)).unwrap().iter() { let prev_i_num = *prev_i; - if prev_i_num == 0 { - zero_starting_states.push(i); - } let mut k = k.clone(); k.sort(); @@ -196,6 +183,7 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) } } } + lines.push(format!("\t\tand[{}][i] = AND();", and_i)); lines.push(format!( "\t\tand[{}][i].a <== states[i][{}];", @@ -207,9 +195,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) "\t\tand[{}][i].b <== {}[{}][i].out;", and_i, eq_outputs[0].0, eq_outputs[0].1 )); - if prev_i_num == 0 { - zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i); - } } else if eq_outputs.len() > 1 { let eq_outputs_key = serde_json::to_string(&eq_outputs).unwrap(); @@ -231,9 +216,6 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) "\t\tand[{}][i].b <== multi_or[{}][i].out;", and_i, multi_or_i )); - if prev_i_num == 0 { - zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i); - } multi_or_checks1.insert(eq_outputs_key, multi_or_i); multi_or_i += 1; } else { @@ -242,29 +224,19 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) "\t\tand[{}][i].b <== multi_or[{}][i].out;", and_i, multi_or_i )); - if prev_i_num == 0 { - zero_starting_and_idxes.get_mut(&i).unwrap().push(and_i); - } } } } - if prev_i_num != 0 { - outputs.push(and_i); - } + + outputs.push(and_i); and_i += 1; } + if outputs.len() == 1 { - if zero_starting_states.contains(&i) { - lines.push(format!( - "\t\tstates_tmp[i+1][{}] <== and[{}][i].out;", - i, outputs[0] - )); - } else { - lines.push(format!( - "\t\tstates[i+1][{}] <== and[{}][i].out;", - i, outputs[0] - )); - } + lines.push(format!( + "\t\tstates[i+1][{}] <== and[{}][i].out;", + i, outputs[0] + )); } else if outputs.len() > 1 { let outputs_key = serde_json::to_string(&outputs).unwrap(); @@ -281,87 +253,34 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) multi_or_i, output_i, and_i )); } - if zero_starting_states.contains(&i) { - lines.push(format!( - "\t\tstates_tmp[i+1][{}] <== multi_or[{}][i].out;", - i, multi_or_i - )); - } else { - lines.push(format!( - "\t\tstates[i+1][{}] <== multi_or[{}][i].out;", - i, multi_or_i - )); - } + + lines.push(format!( + "\t\tstates[i+1][{}] <== multi_or[{}][i].out;", + i, multi_or_i + )); multi_or_checks2.insert(outputs_key, multi_or_i); multi_or_i += 1; } else { - if let Some(multi_or_i) = multi_or_checks2.get(&outputs_key) { - if zero_starting_states.contains(&i) { - lines.push(format!( - "\t\tstates_tmp[i+1][{}] <== multi_or[{}][i].out;", - i, multi_or_i - )); - } else { - lines.push(format!( - "\t\tstates[i+1][{}] <== multi_or[{}][i].out;", - i, multi_or_i - )); - } + if let Some(multi_or_i_) = multi_or_checks2.get(&outputs_key) { + lines.push(format!( + "\t\tstates[i+1][{}] <== multi_or[{}][i].out;", + i, multi_or_i_ + )); } } - } else { - if zero_starting_states.contains(&i) { - lines.push(format!("\t\tstates_tmp[i+1][{}] <== 0;", i)); - } else { - lines.push(format!("\t\tstates[i+1][{}] <== 0;", i)); - } } - - // if zero_starting_states.contains(&i) { - // zero_starting_lines.append(&mut state_change_lines); - // } else { - // lines.append(&mut state_change_lines); - // } } - // let not_zero_starting_states = (1..n) - // .filter(|i| !zero_starting_states.contains(&i)) - // .collect_vec(); - lines.push(format!( - "\t\tfrom_zero_enabled[i] <== MultiNOR({})([{}]);", - n - 1, - (1..n) - .map(|i| if zero_starting_states.contains(&i) { - format!("states_tmp[i+1][{}]", i) - } else { - format!("states[i+1][{}]", i) - }) - .collect::>() - .join(", ") - )); - for (i, vec) in zero_starting_and_idxes.iter() { - if vec.len() == 0 { + + let mut acc_transitions_update = "\t\tacc_transitions[i+1] <== acc_transitions[i]".to_string(); + for i in 0..n { + if i == 0 { continue; } - lines.push(format!( - "\t\tstates[i+1][{}] <== MultiOR({})([states_tmp[i+1][{}], {}]);", - i, - vec.len() + 1, - i, - vec.iter() - .map(|and_i| format!("from_zero_enabled[i] * and[{}][i].out", and_i)) - .collect::>() - .join(", ") - )); - } - for i in 1..n { - lines.push(format!( - "\t\tstate_changed[i].in[{}] <== states[i+1][{}];", - i - 1, - i - )); - } - // lines.push("\t\tstates[i+1][0] <== 1 - state_changed[i].out;".to_string()); + acc_transitions_update.push_str(&format!(" + states[i+1][{}]", i)); + } + acc_transitions_update.push_str(";"); + lines.push(acc_transitions_update); let mut declarations = vec![]; declarations.push("pragma circom 2.1.5;\n".to_string()); @@ -374,40 +293,37 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) declarations.push(format!("template {}(msg_bytes) {{", template_name)); declarations.push("\tsignal input msg[msg_bytes];".to_string()); declarations.push("\tsignal output out;\n".to_string()); - declarations.push("\tvar num_bytes = msg_bytes+1;".to_string()); - declarations.push("\tsignal in[num_bytes];".to_string()); - declarations.push("\tin[0]<==255;".to_string()); + declarations.push("\tvar num_state_trace = msg_bytes+1;".to_string()); + declarations.push("\tsignal in[msg_bytes];".to_string()); declarations.push("\tfor (var i = 0; i < msg_bytes; i++) {".to_string()); - declarations.push("\t\tin[i+1] <== msg[i];".to_string()); + declarations.push("\t\tin[i] <== msg[i];".to_string()); declarations.push("\t}\n".to_string()); if eq_i > 0 { - declarations.push(format!("\tcomponent eq[{}][num_bytes];", eq_i)); + declarations.push(format!("\tcomponent eq[{}][msg_bytes];", eq_i)); } if lt_i > 0 { - declarations.push(format!("\tcomponent lt[{}][num_bytes];", lt_i)); + declarations.push(format!("\tcomponent lt[{}][msg_bytes];", lt_i)); } if and_i > 0 { - declarations.push(format!("\tcomponent and[{}][num_bytes];", and_i)); + declarations.push(format!("\tcomponent and[{}][msg_bytes];", and_i)); } if multi_or_i > 0 { - declarations.push(format!("\tcomponent multi_or[{}][num_bytes];", multi_or_i)); + declarations.push(format!("\tcomponent multi_or[{}][msg_bytes];", multi_or_i)); } - declarations.push(format!("\tsignal states[num_bytes+1][{}];", n)); - declarations.push(format!("\tsignal states_tmp[num_bytes+1][{}];", n)); - declarations.push(format!("\tsignal from_zero_enabled[num_bytes+1];")); - declarations.push(format!("\tfrom_zero_enabled[num_bytes] <== 0;")); - declarations.push("\tcomponent state_changed[num_bytes];\n".to_string()); + declarations.push(format!("\tsignal states[num_state_trace][{}];", n)); + declarations.push(format!("\tsignal acc_transitions[num_state_trace];\n")); let mut init_code = vec![]; - // init_code.push("\tstates[0][0] <== 1;".to_string()); + init_code.push("\tstates[0][0] <== 1;".to_string()); init_code.push(format!("\tfor (var i = 1; i < {}; i++) {{", n)); init_code.push("\t\tstates[0][i] <== 0;".to_string()); - init_code.push("\t}\n".to_string()); + init_code.push("\t}".to_string()); + init_code.push("\tacc_transitions[0] <== 0;\n".to_string()); let mut final_code = declarations .into_iter() @@ -416,18 +332,37 @@ fn gen_circom_allstr(dfa_graph: &DFAGraph, template_name: &str, regex_str: &str) .collect::>(); final_code.push("\t}".to_string()); - let accept_node = accept_nodes_array[0]; let mut accept_lines = vec![]; accept_lines.push("".to_string()); - accept_lines.push("\tcomponent final_state_result = MultiOR(num_bytes+1);".to_string()); - accept_lines.push("\tfor (var i = 0; i <= num_bytes; i++) {".to_string()); - accept_lines.push(format!( - "\t\tfinal_state_result.in[i] <== states[i][{}];", - accept_node - )); + accept_lines.push("\tcomponent final_state_result = MultiOR(msg_bytes+1);".to_string()); + accept_lines.push("\tfor (var i = 0; i <= msg_bytes; i++) {".to_string()); + if accept_nodes_array.len() == 1 { + accept_lines.push(format!( + "\t\tfinal_state_result.in[i] <== states[i][{}];", + accept_nodes_array[0] + )); + } else { + let mut accept_outputs = vec![]; + let mut accept_outputs_str = String::new(); + let mut accept_outputs_str = format!("MultiOR({})([", accept_nodes_array.len()); + for accept_node in &accept_nodes_array { + accept_outputs.push(format!("states[i][{}]", accept_node)); + } + accept_outputs_str.push_str(&accept_outputs.join(", ")); + accept_outputs_str.push_str("])"); + accept_lines.push(format!( + "\t\tfinal_state_result.in[i] <== {};", + accept_outputs_str + )); + } accept_lines.push("\t}".to_string()); - accept_lines.push("\tout <== final_state_result.out;".to_string()); + accept_lines.push( + "\tsignal is_acc_valid <== IsEqual()([acc_transitions[num_state_trace-1], msg_bytes]);" + .to_string(), + ); + accept_lines.push("\tout <== final_state_result.out * is_acc_valid;".to_string()); + accept_lines.push("}".to_string()); final_code.extend(accept_lines); diff --git a/packages/compiler/src/regex.rs b/packages/compiler/src/regex.rs index 02ae9d3..c0c6a93 100644 --- a/packages/compiler/src/regex.rs +++ b/packages/compiler/src/regex.rs @@ -56,10 +56,16 @@ fn parse_dfa_output(output: &str) -> DFAGraphInfo { eoi_pointing_states.insert(eoi_target); state.typ = String::from("accept"); state.edges.remove("EOI"); + // Set the dst of all edges pointing to eoi_target to this state + for edge in &mut state.edges { + if *edge.1 == eoi_target { + *edge.1 = state.source; + } + } } } - let start_state_re = Regex::new(r"START-GROUP\(anchored\)[\s*\w*\=>]*Text => (\d+)").unwrap(); + let start_state_re = Regex::new(r"START-GROUP\(unanchored\)[\s*\w*\=>]*Text => (\d+)").unwrap(); let start_state = start_state_re.captures_iter(output).next().unwrap()[1] .parse::() .unwrap(); @@ -251,9 +257,9 @@ fn add_dfa(net_dfa: &DFAGraph, graph: &DFAGraph) -> DFAGraph { pub fn regex_and_dfa(decomposed_regex: &DecomposedRegexConfig) -> RegexAndDFA { let mut config = DFA::config().minimize(true); - config = config.start_kind(StartKind::Anchored); + // config = config.start_kind(StartKind::Unanchored); config = config.byte_classes(false); - config = config.accelerate(true); + // config = config.accelerate(true); let mut net_dfa = DFAGraph { states: Vec::new() }; let mut substr_defs_array = Vec::new(); @@ -261,7 +267,7 @@ pub fn regex_and_dfa(decomposed_regex: &DecomposedRegexConfig) -> RegexAndDFA { for regex in decomposed_regex.parts.iter() { let re = DFA::builder() .configure(config.clone()) - .build(&format!(r"^{}$", regex.regex_def)) + .build(&format!(r"{}", regex.regex_def)) .unwrap(); let re_str = format!("{:?}", re); let mut graph = dfa_to_graph(&parse_dfa_output(&re_str));