diff --git a/src/intrinsic/annotations.rs b/src/intrinsic/annotations.rs index 590ab531..fb436be6 100644 --- a/src/intrinsic/annotations.rs +++ b/src/intrinsic/annotations.rs @@ -104,9 +104,9 @@ pub struct Calculus { #[derive(Debug, Clone)] pub enum CalculusType { - WP, - WLP, - ERT, + Wp, + Wlp, + Ert, } pub struct CalculusAnnotationError; @@ -169,21 +169,21 @@ pub fn init_calculi(files: &mut Files, tcx: &mut TyCtx) { let wp = AnnotationKind::Calculus(Calculus { name: Ident::with_dummy_file_span(Symbol::intern("wp"), file), - calculus_type: CalculusType::WP, + calculus_type: CalculusType::Wp, }); tcx.add_global(wp.name()); tcx.declare(DeclKind::AnnotationDecl(wp)); let wlp = AnnotationKind::Calculus(Calculus { name: Ident::with_dummy_file_span(Symbol::intern("wlp"), file), - calculus_type: CalculusType::WLP, + calculus_type: CalculusType::Wlp, }); tcx.add_global(wlp.name()); tcx.declare(DeclKind::AnnotationDecl(wlp)); let ert = AnnotationKind::Calculus(Calculus { name: Ident::with_dummy_file_span(Symbol::intern("ert"), file), - calculus_type: CalculusType::ERT, + calculus_type: CalculusType::Ert, }); tcx.add_global(ert.name()); tcx.declare(DeclKind::AnnotationDecl(ert)); diff --git a/src/proof_rules/induction.rs b/src/proof_rules/induction.rs index 1ba981e6..68ca1933 100644 --- a/src/proof_rules/induction.rs +++ b/src/proof_rules/induction.rs @@ -84,17 +84,12 @@ impl Encoding for InvariantAnnotation { resolve.visit_expr(invariant) } - fn check_calculus(&self, calculus: &Calculus, direction: Direction) -> Result<(), ()> { - if direction - != match calculus.calculus_type { - CalculusType::WP | CalculusType::ERT => Direction::Up, - CalculusType::WLP => Direction::Down, - } - { - return Err(()); - } - - Ok(()) + fn is_calculus_allowed(&self, calculus: &Calculus, direction: Direction) -> bool { + matches!( + (&calculus.calculus_type, direction), + (CalculusType::Wp | CalculusType::Ert, Direction::Up) + | (CalculusType::Wlp, Direction::Down) + ) } fn transform( @@ -207,17 +202,17 @@ impl Encoding for KIndAnnotation { resolve.visit_expr(invariant) } - fn check_calculus(&self, calculus: &Calculus, direction: Direction) -> Result<(), ()> { + fn is_calculus_allowed(&self, calculus: &Calculus, direction: Direction) -> bool { if direction != match calculus.calculus_type { - CalculusType::WP | CalculusType::ERT => Direction::Up, - CalculusType::WLP => Direction::Down, + CalculusType::Wp | CalculusType::Ert => Direction::Up, + CalculusType::Wlp => Direction::Down, } { - return Err(()); + return false; } - Ok(()) + true } fn transform( &self, diff --git a/src/proof_rules/mciver_ast.rs b/src/proof_rules/mciver_ast.rs index cf5be7b3..2a9acbce 100644 --- a/src/proof_rules/mciver_ast.rs +++ b/src/proof_rules/mciver_ast.rs @@ -117,14 +117,9 @@ impl Encoding for ASTAnnotation { check_annotation_call(tycheck, call_span, &self.0, args)?; Ok(()) } - fn check_calculus(&self, calculus: &Calculus, direction: Direction) -> Result<(), ()> { - if let CalculusType::WP = calculus.calculus_type { - if direction == Direction::Down { - return Ok(()); - } - } - Err(()) + fn is_calculus_allowed(&self, calculus: &Calculus, direction: Direction) -> bool { + matches!(calculus.calculus_type, CalculusType::Wp) && direction == Direction::Down } fn transform( diff --git a/src/proof_rules/mod.rs b/src/proof_rules/mod.rs index 70a80aa5..65eea3e7 100644 --- a/src/proof_rules/mod.rs +++ b/src/proof_rules/mod.rs @@ -87,7 +87,7 @@ pub trait Encoding: fmt::Debug { ) -> Result; /// Check if the given calculus annotation is compatible with the encoding annotation - fn check_calculus(&self, calculus: &Calculus, direction: Direction) -> Result<(), ()>; + fn is_calculus_allowed(&self, calculus: &Calculus, direction: Direction) -> bool; /// Indicates if the encoding annotation is required to be the last statement of a procedure fn is_terminator(&self) -> bool; @@ -152,35 +152,6 @@ impl<'tcx, 'sunit> EncCall<'tcx, 'sunit> { impl<'tcx, 'sunit> VisitorMut for EncCall<'tcx, 'sunit> { type Err = AnnotationError; - fn visit_decl(&mut self, decl: &mut DeclKind) -> Result<(), Self::Err> { - if let DeclKind::ProcDecl(decl_ref) = decl { - self.direction = Some(decl_ref.borrow().direction); - self.current_proc_ident = Some(decl_ref.borrow().name); - - // If the procedure has a calculus annotation, store it as the current calculus - if let Some(ident) = decl_ref.borrow().calculus.as_ref() { - match self.tcx.get(*ident) { - Some(decl) => { - if let DeclKind::AnnotationDecl(AnnotationKind::Calculus(calculus)) = - decl.as_ref() - { - self.calculus = Some(calculus.clone()); - } - } - None => { - return Err(AnnotationError::UnknownAnnotation( - decl_ref.borrow().span, - *ident, - )) - } - } - } - - self.visit_proc(decl_ref)?; - } - Ok(()) - } - fn visit_proc(&mut self, proc_ref: &mut DeclRef) -> Result<(), Self::Err> { self.direction = Some(proc_ref.borrow().direction); self.current_proc_ident = Some(proc_ref.borrow().name); @@ -240,14 +211,12 @@ impl<'tcx, 'sunit> VisitorMut for EncCall<'tcx, 'sunit> { // Check if the calculus annotation is compatible with the encoding annotation if let Some(calculus) = &self.calculus { - if anno_ref - .check_calculus( - calculus, - self.direction - .ok_or(AnnotationError::NotInProcedure(s.span, *ident))?, - ) - .is_err() - { + // If calculus is not allowed, return an error + if !anno_ref.is_calculus_allowed( + calculus, + self.direction + .ok_or(AnnotationError::NotInProcedure(s.span, *ident))?, + ) { return Err(AnnotationError::CalculusEncodingMismatch( s.span, calculus.name, diff --git a/src/proof_rules/omega.rs b/src/proof_rules/omega.rs index dd5b7e0d..2ef57344 100644 --- a/src/proof_rules/omega.rs +++ b/src/proof_rules/omega.rs @@ -95,17 +95,12 @@ impl Encoding for OmegaInvAnnotation { resolve.visit_expr(omega_inv) } - fn check_calculus(&self, calculus: &Calculus, direction: Direction) -> Result<(), ()> { - if direction - != match calculus.calculus_type { - CalculusType::WP | CalculusType::ERT => Direction::Down, - CalculusType::WLP => Direction::Up, - } - { - return Err(()); - } - - Ok(()) + fn is_calculus_allowed(&self, calculus: &Calculus, direction: Direction) -> bool { + matches!( + (&calculus.calculus_type, direction), + (CalculusType::Wp | CalculusType::Ert, Direction::Down) + | (CalculusType::Wlp, Direction::Up) + ) } fn transform( diff --git a/src/proof_rules/ost.rs b/src/proof_rules/ost.rs index 57b7be20..2523f361 100644 --- a/src/proof_rules/ost.rs +++ b/src/proof_rules/ost.rs @@ -92,14 +92,8 @@ impl Encoding for OSTAnnotation { resolve.visit_expr(post) } - fn check_calculus(&self, calculus: &Calculus, direction: Direction) -> Result<(), ()> { - if let CalculusType::WP = calculus.calculus_type { - if direction == Direction::Down { - return Ok(()); - } - } - - Err(()) + fn is_calculus_allowed(&self, calculus: &Calculus, direction: Direction) -> bool { + matches!(calculus.calculus_type, CalculusType::Wp) && direction == Direction::Down } fn transform( diff --git a/src/proof_rules/past.rs b/src/proof_rules/past.rs index 58623528..cce72dd6 100644 --- a/src/proof_rules/past.rs +++ b/src/proof_rules/past.rs @@ -87,14 +87,8 @@ impl Encoding for PASTAnnotation { resolve.visit_expr(k) } - fn check_calculus(&self, calculus: &Calculus, direction: Direction) -> Result<(), ()> { - if let CalculusType::ERT = calculus.calculus_type { - if direction == Direction::Up { - return Ok(()); - } - } - - Err(()) + fn is_calculus_allowed(&self, calculus: &Calculus, direction: Direction) -> bool { + matches!(calculus.calculus_type, CalculusType::Ert) && direction == Direction::Up } fn transform( diff --git a/src/proof_rules/unroll.rs b/src/proof_rules/unroll.rs index 562d889c..14ec9b5f 100644 --- a/src/proof_rules/unroll.rs +++ b/src/proof_rules/unroll.rs @@ -82,18 +82,14 @@ impl Encoding for UnrollAnnotation { resolve.visit_expr(invariant) } - fn check_calculus(&self, calculus: &Calculus, direction: Direction) -> Result<(), ()> { - if direction - != match calculus.calculus_type { - CalculusType::WP | CalculusType::ERT => Direction::Up, - CalculusType::WLP => Direction::Down, - } - { - return Err(()); - } - - Ok(()) + fn is_calculus_allowed(&self, calculus: &Calculus, direction: Direction) -> bool { + matches!( + (&calculus.calculus_type, direction), + (CalculusType::Wp | CalculusType::Ert, Direction::Up) + | (CalculusType::Wlp, Direction::Down) + ) } + fn transform( &self, tcx: &TyCtx, diff --git a/src/proof_rules/util.rs b/src/proof_rules/util.rs index 4697ebc5..3249dbc6 100644 --- a/src/proof_rules/util.rs +++ b/src/proof_rules/util.rs @@ -31,7 +31,8 @@ pub fn encode_spec( ] } -/// Encode the extend step in k-induction and bmc recursively for k times +/// Encode the extend step in k-induction and bmc recursively for k times: +/// /// # Arguments /// * `span` - The span of the new generated statement /// * `inner_stmt` - A While statement to be encoded @@ -57,7 +58,8 @@ pub fn encode_extend( ] } -/// Encode the extend step in bmc recursively for k times +/// Encode the extend step in bmc recursively for k times: +/// /// # Arguments /// * `span` - The span of the new generated statement /// * `inner_stmt` - A While statement to be encoded @@ -87,13 +89,13 @@ pub fn encode_iter(span: Span, stmt: &Stmt, next_iter: Vec) -> Option Vec { let builder = ExprBuilder::new(span); - let extrem_lit = match direction { + let extreme_lit = match direction { Direction::Up => builder.top_lit(tcx.spec_ty()), Direction::Down => builder.bot_lit(tcx.spec_ty()), }; vec![ Spanned::new(span, StmtKind::Assert(direction, expr.clone())), - Spanned::new(span, StmtKind::Assume(direction, extrem_lit)), + Spanned::new(span, StmtKind::Assume(direction, extreme_lit)), ] }