diff --git a/patronus/src/smt.rs b/patronus/src/smt.rs index ac56aa0..7654fda 100644 --- a/patronus/src/smt.rs +++ b/patronus/src/smt.rs @@ -7,5 +7,5 @@ mod parser; mod serialize; mod solver; -pub use parser::parse_expr; +pub use parser::{parse_command, parse_expr}; pub use solver::*; diff --git a/patronus/src/smt/parser.rs b/patronus/src/smt/parser.rs index b2cde6d..83d0953 100644 --- a/patronus/src/smt/parser.rs +++ b/patronus/src/smt/parser.rs @@ -69,10 +69,10 @@ type Result = std::result::Result; type SymbolTable = FxHashMap; -pub fn parse_expr(ctx: &mut Context, st: &mut SymbolTable, input: &[u8]) -> Result { +pub fn parse_expr(ctx: &mut Context, st: &SymbolTable, input: &[u8]) -> Result { let mut lexer = Lexer::new(input); let expr = parse_expr_internal(ctx, st, &mut lexer)?; - let token_after_expr = lexer.next(); + let token_after_expr = lexer.next_no_comment(); if token_after_expr.is_some() { Err(SmtParserError::ExprSuffix(format!("{token_after_expr:?}"))) } else { @@ -80,24 +80,17 @@ pub fn parse_expr(ctx: &mut Context, st: &mut SymbolTable, input: &[u8]) -> Resu } } -fn parse_expr_internal( - ctx: &mut Context, - st: &mut SymbolTable, - lexer: &mut Lexer, -) -> Result { +fn parse_expr_internal(ctx: &mut Context, st: &SymbolTable, lexer: &mut Lexer) -> Result { match parse_expr_or_type(ctx, st, lexer)? { ExprOrType::E(e) => Ok(e), ExprOrType::T(t) => Err(SmtParserError::TypeInsteadOfExpr(format!("{t:?}"))), } } -fn parse_type(ctx: &mut Context, st: &mut SymbolTable, lexer: &mut Lexer) -> Result { +fn parse_type(ctx: &mut Context, st: &SymbolTable, lexer: &mut Lexer) -> Result { match parse_expr_or_type(ctx, st, lexer)? { ExprOrType::T(t) => Ok(t), - ExprOrType::E(e) => Err(SmtParserError::ExprInsteadOfType(format!( - "{}", - e.serialize_to_str(ctx) - ))), + ExprOrType::E(e) => Err(SmtParserError::ExprInsteadOfType(e.serialize_to_str(ctx))), } } @@ -108,7 +101,7 @@ enum ExprOrType { fn parse_expr_or_type( ctx: &mut Context, - st: &mut SymbolTable, + st: &SymbolTable, lexer: &mut Lexer, ) -> Result { use ParserItem::*; @@ -140,8 +133,8 @@ fn parse_expr_or_type( if orphan_closing_count > 0 { return Err(SmtParserError::ClosingParenWithoutOpening); } - // we eagerly parse number literals, but we do not make decisions on symbols yet - stack.push(early_parse_number_literals(ctx, value)?); + // we eagerly parse expressions and types that are represented by a single token + stack.push(early_parse_single_token(ctx, st, value)?); } Token::EscapedValue(value) => { if orphan_closing_count > 0 { @@ -149,6 +142,11 @@ fn parse_expr_or_type( } stack.push(PExpr(lookup_sym(st, value)?)) } + Token::StringLit(value) => { + let value = string_lit_to_string(value); + todo!("unexpected string literal in expression: {value}") + } + Token::Comment(_) => {} // ignore comments } // are we done? @@ -158,7 +156,7 @@ fn parse_expr_or_type( _ => {} // cotinue parsing } } - todo!("error message!") + todo!("error message!: {stack:?}") } /// Extracts the value expression from SMT solver responses of the form ((... value)) @@ -180,7 +178,7 @@ pub fn parse_get_value_response(ctx: &mut Context, input: &[u8]) -> Result Result<()> { - let token = lexer.next(); + let token = lexer.next_no_comment(); if token == Some(Token::Open) { Ok(()) } else { @@ -189,7 +187,7 @@ fn skip_open_parens(lexer: &mut Lexer) -> Result<()> { } fn skip_close_parens(lexer: &mut Lexer) -> Result<()> { - let token = lexer.next(); + let token = lexer.next_no_comment(); if token == Some(Token::Close) { Ok(()) } else { @@ -198,13 +196,13 @@ fn skip_close_parens(lexer: &mut Lexer) -> Result<()> { } /// Parses a single command. -pub fn parse_command(ctx: &mut Context, st: &mut SymbolTable, input: &[u8]) -> Result { +pub fn parse_command(ctx: &mut Context, st: &SymbolTable, input: &[u8]) -> Result { let mut lexer = Lexer::new(input); // `(` skip_open_parens(&mut lexer)?; // next token should be the command - let cmd_token = lexer.next(); + let cmd_token = lexer.next_no_comment(); let cmd = match cmd_token { Some(Token::Value(name)) => match name { b"exit" => SmtCommand::Exit, @@ -213,14 +211,17 @@ pub fn parse_command(ctx: &mut Context, st: &mut SymbolTable, input: &[u8]) -> R let logic = parse_logic(&mut lexer)?; SmtCommand::SetLogic(logic) } - b"set-option" => { + b"set-option" | b"set-info" => { let key = value_token(&mut lexer)?; - let value = value_token(&mut lexer)?; + let value = any_string_token(&mut lexer)?; if let Some(key) = key.strip_prefix(b":") { - SmtCommand::SetOption( - String::from_utf8_lossy(key).into(), - String::from_utf8_lossy(value).into(), - ) + let key = String::from_utf8_lossy(key).into(); + if name == b"set-option" { + SmtCommand::SetOption(key, value.into()) + } else { + debug_assert_eq!(name, b"set-info"); + SmtCommand::SetInfo(key, value.into()) + } } else { return Err(SmtParserError::InvalidOptionKey( String::from_utf8_lossy(key).into(), @@ -248,6 +249,19 @@ pub fn parse_command(ctx: &mut Context, st: &mut SymbolTable, input: &[u8]) -> R let sym = ctx.symbol(name_ref, tpe); SmtCommand::DefineConst(sym, value) } + b"define-fun" => { + // parses the `define-const` subset (i.e. no arguments!) + let name = String::from_utf8_lossy(value_token(&mut lexer)?); + skip_open_parens(&mut lexer)?; + skip_close_parens(&mut lexer)?; + let tpe = parse_type(ctx, st, &mut lexer)?; + let value = parse_expr_internal(ctx, st, &mut lexer)?; + // TODO: turn this into a proper error + debug_assert_eq!(ctx[value].get_type(ctx), tpe); + let name_ref = ctx.string(name); + let sym = ctx.symbol(name_ref, tpe); + SmtCommand::DefineConst(sym, value) + } b"check-sat-assuming" => { let expressions = vec![parse_expr_internal(ctx, st, &mut lexer)?]; // TODO: deal with more than one expression @@ -293,13 +307,23 @@ fn parse_logic(lexer: &mut Lexer) -> Result { } fn value_token<'a>(lexer: &mut Lexer<'a>) -> Result<&'a [u8]> { - match lexer.next() { + match lexer.next_no_comment() { Some(Token::Value(v)) => Ok(v), Some(Token::EscapedValue(v)) => Ok(v), other => Err(SmtParserError::ExpectedIdentifer(format!("{other:?}"))), } } +/// parse a token that can be converted to a string +fn any_string_token<'a>(lexer: &mut Lexer<'a>) -> Result> { + match lexer.next_no_comment() { + Some(Token::Value(v)) => Ok(String::from_utf8_lossy(v)), + Some(Token::EscapedValue(v)) => Ok(String::from_utf8_lossy(v)), + Some(Token::StringLit(v)) => Ok(string_lit_to_string(v).into()), + other => Err(SmtParserError::ExpectedIdentifer(format!("{other:?}"))), + } +} + fn skip_expr(lexer: &mut Lexer) -> Result<()> { let mut open_count = 0u64; for token in lexer.by_ref() { @@ -313,6 +337,7 @@ fn skip_expr(lexer: &mut Lexer) -> Result<()> { return Ok(()); } } + Token::Comment(_) => {} // skip _ => { if open_count == 0 { return Ok(()); @@ -382,7 +407,12 @@ fn expr(st: &SymbolTable, item: &ParserItem<'_>) -> Result { } } -fn early_parse_number_literals<'a>(ctx: &mut Context, value: &'a [u8]) -> Result> { +/// Parses things that can be represented by a single token. +fn early_parse_single_token<'a>( + ctx: &mut Context, + st: &SymbolTable, + value: &'a [u8], +) -> Result> { if let Some(match_id) = NUM_LIT_REGEX.matches(value).into_iter().next() { match match_id { 0 => { @@ -404,7 +434,13 @@ fn early_parse_number_literals<'a>(ctx: &mut Context, value: &'a [u8]) -> Result _ => unreachable!("not part of the regex!"), } } else { - Ok(ParserItem::Sym(value)) + match value { + b"Bool" => Ok(ParserItem::PType(Type::BV(1))), + other => { + let symbol = lookup_sym(st, other).ok().map(ParserItem::PExpr); + Ok(symbol.unwrap_or(ParserItem::Sym(value))) + } + } } } @@ -460,6 +496,8 @@ enum Token<'a> { Close, Value(&'a [u8]), EscapedValue(&'a [u8]), + StringLit(&'a [u8]), + Comment(&'a [u8]), } impl<'a> Debug for Token<'a> { @@ -469,15 +507,25 @@ impl<'a> Debug for Token<'a> { Token::Close => write!(f, ")"), Token::Value(v) => write!(f, "{}", String::from_utf8_lossy(v)), Token::EscapedValue(v) => write!(f, "{}", String::from_utf8_lossy(v)), + Token::StringLit(v) => write!(f, "{}", string_lit_to_string(v)), + Token::Comment(v) => write!(f, "/* {} */", String::from_utf8_lossy(v)), } } } +fn string_lit_to_string(value: &[u8]) -> String { + let s = String::from_utf8_lossy(value); + s.replace("\"\"", "\"") +} + #[derive(Debug, Copy, Clone)] enum LexState { Searching, ParsingToken(usize), ParsingEscapedToken(usize), + ParsingStringLiteral(usize), + StringLiteralQuoteFound(usize), + ParsingComment(usize), } impl<'a> Lexer<'a> { @@ -488,6 +536,12 @@ impl<'a> Lexer<'a> { pos: 0, } } + + /// returns the next token that is not a comment + fn next_no_comment(&mut self) -> Option> { + self.by_ref() + .find(|token| !matches!(token, Token::Comment(_))) + } } impl<'a> Iterator for Lexer<'a> { @@ -512,8 +566,8 @@ impl<'a> Iterator for Lexer<'a> { b')' => return Some(Token::Close), // White Space Characters: tab, line feed, carriage return or space b' ' | b'\n' | b'\r' | b'\t' => Searching, - // string literals are currently not supported - b'"' => todo!("String literals are currently not supported!"), + b'"' => ParsingStringLiteral(self.pos), + b';' => ParsingComment(self.pos), _ => ParsingToken(self.pos - 1), } } @@ -539,10 +593,43 @@ impl<'a> Iterator for Lexer<'a> { return Some(Token::EscapedValue(&self.input[start..(self.pos - 1)])); } } + ParsingStringLiteral(start) => { + // consume character + self.pos += 1; + if c == b'"' { + self.state = StringLiteralQuoteFound(start); + } + } + StringLiteralQuoteFound(start) => { + // did we just find an escaped quote? + if c == b'"' { + // consume character + self.pos += 1; + self.state = ParsingStringLiteral(start); + } else { + self.state = Searching; // do not consume the character + return Some(Token::StringLit(&self.input[start..(self.pos - 1)])); + } + } + ParsingComment(start) => { + if c == b'\n' || c == b'\r' { + self.state = Searching; // do not consume the character + return Some(Token::Comment(&self.input[start..(self.pos - 1)])); + } else { + // consume character + self.pos += 1; + } + } }; } - debug_assert_eq!(self.pos, self.input.len() - 1); + debug_assert!(matches!(self.state, Searching), "{:?}", self.state); + debug_assert_eq!( + self.pos, + self.input.len(), + "{}", + String::from_utf8_lossy(self.input) + ); None } } @@ -572,8 +659,8 @@ mod tests { fn test_parser() { let mut ctx = Context::default(); let a = ctx.bv_symbol("a", 2); - let mut symbols = FxHashMap::from_iter([("a".to_string(), a)]); - let expr = parse_expr(&mut ctx, &mut symbols, "(bvand a #b00)".as_bytes()).unwrap(); + let symbols = FxHashMap::from_iter([("a".to_string(), a)]); + let expr = parse_expr(&mut ctx, &symbols, "(bvand a #b00)".as_bytes()).unwrap(); assert_eq!(expr, ctx.build(|c| c.and(a, c.bit_vec_val(0, 2)))); } @@ -592,22 +679,22 @@ mod tests { #[test] fn test_parse_smt_array_const_and_store() { let mut ctx = Context::default(); - let mut symbols = FxHashMap::default(); + let symbols = FxHashMap::default(); let base = "((as const (Array (_ BitVec 5) (_ BitVec 32))) #b00000000000000000000000000110011)"; - let expr = parse_expr(&mut ctx, &mut symbols, base.as_bytes()).unwrap(); + let expr = parse_expr(&mut ctx, &symbols, base.as_bytes()).unwrap(); assert_eq!(expr.serialize_to_str(&ctx), "([32'x00000033] x 2^5)"); let store_1 = format!("(store {base} #b01110 #x00000000)"); - let expr = parse_expr(&mut ctx, &mut symbols, store_1.as_bytes()).unwrap(); + let expr = parse_expr(&mut ctx, &symbols, store_1.as_bytes()).unwrap(); assert_eq!( expr.serialize_to_str(&ctx), "([32'x00000033] x 2^5)[5'b01110 := 32'x00000000]" ); let store_2 = format!("(store {store_1} #b01110 #x00000011)"); - let expr = parse_expr(&mut ctx, &mut symbols, store_2.as_bytes()).unwrap(); + let expr = parse_expr(&mut ctx, &symbols, store_2.as_bytes()).unwrap(); assert_eq!( expr.serialize_to_str(&ctx), "([32'x00000033] x 2^5)[5'b01110 := 32'x00000000][5'b01110 := 32'x00000011]" diff --git a/patronus/src/smt/serialize.rs b/patronus/src/smt/serialize.rs index e0aa40b..3751a71 100644 --- a/patronus/src/smt/serialize.rs +++ b/patronus/src/smt/serialize.rs @@ -288,7 +288,12 @@ pub fn serialize_cmd(out: &mut impl Write, ctx: Option<&Context>, cmd: &SmtComma SmtCommand::Exit => writeln!(out, "(exit)"), SmtCommand::CheckSat => writeln!(out, "(check-sat)"), SmtCommand::SetLogic(logic) => writeln!(out, "(set-logic {})", logic.to_smt_str()), - SmtCommand::SetOption(name, value) => writeln!(out, "(set-option :{name} {value})"), + SmtCommand::SetOption(name, value) => { + writeln!(out, "(set-option :{name} {})", escape_smt_identifier(value)) + } + SmtCommand::SetInfo(name, value) => { + writeln!(out, "(set-option :{name} {})", escape_smt_identifier(value)) + } SmtCommand::Assert(e) => { write!(out, "(assert ")?; serialize_expr(out, ctx.unwrap(), *e)?; diff --git a/patronus/src/smt/solver.rs b/patronus/src/smt/solver.rs index 5f16108..f6872fc 100644 --- a/patronus/src/smt/solver.rs +++ b/patronus/src/smt/solver.rs @@ -54,6 +54,7 @@ pub enum SmtCommand { CheckSat, SetLogic(Logic), SetOption(String, String), + SetInfo(String, String), Assert(ExprRef), DeclareConst(ExprRef), DefineConst(ExprRef, ExprRef), diff --git a/tools/simplify/Cargo.toml b/tools/simplify/Cargo.toml index fba0d75..b523a58 100644 --- a/tools/simplify/Cargo.toml +++ b/tools/simplify/Cargo.toml @@ -11,3 +11,4 @@ rust-version.workspace = true [dependencies] patronus.workspace = true clap.workspace = true +rustc-hash.workspace = true diff --git a/tools/simplify/src/main.rs b/tools/simplify/src/main.rs index 05de61f..7fc836b 100644 --- a/tools/simplify/src/main.rs +++ b/tools/simplify/src/main.rs @@ -4,8 +4,9 @@ use clap::Parser; use patronus::expr::*; -use patronus::smt::SmtCommand; +use patronus::smt::{parse_command, SmtCommand}; use patronus::*; +use rustc_hash::FxHashMap; use std::io::{BufRead, BufReader}; use std::path::PathBuf; @@ -28,13 +29,72 @@ fn main() { let in_file = std::fs::File::open(args.input_file).expect("failed to open input file"); let mut in_reader = BufReader::new(in_file); let mut ctx = Context::default(); - let cmds = parse_commands(&mut in_reader, &mut ctx); + let mut st = FxHashMap::default(); + let cmds = read_cmds(&mut in_reader, &mut ctx, &mut st); todo!(); } -fn parse_commands(inp: &mut impl BufRead, ctx: &mut Context) -> Vec { +fn read_cmds(inp: &mut impl BufRead, ctx: &mut Context, st: &mut SymbolTable) -> Vec { let mut out = vec![]; - + while let Some(cmd) = read_cmd(inp, ctx, st).unwrap() { + out.push(cmd); + } out } + +type SymbolTable = FxHashMap; + +fn read_cmd( + inp: &mut impl BufRead, + ctx: &mut Context, + st: &mut SymbolTable, +) -> std::io::Result> { + let mut response = String::new(); + inp.read_line(&mut response)?; + + // skip lines that are just comments + while is_comment(&response) { + response.clear(); + inp.read_line(&mut response)?; + } + + // ensure that the response contains balanced parentheses + while count_parens(&response) > 0 { + response.push(' '); + inp.read_line(&mut response)?; + } + + // debug print + println!("{response}"); + let cmd = parse_command(ctx, st, response.as_bytes()); + println!("{cmd:?}"); + let cmd = cmd.unwrap(); + + // add symbols to table + match cmd { + SmtCommand::DefineConst(sym, _) | SmtCommand::DeclareConst(sym) => { + st.insert(ctx.get_symbol_name(sym).unwrap().into(), sym); + } + _ => {} + } + Ok(Some(cmd)) +} + +fn is_comment(line: &str) -> bool { + for c in line.chars() { + if !c.is_ascii_whitespace() { + return c == ';'; + } + } + // all whilespace + false +} + +fn count_parens(s: &str) -> i64 { + s.chars().fold(0, |count, cc| match cc { + '(' => count + 1, + ')' => count - 1, + _ => count, + }) +}