Skip to content

Commit

Permalink
port egraph condition sythesizer to new SMT solver interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 16, 2024
1 parent 70afbb4 commit e31a7e3
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 36 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ rust-version = "1.73.0"
rustc-hash = "2.x"
baa = "0.14.7"
egg = "0.9.5"
easy-smt = "0.2.3"
regex = "1.11.1"
boolean_expression = "0.4.4"
clap = { version = "4.x", features = ["derive"] }
Expand Down
1 change: 0 additions & 1 deletion tools/egraphs-cond-synth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ patronus-egraphs = { path = "../../patronus-egraphs"}
egg.workspace = true
clap.workspace = true
rustc-hash.workspace = true
easy-smt.workspace = true
boolean_expression.workspace = true
baa.workspace = true
indicatif = "0.17.9"
Expand Down
19 changes: 9 additions & 10 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use baa::BitVecOps;
use clap::Parser;
use patronus::expr::*;
use patronus::mc::get_smt_value;
use patronus::smt::{CheckSatResponse, SolverContext};
use std::io::Write;
use std::path::{Path, PathBuf};

Expand Down Expand Up @@ -265,30 +266,30 @@ fn check_conditions(rule: &ArithRewrite, samples: &Samples, info: &RuleInfo) {
let rhs_expr = to_smt(&mut ctx, rhs, info, &a);

// run SMT solver to get a counter example
smt_ctx.push_many(1).unwrap();
smt_ctx.push().unwrap();
let resp = check_eq(&mut ctx, &mut smt_ctx, lhs_expr, rhs_expr);
assert_eq!(resp, easy_smt::Response::Sat);
assert_eq!(resp, CheckSatResponse::Sat);

// get assignments to variables
let is_eq = ctx.equal(lhs_expr, rhs_expr);
let vars = find_symbols_in_expr(&ctx, is_eq);
let mut values: Vec<String> = vars
.into_iter()
.map(|v| {
let value = get_value(&mut ctx, &mut smt_ctx, v);
let name = ctx.get_symbol_name(v).unwrap();
let value = get_value(&ctx, &mut smt_ctx, v);
format!("{name}={value}")
})
.collect();
values.push(format!(
"lhs_result={}",
get_value(&ctx, &mut smt_ctx, lhs_expr)
get_value(&mut ctx, &mut smt_ctx, lhs_expr)
));
values.push(format!(
"rhs_result={}",
get_value(&ctx, &mut smt_ctx, rhs_expr)
get_value(&mut ctx, &mut smt_ctx, rhs_expr)
));
smt_ctx.pop_many(1).unwrap();
smt_ctx.pop().unwrap();

println!(
" {} =/= {}",
Expand All @@ -300,10 +301,8 @@ fn check_conditions(rule: &ArithRewrite, samples: &Samples, info: &RuleInfo) {
}
}

fn get_value(ctx: &Context, smt_ctx: &mut easy_smt::Context, expr: ExprRef) -> String {
let tpe = expr.get_type(&ctx);
let v = patronus::smt::convert_expr(smt_ctx, &ctx, expr, &|_| None);
let value = get_smt_value(smt_ctx, v, tpe).unwrap();
fn get_value(ctx: &mut Context, smt_ctx: &mut impl SolverContext, expr: ExprRef) -> String {
let value = get_smt_value(ctx, smt_ctx, expr).unwrap();
if let baa::Value::BitVec(v) = value {
format!("{}", v.to_u64().unwrap())
} else {
Expand Down
40 changes: 16 additions & 24 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use crate::rewrites::ArithRewrite;
use egg::*;
use indicatif::ProgressBar;
use patronus::expr::traversal::TraversalCmd;
use patronus::expr::{Context, ExprRef, TypeCheck, WidthInt};
use patronus::expr::{Context, ExprRef, WidthInt};
use patronus::smt::{CheckSatResponse, Logic, Solver, SolverContext, BITWUZLA};
use patronus_egraphs::*;
use rayon::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet};
Expand Down Expand Up @@ -63,14 +64,14 @@ pub fn generate_samples(
let lhs_expr = to_smt(&mut ctx, lhs, &lhs_info, &assignment);
let rhs_expr = to_smt(&mut ctx, rhs, &rhs_info, &assignment);

smt_ctx.push_many(1).unwrap();
smt_ctx.push().unwrap();
let resp = check_eq(&mut ctx, &mut smt_ctx, lhs_expr, rhs_expr);
smt_ctx.pop_many(1).unwrap();
smt_ctx.pop().unwrap();

match resp {
easy_smt::Response::Sat => samples.add(assignment, false),
easy_smt::Response::Unsat => samples.add(assignment, true),
easy_smt::Response::Unknown => println!("{assignment:?} => Unknown!"),
CheckSatResponse::Sat => samples.add(assignment, false),
CheckSatResponse::Unsat => samples.add(assignment, true),
CheckSatResponse::Unknown => println!("{assignment:?} => Unknown!"),
}
}

Expand All @@ -86,31 +87,25 @@ pub fn generate_samples(

pub fn check_eq(
ctx: &mut Context,
smt_ctx: &mut easy_smt::Context,
smt_ctx: &mut impl SolverContext,
lhs_expr: ExprRef,
rhs_expr: ExprRef,
) -> easy_smt::Response {
) -> CheckSatResponse {
let is_eq = ctx.equal(lhs_expr, rhs_expr);
let is_not_eq = ctx.not(is_eq);
let smt_expr = patronus::smt::convert_expr(&smt_ctx, &ctx, is_not_eq, &|_| None);
declare_vars(smt_ctx, &ctx, is_not_eq);
smt_ctx.assert(smt_expr).unwrap();
smt_ctx.check().unwrap()
smt_ctx.assert(ctx, is_not_eq).unwrap();
smt_ctx.check_sat().unwrap()
}

pub fn start_solver(dump_smt: bool) -> easy_smt::Context {
let solver: patronus::mc::SmtSolverCmd = patronus::mc::BITWUZLA_CMD;
pub fn start_solver(dump_smt: bool) -> impl SolverContext {
let dump_file = if dump_smt {
Some(std::fs::File::create("replay.smt").unwrap())
} else {
None
};
let mut smt_ctx = easy_smt::ContextBuilder::new()
.solver(solver.name, solver.args)
.replay_file(dump_file)
.build()
.unwrap();
smt_ctx.set_logic("QF_ABV").unwrap();
let mut smt_ctx = BITWUZLA.start(dump_file).expect("failed to start solver");
smt_ctx.set_logic(Logic::QfAbv).unwrap();
smt_ctx
}

Expand Down Expand Up @@ -250,14 +245,11 @@ pub fn find_symbols_in_expr(ctx: &Context, expr: ExprRef) -> Vec<ExprRef> {
vars
}

fn declare_vars(smt_ctx: &mut easy_smt::Context, ctx: &Context, expr: ExprRef) {
fn declare_vars(smt_ctx: &mut impl SolverContext, ctx: &Context, expr: ExprRef) {
let vars = find_symbols_in_expr(ctx, expr);
for v in vars.into_iter() {
let expr = &ctx[v];
let tpe = patronus::smt::convert_tpe(smt_ctx, expr.get_type(ctx));
let name = expr.get_symbol_name(ctx).unwrap();
smt_ctx
.declare_const(name, tpe)
.declare_const(ctx, v)
.expect("failed to declare const");
}
}
Expand Down

0 comments on commit e31a7e3

Please sign in to comment.