From e31a7e3ad3a814b78658b5755121790a7d171a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kevin=20L=C3=A4ufer?= Date: Mon, 16 Dec 2024 12:31:36 -0500 Subject: [PATCH] port egraph condition sythesizer to new SMT solver interface --- Cargo.toml | 1 - tools/egraphs-cond-synth/Cargo.toml | 1 - tools/egraphs-cond-synth/src/main.rs | 19 ++++++------ tools/egraphs-cond-synth/src/samples.rs | 40 ++++++++++--------------- 4 files changed, 25 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 84a0538..6c6c38d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ rust-version = "1.73.0" rustc-hash = "2.x" baa = "0.14.7" egg = "0.9.5" -easy-smt = "0.2.3" regex = "1.11.1" boolean_expression = "0.4.4" clap = { version = "4.x", features = ["derive"] } diff --git a/tools/egraphs-cond-synth/Cargo.toml b/tools/egraphs-cond-synth/Cargo.toml index de9b491..5fbcddd 100644 --- a/tools/egraphs-cond-synth/Cargo.toml +++ b/tools/egraphs-cond-synth/Cargo.toml @@ -14,7 +14,6 @@ patronus-egraphs = { path = "../../patronus-egraphs"} egg.workspace = true clap.workspace = true rustc-hash.workspace = true -easy-smt.workspace = true boolean_expression.workspace = true baa.workspace = true indicatif = "0.17.9" diff --git a/tools/egraphs-cond-synth/src/main.rs b/tools/egraphs-cond-synth/src/main.rs index ed89ff5..8678ef7 100644 --- a/tools/egraphs-cond-synth/src/main.rs +++ b/tools/egraphs-cond-synth/src/main.rs @@ -21,6 +21,7 @@ use baa::BitVecOps; use clap::Parser; use patronus::expr::*; use patronus::mc::get_smt_value; +use patronus::smt::{CheckSatResponse, SolverContext}; use std::io::Write; use std::path::{Path, PathBuf}; @@ -265,9 +266,9 @@ fn check_conditions(rule: &ArithRewrite, samples: &Samples, info: &RuleInfo) { let rhs_expr = to_smt(&mut ctx, rhs, info, &a); // run SMT solver to get a counter example - smt_ctx.push_many(1).unwrap(); + smt_ctx.push().unwrap(); let resp = check_eq(&mut ctx, &mut smt_ctx, lhs_expr, rhs_expr); - assert_eq!(resp, easy_smt::Response::Sat); + assert_eq!(resp, CheckSatResponse::Sat); // get assignments to variables let is_eq = ctx.equal(lhs_expr, rhs_expr); @@ -275,20 +276,20 @@ fn check_conditions(rule: &ArithRewrite, samples: &Samples, info: &RuleInfo) { let mut values: Vec = vars .into_iter() .map(|v| { + let value = get_value(&mut ctx, &mut smt_ctx, v); let name = ctx.get_symbol_name(v).unwrap(); - let value = get_value(&ctx, &mut smt_ctx, v); format!("{name}={value}") }) .collect(); values.push(format!( "lhs_result={}", - get_value(&ctx, &mut smt_ctx, lhs_expr) + get_value(&mut ctx, &mut smt_ctx, lhs_expr) )); values.push(format!( "rhs_result={}", - get_value(&ctx, &mut smt_ctx, rhs_expr) + get_value(&mut ctx, &mut smt_ctx, rhs_expr) )); - smt_ctx.pop_many(1).unwrap(); + smt_ctx.pop().unwrap(); println!( " {} =/= {}", @@ -300,10 +301,8 @@ fn check_conditions(rule: &ArithRewrite, samples: &Samples, info: &RuleInfo) { } } -fn get_value(ctx: &Context, smt_ctx: &mut easy_smt::Context, expr: ExprRef) -> String { - let tpe = expr.get_type(&ctx); - let v = patronus::smt::convert_expr(smt_ctx, &ctx, expr, &|_| None); - let value = get_smt_value(smt_ctx, v, tpe).unwrap(); +fn get_value(ctx: &mut Context, smt_ctx: &mut impl SolverContext, expr: ExprRef) -> String { + let value = get_smt_value(ctx, smt_ctx, expr).unwrap(); if let baa::Value::BitVec(v) = value { format!("{}", v.to_u64().unwrap()) } else { diff --git a/tools/egraphs-cond-synth/src/samples.rs b/tools/egraphs-cond-synth/src/samples.rs index 66612f0..587de54 100644 --- a/tools/egraphs-cond-synth/src/samples.rs +++ b/tools/egraphs-cond-synth/src/samples.rs @@ -6,7 +6,8 @@ use crate::rewrites::ArithRewrite; use egg::*; use indicatif::ProgressBar; use patronus::expr::traversal::TraversalCmd; -use patronus::expr::{Context, ExprRef, TypeCheck, WidthInt}; +use patronus::expr::{Context, ExprRef, WidthInt}; +use patronus::smt::{CheckSatResponse, Logic, Solver, SolverContext, BITWUZLA}; use patronus_egraphs::*; use rayon::prelude::*; use rustc_hash::{FxHashMap, FxHashSet}; @@ -63,14 +64,14 @@ pub fn generate_samples( let lhs_expr = to_smt(&mut ctx, lhs, &lhs_info, &assignment); let rhs_expr = to_smt(&mut ctx, rhs, &rhs_info, &assignment); - smt_ctx.push_many(1).unwrap(); + smt_ctx.push().unwrap(); let resp = check_eq(&mut ctx, &mut smt_ctx, lhs_expr, rhs_expr); - smt_ctx.pop_many(1).unwrap(); + smt_ctx.pop().unwrap(); match resp { - easy_smt::Response::Sat => samples.add(assignment, false), - easy_smt::Response::Unsat => samples.add(assignment, true), - easy_smt::Response::Unknown => println!("{assignment:?} => Unknown!"), + CheckSatResponse::Sat => samples.add(assignment, false), + CheckSatResponse::Unsat => samples.add(assignment, true), + CheckSatResponse::Unknown => println!("{assignment:?} => Unknown!"), } } @@ -86,31 +87,25 @@ pub fn generate_samples( pub fn check_eq( ctx: &mut Context, - smt_ctx: &mut easy_smt::Context, + smt_ctx: &mut impl SolverContext, lhs_expr: ExprRef, rhs_expr: ExprRef, -) -> easy_smt::Response { +) -> CheckSatResponse { let is_eq = ctx.equal(lhs_expr, rhs_expr); let is_not_eq = ctx.not(is_eq); - let smt_expr = patronus::smt::convert_expr(&smt_ctx, &ctx, is_not_eq, &|_| None); declare_vars(smt_ctx, &ctx, is_not_eq); - smt_ctx.assert(smt_expr).unwrap(); - smt_ctx.check().unwrap() + smt_ctx.assert(ctx, is_not_eq).unwrap(); + smt_ctx.check_sat().unwrap() } -pub fn start_solver(dump_smt: bool) -> easy_smt::Context { - let solver: patronus::mc::SmtSolverCmd = patronus::mc::BITWUZLA_CMD; +pub fn start_solver(dump_smt: bool) -> impl SolverContext { let dump_file = if dump_smt { Some(std::fs::File::create("replay.smt").unwrap()) } else { None }; - let mut smt_ctx = easy_smt::ContextBuilder::new() - .solver(solver.name, solver.args) - .replay_file(dump_file) - .build() - .unwrap(); - smt_ctx.set_logic("QF_ABV").unwrap(); + let mut smt_ctx = BITWUZLA.start(dump_file).expect("failed to start solver"); + smt_ctx.set_logic(Logic::QfAbv).unwrap(); smt_ctx } @@ -250,14 +245,11 @@ pub fn find_symbols_in_expr(ctx: &Context, expr: ExprRef) -> Vec { vars } -fn declare_vars(smt_ctx: &mut easy_smt::Context, ctx: &Context, expr: ExprRef) { +fn declare_vars(smt_ctx: &mut impl SolverContext, ctx: &Context, expr: ExprRef) { let vars = find_symbols_in_expr(ctx, expr); for v in vars.into_iter() { - let expr = &ctx[v]; - let tpe = patronus::smt::convert_tpe(smt_ctx, expr.get_type(ctx)); - let name = expr.get_symbol_name(ctx).unwrap(); smt_ctx - .declare_const(name, tpe) + .declare_const(ctx, v) .expect("failed to declare const"); } }