From b639576ae8c7ad12ea8c87cdff7ea3213be113ab Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Mon, 23 Sep 2024 15:16:57 +0200 Subject: [PATCH 1/5] Add RewriteControlFlow phase. --- engine/backends/fstar/fstar_backend.ml | 1 + engine/lib/diagnostics.ml | 1 + .../lib/phases/phase_rewrite_control_flow.ml | 176 ++++++++ .../lib/phases/phase_rewrite_control_flow.mli | 5 + .../toolchain__side-effects into-fstar.snap | 395 ++++++++++-------- .../toolchain__side-effects into-ssprove.snap | 34 +- tests/side-effects/src/lib.rs | 17 + 7 files changed, 442 insertions(+), 187 deletions(-) create mode 100644 engine/lib/phases/phase_rewrite_control_flow.ml create mode 100644 engine/lib/phases/phase_rewrite_control_flow.mli diff --git a/engine/backends/fstar/fstar_backend.ml b/engine/backends/fstar/fstar_backend.ml index 6430fc633..f7ccd1a54 100644 --- a/engine/backends/fstar/fstar_backend.ml +++ b/engine/backends/fstar/fstar_backend.ml @@ -1703,6 +1703,7 @@ module TransformToInputLanguage = |> Side_effect_utils.Hoist |> Phases.Hoist_disjunctive_patterns |> Phases.Simplify_match_return + |> Phases.Rewrite_control_flow |> Phases.Drop_needless_returns |> Phases.Local_mutation |> Phases.Reject.Continue diff --git a/engine/lib/diagnostics.ml b/engine/lib/diagnostics.ml index 628b99e1c..b04bd5868 100644 --- a/engine/lib/diagnostics.ml +++ b/engine/lib/diagnostics.ml @@ -39,6 +39,7 @@ module Phase = struct | ResugarWhileLoops | ResugarForIndexLoops | ResugarQuestionMarks + | RewriteControlFlow | SimplifyQuestionMarks | Specialize | HoistSideEffects diff --git a/engine/lib/phases/phase_rewrite_control_flow.ml b/engine/lib/phases/phase_rewrite_control_flow.ml new file mode 100644 index 000000000..e09d6e30b --- /dev/null +++ b/engine/lib/phases/phase_rewrite_control_flow.ml @@ -0,0 +1,176 @@ +(* This phase rewrites: `if c {return a}; b` as `if c {return a; b} else {b}` + and does the equivalent transformation for pattern matchings. *) + +open! Prelude + +module Make (F : Features.T) = + Phase_utils.MakeMonomorphicPhase + (F) + (struct + let phase_id = Diagnostics.Phase.RewriteControlFlow + + open Ast.Make (F) + module U = Ast_utils.Make (F) + module Visitors = Ast_visitors.Make (F) + + module Error = Phase_utils.MakeError (struct + let ctx = Diagnostics.Context.Phase phase_id + end) + + let has_return = + object (_self) + inherit [_] Visitors.reduce as super + method zero = false + method plus = ( || ) + + method! visit_expr' () e = + match e with Return _ -> true | _ -> super#visit_expr' () e + end + + let rewrite_control_flow = + object (self) + inherit [_] Visitors.map as super + + method! visit_expr () e = + match e.e with + | _ when not (has_return#visit_expr () e) -> e + | Let + { + monadic = None; + lhs; + rhs = { e = If { cond; then_; else_ }; _ } as rhs; + body; + } -> + let cond = self#visit_expr () cond in + let then_has_return = has_return#visit_expr () then_ in + let else_has_return = + Option.map else_ ~f:(has_return#visit_expr ()) + |> Option.value ~default:false + in + let rewrite = then_has_return || else_has_return in + if rewrite then + let then_ = + { + e with + e = Let { monadic = None; lhs; rhs = then_; body }; + } + in + let then_ = self#visit_expr () then_ in + let else_ = + Some + (match else_ with + | Some else_ -> + self#visit_expr () + { + e with + e = Let { monadic = None; lhs; rhs = else_; body }; + } + | None -> body) + in + + { rhs with e = If { cond; then_; else_ } } + else + let body = self#visit_expr () body in + { + e with + e = + Let + { + monadic = None; + lhs; + rhs = { rhs with e = If { cond; then_; else_ } }; + body; + }; + } + (* We need this case to make sure we take the `if` all the way up a sequence of nested `let` + and not just one level. *) + | Let { monadic = None; lhs; rhs = { e = Let _; _ } as rhs; body } + -> ( + let body = self#visit_expr () body in + match self#visit_expr () rhs with + | { e = If { cond; then_; else_ = Some else_ }; _ } -> + (* In this case we know we already rewrote the rhs so we should take the `if` one level higher. *) + let rewrite_branch branch = + { + branch with + e = Let { monadic = None; lhs; rhs = branch; body }; + } + in + { + rhs with + e = + If + { + cond; + then_ = rewrite_branch then_; + else_ = Some (rewrite_branch else_); + }; + } + | rhs -> { e with e = Let { monadic = None; lhs; rhs; body } }) + | Let + { + monadic = None; + lhs; + rhs = { e = Match { scrutinee; arms }; _ }; + body; + } -> + let rewrite = + List.fold arms ~init:false ~f:(fun acc (arm : arm) -> + acc || has_return#visit_arm () arm) + in + if rewrite then + { + e with + e = + Match + { + scrutinee = self#visit_expr () scrutinee; + arms = + List.map arms ~f:(fun arm -> + let arm_body = arm.arm.body in + let arm_body = + { + arm_body with + e = + Let + { + monadic = None; + lhs; + rhs = arm_body; + body; + }; + } + in + self#visit_arm () + { + arm with + arm = { arm.arm with body = arm_body }; + }); + }; + } + else e + | _ -> super#visit_expr () e + end + + (* This visitor allows to remove instructions after a `return` so that `drop_needless_return` can simplify them. *) + let remove_after_return = + object (self) + inherit [_] Visitors.map as super + + method! visit_expr () e = + match e.e with + | Let { monadic = None; lhs; rhs; body } -> ( + let rhs = self#visit_expr () rhs in + let body = self#visit_expr () body in + match rhs.e with + | Return _ -> rhs + | _ -> { e with e = Let { monadic = None; lhs; rhs; body } }) + | _ -> super#visit_expr () e + end + + let ditems = + List.map + ~f: + (rewrite_control_flow#visit_item () + >> remove_after_return#visit_item ()) + end) diff --git a/engine/lib/phases/phase_rewrite_control_flow.mli b/engine/lib/phases/phase_rewrite_control_flow.mli new file mode 100644 index 000000000..b5dfa2314 --- /dev/null +++ b/engine/lib/phases/phase_rewrite_control_flow.mli @@ -0,0 +1,5 @@ +(** This phase finds control flow expression (`if` or `match`) with a `return` expression + in one of the branches. We replace them by replicating what comes after in all the branches. + This allows the `return` to be eliminated by `drop_needless_returns`. *) + +module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE diff --git a/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap b/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap index 66de5587a..8b7103da7 100644 --- a/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap +++ b/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap @@ -56,56 +56,26 @@ let direct_result_question_mark_coercion (y: Core.Result.t_Result i8 u16) /// Exercise early returns with control flow and loops let early_returns (x: u32) : u32 = - Rust_primitives.Hax.Control_flow_monad.Mexception.run (let! _:Prims.unit = - if x >. 3ul - then - let! hoist2:Rust_primitives.Hax.t_Never = - Core.Ops.Control_flow.ControlFlow_Break 0ul - <: - Core.Ops.Control_flow.t_ControlFlow u32 Rust_primitives.Hax.t_Never - in - Core.Ops.Control_flow.ControlFlow_Continue (Rust_primitives.Hax.never_to_any hoist2) - <: - Core.Ops.Control_flow.t_ControlFlow u32 Prims.unit - else - Core.Ops.Control_flow.ControlFlow_Continue () - <: - Core.Ops.Control_flow.t_ControlFlow u32 Prims.unit - in - let! x, hoist5:(u32 & u32) = - if x >. 30ul - then - match true with - | true -> - let! hoist4:Rust_primitives.Hax.t_Never = - Core.Ops.Control_flow.ControlFlow_Break 34ul - <: - Core.Ops.Control_flow.t_ControlFlow u32 Rust_primitives.Hax.t_Never - in - Core.Ops.Control_flow.ControlFlow_Continue - (x, Rust_primitives.Hax.never_to_any hoist4 <: (u32 & u32)) - <: - Core.Ops.Control_flow.t_ControlFlow u32 (u32 & u32) - | _ -> - Core.Ops.Control_flow.ControlFlow_Continue (x, 3ul <: (u32 & u32)) + if x >. 3ul + then 0ul + else + let x, hoist5:(u32 & Rust_primitives.Hax.t_Never) = + if x >. 30ul + then + match true with + | true -> + x, + (Core.Ops.Control_flow.ControlFlow_Break 34ul <: - Core.Ops.Control_flow.t_ControlFlow u32 (u32 & u32) - else - Core.Ops.Control_flow.ControlFlow_Continue - (let x:u32 = x +! 9ul in - x, x +! 1ul <: (u32 & u32)) + Core.Ops.Control_flow.t_ControlFlow u32 Rust_primitives.Hax.t_Never) <: - Core.Ops.Control_flow.t_ControlFlow u32 (u32 & u32) - in - let! hoist8:Rust_primitives.Hax.t_Never = - Core.Ops.Control_flow.ControlFlow_Break - (Core.Num.impl__u32__wrapping_add (Core.Num.impl__u32__wrapping_add 123ul hoist5 <: u32) x) - <: - Core.Ops.Control_flow.t_ControlFlow u32 Rust_primitives.Hax.t_Never - in - Core.Ops.Control_flow.ControlFlow_Continue (Rust_primitives.Hax.never_to_any hoist8) - <: - Core.Ops.Control_flow.t_ControlFlow u32 u32) + (u32 & Rust_primitives.Hax.t_Never) + | _ -> x, 3ul <: (u32 & u32) + else + let x:u32 = x +! 9ul in + x, x +! 1ul <: (u32 & u32) + in + Core.Num.impl__u32__wrapping_add (Core.Num.impl__u32__wrapping_add 123ul hoist5 <: u32) x /// Exercise local mutation with control flow and loops let local_mutation (x: u32) : u32 = @@ -150,115 +120,163 @@ let local_mutation (x: u32) : u32 = /// Test question mark on `Option`s with some control flow let options (x y: Core.Option.t_Option u8) (z: Core.Option.t_Option u64) : Core.Option.t_Option u8 = - Rust_primitives.Hax.Control_flow_monad.Mexception.run (match x with - | Core.Option.Option_Some hoist19 -> - let! hoist26:Core.Option.t_Option u8 = - if hoist19 >. 10uy - then - match x with - | Core.Option.Option_Some hoist21 -> - Core.Ops.Control_flow.ControlFlow_Continue - (Core.Option.Option_Some (Core.Num.impl__u8__wrapping_add hoist21 3uy) - <: - Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) - (Core.Option.t_Option u8) - | Core.Option.Option_None -> - Core.Ops.Control_flow.ControlFlow_Break - (Core.Option.Option_None <: Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) - (Core.Option.t_Option u8) - else - match x with - | Core.Option.Option_Some hoist24 -> - (match y with - | Core.Option.Option_Some hoist23 -> - Core.Ops.Control_flow.ControlFlow_Continue - (Core.Option.Option_Some (Core.Num.impl__u8__wrapping_add hoist24 hoist23) - <: - Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) - (Core.Option.t_Option u8) - | Core.Option.Option_None -> - Core.Ops.Control_flow.ControlFlow_Break - (Core.Option.Option_None <: Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) - (Core.Option.t_Option u8)) - | Core.Option.Option_None -> - Core.Ops.Control_flow.ControlFlow_Break - (Core.Option.Option_None <: Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) - (Core.Option.t_Option u8) - in - (match hoist26 with + match x with + | Core.Option.Option_Some hoist19 -> + if hoist19 >. 10uy + then + match x with + | Core.Option.Option_Some hoist21 -> + (match + Core.Option.Option_Some (Core.Num.impl__u8__wrapping_add hoist21 3uy) + <: + Core.Option.t_Option u8 + with | Core.Option.Option_Some hoist27 -> - let! v:u8 = - match hoist27 with + (match hoist27 with | 3uy -> (match Core.Option.Option_None <: Core.Option.t_Option u8 with | Core.Option.Option_Some some -> - Core.Ops.Control_flow.ControlFlow_Continue some - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) u8 - | Core.Option.Option_None -> - Core.Ops.Control_flow.ControlFlow_Break - (Core.Option.Option_None <: Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) u8) + let v:u8 = some in + (match x with + | Core.Option.Option_Some hoist28 -> + (match y with + | Core.Option.Option_Some hoist29 -> + Core.Option.Option_Some + (Core.Num.impl__u8__wrapping_add (Core.Num.impl__u8__wrapping_add v + hoist28 + <: + u8) + hoist29) + <: + Core.Option.t_Option u8 + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8) | 4uy -> (match z with | Core.Option.Option_Some hoist16 -> - Core.Ops.Control_flow.ControlFlow_Continue - (4uy +! (if hoist16 >. 4uL <: bool then 0uy else 3uy)) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) u8 - | Core.Option.Option_None -> - Core.Ops.Control_flow.ControlFlow_Break - (Core.Option.Option_None <: Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) u8) - | _ -> - Core.Ops.Control_flow.ControlFlow_Continue 12uy - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) u8 - in - Core.Ops.Control_flow.ControlFlow_Continue - (match x with - | Core.Option.Option_Some hoist28 -> - (match y with - | Core.Option.Option_Some hoist29 -> - Core.Option.Option_Some - (Core.Num.impl__u8__wrapping_add (Core.Num.impl__u8__wrapping_add v hoist28 - <: - u8) - hoist29) - <: - Core.Option.t_Option u8 + let v:u8 = 4uy +! (if hoist16 >. 4uL <: bool then 0uy else 3uy) in + (match x with + | Core.Option.Option_Some hoist28 -> + (match y with + | Core.Option.Option_Some hoist29 -> + Core.Option.Option_Some + (Core.Num.impl__u8__wrapping_add (Core.Num.impl__u8__wrapping_add v + hoist28 + <: + u8) + hoist29) + <: + Core.Option.t_Option u8 + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8) - | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) (Core.Option.t_Option u8) - | Core.Option.Option_None -> - Core.Ops.Control_flow.ControlFlow_Continue - (Core.Option.Option_None <: Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) (Core.Option.t_Option u8)) - | Core.Option.Option_None -> - Core.Ops.Control_flow.ControlFlow_Continue - (Core.Option.Option_None <: Core.Option.t_Option u8) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Option.t_Option u8) (Core.Option.t_Option u8)) + | _ -> + let v:u8 = 12uy in + match x with + | Core.Option.Option_Some hoist28 -> + (match y with + | Core.Option.Option_Some hoist29 -> + Core.Option.Option_Some + (Core.Num.impl__u8__wrapping_add (Core.Num.impl__u8__wrapping_add v hoist28 + <: + u8) + hoist29) + <: + Core.Option.t_Option u8 + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8 + ) + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8 + else + (match x with + | Core.Option.Option_Some hoist24 -> + (match y with + | Core.Option.Option_Some hoist23 -> + (match + Core.Option.Option_Some (Core.Num.impl__u8__wrapping_add hoist24 hoist23) + <: + Core.Option.t_Option u8 + with + | Core.Option.Option_Some hoist27 -> + (match hoist27 with + | 3uy -> + (match Core.Option.Option_None <: Core.Option.t_Option u8 with + | Core.Option.Option_Some some -> + let v:u8 = some in + (match x with + | Core.Option.Option_Some hoist28 -> + (match y with + | Core.Option.Option_Some hoist29 -> + Core.Option.Option_Some + (Core.Num.impl__u8__wrapping_add (Core.Num.impl__u8__wrapping_add v + hoist28 + <: + u8) + hoist29) + <: + Core.Option.t_Option u8 + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | 4uy -> + (match z with + | Core.Option.Option_Some hoist16 -> + let v:u8 = 4uy +! (if hoist16 >. 4uL <: bool then 0uy else 3uy) in + (match x with + | Core.Option.Option_Some hoist28 -> + (match y with + | Core.Option.Option_Some hoist29 -> + Core.Option.Option_Some + (Core.Num.impl__u8__wrapping_add (Core.Num.impl__u8__wrapping_add v + hoist28 + <: + u8) + hoist29) + <: + Core.Option.t_Option u8 + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | _ -> + let v:u8 = 12uy in + match x with + | Core.Option.Option_Some hoist28 -> + (match y with + | Core.Option.Option_Some hoist29 -> + Core.Option.Option_Some + (Core.Num.impl__u8__wrapping_add (Core.Num.impl__u8__wrapping_add v + hoist28 + <: + u8) + hoist29) + <: + Core.Option.t_Option u8 + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> + Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8) + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option u8 /// Test question mark on `Result`s with local mutation let question_mark (x: u32) : Core.Result.t_Result u32 u32 = - Rust_primitives.Hax.Control_flow_monad.Mexception.run (let! x:u32 = - if x >. 40ul - then + Rust_primitives.Hax.Control_flow_monad.Mexception.run (if x >. 40ul <: bool + then + let! x:u32 = let y:u32 = 0ul in let x:u32 = Core.Num.impl__u32__wrapping_add x 3ul in let y:u32 = Core.Num.impl__u32__wrapping_add x y in @@ -286,18 +304,50 @@ let question_mark (x: u32) : Core.Result.t_Result u32 u32 = Core.Ops.Control_flow.ControlFlow_Continue x <: Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) u32 - else - Core.Ops.Control_flow.ControlFlow_Continue x + in + Core.Ops.Control_flow.ControlFlow_Continue + (Core.Result.Result_Ok (Core.Num.impl__u32__wrapping_add 3ul x) <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) u32 - in - Core.Ops.Control_flow.ControlFlow_Continue - (Core.Result.Result_Ok (Core.Num.impl__u32__wrapping_add 3ul x) + Core.Result.t_Result u32 u32) <: - Core.Result.t_Result u32 u32) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) - (Core.Result.t_Result u32 u32)) + Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) + (Core.Result.t_Result u32 u32) + else + Core.Ops.Control_flow.ControlFlow_Continue + (Core.Result.Result_Ok (Core.Num.impl__u32__wrapping_add 3ul x <: u32) + <: + Core.Result.t_Result u32 u32) + <: + Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) + (Core.Result.t_Result u32 u32)) + +let simplifiable_question_mark (c: bool) (x: Core.Option.t_Option i32) : Core.Option.t_Option i32 = + if c + then + match x with + | Core.Option.Option_Some hoist33 -> + let a:i32 = hoist33 +! 10l in + let b:i32 = 20l in + Core.Option.Option_Some (a +! b) <: Core.Option.t_Option i32 + | Core.Option.Option_None -> Core.Option.Option_None <: Core.Option.t_Option i32 + else + let a:i32 = 0l in + let b:i32 = 20l in + Core.Option.Option_Some (a +! b) <: Core.Option.t_Option i32 + +let simplifiable_return (c1 c2: bool) : i32 = + let x:i32 = 0l in + if c1 + then + if c2 + then 1l + else + let x:i32 = + let x:i32 = 1l in + x + in + x + else x type t_A = | A : t_A @@ -310,33 +360,12 @@ type t_Bar = { /// Combine `?` and early return let monad_lifting (x: u8) : Core.Result.t_Result t_A t_B = - Rust_primitives.Hax.Control_flow_monad.Mexception.run (if x >. 123uy <: bool - then - match Core.Result.Result_Err (B <: t_B) <: Core.Result.t_Result t_A t_B with - | Core.Result.Result_Ok hoist33 -> - let! hoist35:Rust_primitives.Hax.t_Never = - Core.Ops.Control_flow.ControlFlow_Break - (Core.Result.Result_Ok hoist33 <: Core.Result.t_Result t_A t_B) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result t_A t_B) - Rust_primitives.Hax.t_Never - in - Core.Ops.Control_flow.ControlFlow_Continue (Rust_primitives.Hax.never_to_any hoist35) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result t_A t_B) - (Core.Result.t_Result t_A t_B) - | Core.Result.Result_Err err -> - Core.Ops.Control_flow.ControlFlow_Continue - (Core.Result.Result_Err err <: Core.Result.t_Result t_A t_B) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result t_A t_B) - (Core.Result.t_Result t_A t_B) - else - Core.Ops.Control_flow.ControlFlow_Continue - (Core.Result.Result_Ok (A <: t_A) <: Core.Result.t_Result t_A t_B) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result t_A t_B) - (Core.Result.t_Result t_A t_B)) + if x >. 123uy + then + match Core.Result.Result_Err (B <: t_B) <: Core.Result.t_Result t_A t_B with + | Core.Result.Result_Ok hoist35 -> Core.Result.Result_Ok hoist35 <: Core.Result.t_Result t_A t_B + | Core.Result.Result_Err err -> Core.Result.Result_Err err <: Core.Result.t_Result t_A t_B + else Core.Result.Result_Ok (A <: t_A) <: Core.Result.t_Result t_A t_B type t_Foo = { f_x:bool; diff --git a/test-harness/src/snapshots/toolchain__side-effects into-ssprove.snap b/test-harness/src/snapshots/toolchain__side-effects into-ssprove.snap index 22288365a..90289584e 100644 --- a/test-harness/src/snapshots/toolchain__side-effects into-ssprove.snap +++ b/test-harness/src/snapshots/toolchain__side-effects into-ssprove.snap @@ -182,6 +182,32 @@ Equations question_mark {L1 : {fset Location}} {I1 : Interface} (x : both L1 I1 Result_Ok (Result_Ok (impl__u32__wrapping_add (ret_both (3 : int32)) x)))) : both (L1 :|: fset [y_loc]) I1 (t_Result int32 int32). Fail Next Obligation. +Equations simplifiable_question_mark {L1 : {fset Location}} {L2 : {fset Location}} {I1 : Interface} {I2 : Interface} (c : both L1 I1 'bool) (x : both L2 I2 (t_Option int32)) : both (L1 :|: L2) (I1 :|: I2) (t_Option int32) := + simplifiable_question_mark c x := + solve_lift (run (letm[choice_typeMonad.option_bind_code] a := ifb c + then letm[choice_typeMonad.option_bind_code] hoist33 := x in + Option_Some (hoist33 .+ (ret_both (10 : int32))) + else Option_Some (ret_both (0 : int32)) in + Option_Some (letb b := ret_both (20 : int32) in + Option_Some (a .+ b)))) : both (L1 :|: L2) (I1 :|: I2) (t_Option int32). +Fail Next Obligation. + +Definition x_loc : Location := + (int32;3%nat). +Equations simplifiable_return {L1 : {fset Location}} {L2 : {fset Location}} {I1 : Interface} {I2 : Interface} (c1 : both L1 I1 'bool) (c2 : both L2 I2 'bool) : both (L1 :|: L2 :|: fset [x_loc]) (I1 :|: I2) int32 := + simplifiable_return c1 c2 := + solve_lift (run (letb x loc(x_loc) := ret_both (0 : int32) in + letm[choice_typeMonad.result_bind_code int32] _ := ifb c1 + then letm[choice_typeMonad.result_bind_code int32] _ := ifb c2 + then letm[choice_typeMonad.result_bind_code int32] hoist34 := ControlFlow_Break (ret_both (1 : int32)) in + ControlFlow_Continue (never_to_any hoist34) + else () in + ControlFlow_Continue (letb _ := assign todo(term) in + ret_both (tt : 'unit)) + else () in + ControlFlow_Continue x)) : both (L1 :|: L2 :|: fset [x_loc]) (I1 :|: I2) int32. +Fail Next Obligation. + Definition t_A : choice_type := 'unit. Equations Build_t_A : both (fset []) (fset []) (t_A) := @@ -220,10 +246,10 @@ Notation "'Build_t_Bar' '[' x ']' '(' 'f_b' ':=' y ')'" := (Build_t_Bar (f_a := Equations monad_lifting {L1 : {fset Location}} {I1 : Interface} (x : both L1 I1 int8) : both L1 I1 (t_Result t_A t_B) := monad_lifting x := solve_lift (run (ifb x >.? (ret_both (123 : int8)) - then letm[choice_typeMonad.result_bind_code (t_Result t_A t_B)] hoist33 := ControlFlow_Continue (Result_Err B) in - letb hoist34 := Result_Ok hoist33 in - letm[choice_typeMonad.result_bind_code (t_Result t_A t_B)] hoist35 := ControlFlow_Break hoist34 in - ControlFlow_Continue (never_to_any hoist35) + then letm[choice_typeMonad.result_bind_code (t_Result t_A t_B)] hoist35 := ControlFlow_Continue (Result_Err B) in + letb hoist36 := Result_Ok hoist35 in + letm[choice_typeMonad.result_bind_code (t_Result t_A t_B)] hoist37 := ControlFlow_Break hoist36 in + ControlFlow_Continue (never_to_any hoist37) else ControlFlow_Continue (Result_Ok A))) : both L1 I1 (t_Result t_A t_B). Fail Next Obligation. diff --git a/tests/side-effects/src/lib.rs b/tests/side-effects/src/lib.rs index 8528df943..a0a4bf377 100644 --- a/tests/side-effects/src/lib.rs +++ b/tests/side-effects/src/lib.rs @@ -63,6 +63,23 @@ fn early_returns(mut x: u32) -> u32 { .wrapping_add(x); } +fn simplifiable_return(c1: bool, c2: bool) -> i32 { + let mut x = 0; + if c1 { + if c2 { + return 1; + } + x = 1; + } + x +} + +fn simplifiable_question_mark(c: bool, x: Option) -> Option { + let a = if c { x? + 10 } else { 0 }; + let b = 20; + Some(a + b) +} + /// Question mark without error coercion fn direct_result_question_mark(y: Result<(), u32>) -> Result { y?; From ad85ba2340bde85a27c17d7e88f0b925810d9f5d Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Tue, 24 Sep 2024 13:39:18 +0200 Subject: [PATCH 2/5] Refactor RewriteControlFlow phase. --- .../lib/phases/phase_rewrite_control_flow.ml | 198 ++++++------------ .../lib/phases/phase_rewrite_control_flow.mli | 3 +- .../toolchain__side-effects into-fstar.snap | 107 ++++------ .../toolchain__side-effects into-ssprove.snap | 9 +- tests/side-effects/src/lib.rs | 9 +- 5 files changed, 120 insertions(+), 206 deletions(-) diff --git a/engine/lib/phases/phase_rewrite_control_flow.ml b/engine/lib/phases/phase_rewrite_control_flow.ml index e09d6e30b..625b8e1aa 100644 --- a/engine/lib/phases/phase_rewrite_control_flow.ml +++ b/engine/lib/phases/phase_rewrite_control_flow.ml @@ -34,143 +34,75 @@ module Make (F : Features.T) = method! visit_expr () e = match e.e with | _ when not (has_return#visit_expr () e) -> e - | Let - { - monadic = None; - lhs; - rhs = { e = If { cond; then_; else_ }; _ } as rhs; - body; - } -> - let cond = self#visit_expr () cond in - let then_has_return = has_return#visit_expr () then_ in - let else_has_return = - Option.map else_ ~f:(has_return#visit_expr ()) - |> Option.value ~default:false + (* Returns in loops will be handled by issue #196 *) + | Loop _ -> e + | Let _ -> ( + (* Collect let bindings to get the sequence + of "statements", find the first "statement" that is a + control flow containing a return. Rewrite it. + *) + let stmts, final = U.collect_let_bindings e in + let before_after i = + let stmts_before, stmts_after = List.split_n stmts i in + let stmts_after = List.tl_exn stmts_after in + (stmts_before, stmts_after) in - let rewrite = then_has_return || else_has_return in - if rewrite then - let then_ = - { - e with - e = Let { monadic = None; lhs; rhs = then_; body }; - } + let inline_in_branch branch p stmts_after final = + let branch_stmts, branch_final = + U.collect_let_bindings branch in - let then_ = self#visit_expr () then_ in - let else_ = - Some - (match else_ with - | Some else_ -> - self#visit_expr () - { - e with - e = Let { monadic = None; lhs; rhs = else_; body }; - } - | None -> body) - in - - { rhs with e = If { cond; then_; else_ } } - else - let body = self#visit_expr () body in - { - e with - e = - Let - { - monadic = None; - lhs; - rhs = { rhs with e = If { cond; then_; else_ } }; - body; - }; - } - (* We need this case to make sure we take the `if` all the way up a sequence of nested `let` - and not just one level. *) - | Let { monadic = None; lhs; rhs = { e = Let _; _ } as rhs; body } - -> ( - let body = self#visit_expr () body in - match self#visit_expr () rhs with - | { e = If { cond; then_; else_ = Some else_ }; _ } -> - (* In this case we know we already rewrote the rhs so we should take the `if` one level higher. *) - let rewrite_branch branch = - { - branch with - e = Let { monadic = None; lhs; rhs = branch; body }; - } - in - { - rhs with - e = - If - { - cond; - then_ = rewrite_branch then_; - else_ = Some (rewrite_branch else_); - }; - } - | rhs -> { e with e = Let { monadic = None; lhs; rhs; body } }) - | Let - { - monadic = None; - lhs; - rhs = { e = Match { scrutinee; arms }; _ }; - body; - } -> - let rewrite = - List.fold arms ~init:false ~f:(fun acc (arm : arm) -> - acc || has_return#visit_arm () arm) + U.make_lets + (branch_stmts @ ((p, branch_final) :: stmts_after)) + final in - if rewrite then - { - e with - e = - Match - { - scrutinee = self#visit_expr () scrutinee; - arms = - List.map arms ~f:(fun arm -> - let arm_body = arm.arm.body in - let arm_body = - { - arm_body with - e = - Let - { - monadic = None; - lhs; - rhs = arm_body; - body; - }; - } - in - self#visit_arm () - { - arm with - arm = { arm.arm with body = arm_body }; - }); - }; - } - else e - | _ -> super#visit_expr () e - end - - (* This visitor allows to remove instructions after a `return` so that `drop_needless_return` can simplify them. *) - let remove_after_return = - object (self) - inherit [_] Visitors.map as super - - method! visit_expr () e = - match e.e with - | Let { monadic = None; lhs; rhs; body } -> ( - let rhs = self#visit_expr () rhs in - let body = self#visit_expr () body in - match rhs.e with - | Return _ -> rhs - | _ -> { e with e = Let { monadic = None; lhs; rhs; body } }) + match + List.findi stmts ~f:(fun _ (_, e) -> + match e.e with + | (If _ | Match _) when has_return#visit_expr () e -> true + | Return _ -> true + | _ -> false) + with + | Some (i, (p, ({ e = If { cond; then_; else_ }; _ } as rhs))) + -> + (* We know there is no "return" in the condition + so we must rewrite the if *) + let stmts_before, stmts_after = before_after i in + let then_ = inline_in_branch then_ p stmts_after final in + let else_ = + Some + (match else_ with + | Some else_ -> + inline_in_branch else_ p stmts_after final + | None -> U.make_lets stmts_after final) + in + U.make_lets stmts_before + { rhs with e = If { cond; then_; else_ } } + |> self#visit_expr () + | Some (i, (p, ({ e = Match { scrutinee; arms }; _ } as rhs))) + -> + let stmts_before, stmts_after = before_after i in + let arms = + List.map arms ~f:(fun arm -> + let body = + inline_in_branch arm.arm.body p stmts_after final + in + { arm with arm = { arm.arm with body } }) + in + U.make_lets stmts_before + { rhs with e = Match { scrutinee; arms } } + |> self#visit_expr () + (* The statements coming after a "return" are useless. *) + | Some (i, (_, ({ e = Return _; _ } as rhs))) -> + let stmts_before, _ = before_after i in + U.make_lets stmts_before rhs |> self#visit_expr () + | _ -> + let stmts = + List.map stmts ~f:(fun (p, e) -> + (p, self#visit_expr () e)) + in + U.make_lets stmts (self#visit_expr () final)) | _ -> super#visit_expr () e end - let ditems = - List.map - ~f: - (rewrite_control_flow#visit_item () - >> remove_after_return#visit_item ()) + let ditems = List.map ~f:(rewrite_control_flow#visit_item ()) end) diff --git a/engine/lib/phases/phase_rewrite_control_flow.mli b/engine/lib/phases/phase_rewrite_control_flow.mli index b5dfa2314..ccc4ed630 100644 --- a/engine/lib/phases/phase_rewrite_control_flow.mli +++ b/engine/lib/phases/phase_rewrite_control_flow.mli @@ -1,5 +1,6 @@ (** This phase finds control flow expression (`if` or `match`) with a `return` expression in one of the branches. We replace them by replicating what comes after in all the branches. - This allows the `return` to be eliminated by `drop_needless_returns`. *) + This allows the `return` to be eliminated by `drop_needless_returns`. + This phase should come after phase_local_mutation. *) module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE diff --git a/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap b/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap index 8b7103da7..abc8ec0b4 100644 --- a/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap +++ b/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap @@ -59,23 +59,17 @@ let early_returns (x: u32) : u32 = if x >. 3ul then 0ul else - let x, hoist5:(u32 & Rust_primitives.Hax.t_Never) = - if x >. 30ul - then - match true with - | true -> - x, - (Core.Ops.Control_flow.ControlFlow_Break 34ul - <: - Core.Ops.Control_flow.t_ControlFlow u32 Rust_primitives.Hax.t_Never) + if x >. 30ul + then + match true with + | true -> 34ul + | _ -> Core.Num.impl__u32__wrapping_add (Core.Num.impl__u32__wrapping_add 123ul 3ul <: u32) x + else + let x:u32 = x +! 9ul in + Core.Num.impl__u32__wrapping_add (Core.Num.impl__u32__wrapping_add 123ul (x +! 1ul <: u32) <: - (u32 & Rust_primitives.Hax.t_Never) - | _ -> x, 3ul <: (u32 & u32) - else - let x:u32 = x +! 9ul in - x, x +! 1ul <: (u32 & u32) - in - Core.Num.impl__u32__wrapping_add (Core.Num.impl__u32__wrapping_add 123ul hoist5 <: u32) x + u32) + x /// Exercise local mutation with control flow and loops let local_mutation (x: u32) : u32 = @@ -274,52 +268,28 @@ let options (x y: Core.Option.t_Option u8) (z: Core.Option.t_Option u64) : Core. /// Test question mark on `Result`s with local mutation let question_mark (x: u32) : Core.Result.t_Result u32 u32 = - Rust_primitives.Hax.Control_flow_monad.Mexception.run (if x >. 40ul <: bool - then - let! x:u32 = - let y:u32 = 0ul in - let x:u32 = Core.Num.impl__u32__wrapping_add x 3ul in - let y:u32 = Core.Num.impl__u32__wrapping_add x y in - let x:u32 = Core.Num.impl__u32__wrapping_add x y in - if x >. 90ul - then - match Core.Result.Result_Err 12uy <: Core.Result.t_Result Prims.unit u8 with - | Core.Result.Result_Ok ok -> - Core.Ops.Control_flow.ControlFlow_Continue x - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) u32 - | Core.Result.Result_Err err -> - let! _:Prims.unit = - Core.Ops.Control_flow.ControlFlow_Break - (Core.Result.Result_Err (Core.Convert.f_from #FStar.Tactics.Typeclasses.solve err) - <: - Core.Result.t_Result u32 u32) - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) Prims.unit - in - Core.Ops.Control_flow.ControlFlow_Continue x - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) u32 - else - Core.Ops.Control_flow.ControlFlow_Continue x - <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) u32 - in - Core.Ops.Control_flow.ControlFlow_Continue - (Core.Result.Result_Ok (Core.Num.impl__u32__wrapping_add 3ul x) - <: - Core.Result.t_Result u32 u32) + if x >. 40ul + then + let y:u32 = 0ul in + let x:u32 = Core.Num.impl__u32__wrapping_add x 3ul in + let y:u32 = Core.Num.impl__u32__wrapping_add x y in + let x:u32 = Core.Num.impl__u32__wrapping_add x y in + if x >. 90ul + then + match Core.Result.Result_Err 12uy <: Core.Result.t_Result Prims.unit u8 with + | Core.Result.Result_Ok ok -> + let _:Prims.unit = ok in + Core.Result.Result_Ok (Core.Num.impl__u32__wrapping_add 3ul x) <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) - (Core.Result.t_Result u32 u32) - else - Core.Ops.Control_flow.ControlFlow_Continue - (Core.Result.Result_Ok (Core.Num.impl__u32__wrapping_add 3ul x <: u32) - <: - Core.Result.t_Result u32 u32) + Core.Result.t_Result u32 u32 + | Core.Result.Result_Err err -> + Core.Result.Result_Err (Core.Convert.f_from #FStar.Tactics.Typeclasses.solve err) <: - Core.Ops.Control_flow.t_ControlFlow (Core.Result.t_Result u32 u32) - (Core.Result.t_Result u32 u32)) + Core.Result.t_Result u32 u32 + else + Core.Result.Result_Ok (Core.Num.impl__u32__wrapping_add 3ul x) <: Core.Result.t_Result u32 u32 + else + Core.Result.Result_Ok (Core.Num.impl__u32__wrapping_add 3ul x) <: Core.Result.t_Result u32 u32 let simplifiable_question_mark (c: bool) (x: Core.Option.t_Option i32) : Core.Option.t_Option i32 = if c @@ -335,17 +305,22 @@ let simplifiable_question_mark (c: bool) (x: Core.Option.t_Option i32) : Core.Op let b:i32 = 20l in Core.Option.Option_Some (a +! b) <: Core.Option.t_Option i32 -let simplifiable_return (c1 c2: bool) : i32 = +let simplifiable_return (c1 c2 c3: bool) : i32 = let x:i32 = 0l in if c1 then if c2 - then 1l - else - let x:i32 = - let x:i32 = 1l in + then + let x:i32 = x +! 10l in + if c3 + then 1l + else + let x:i32 = x +! 1l in + let _:Prims.unit = () in x - in + else + let x:i32 = x +! 1l in + let _:Prims.unit = () in x else x diff --git a/test-harness/src/snapshots/toolchain__side-effects into-ssprove.snap b/test-harness/src/snapshots/toolchain__side-effects into-ssprove.snap index 90289584e..ebff02a35 100644 --- a/test-harness/src/snapshots/toolchain__side-effects into-ssprove.snap +++ b/test-harness/src/snapshots/toolchain__side-effects into-ssprove.snap @@ -194,18 +194,21 @@ Fail Next Obligation. Definition x_loc : Location := (int32;3%nat). -Equations simplifiable_return {L1 : {fset Location}} {L2 : {fset Location}} {I1 : Interface} {I2 : Interface} (c1 : both L1 I1 'bool) (c2 : both L2 I2 'bool) : both (L1 :|: L2 :|: fset [x_loc]) (I1 :|: I2) int32 := - simplifiable_return c1 c2 := +Equations simplifiable_return {L1 : {fset Location}} {L2 : {fset Location}} {L3 : {fset Location}} {I1 : Interface} {I2 : Interface} {I3 : Interface} (c1 : both L1 I1 'bool) (c2 : both L2 I2 'bool) (c3 : both L3 I3 'bool) : both (L1 :|: L2 :|: L3 :|: fset [x_loc]) (I1 :|: I2 :|: I3) int32 := + simplifiable_return c1 c2 c3 := solve_lift (run (letb x loc(x_loc) := ret_both (0 : int32) in letm[choice_typeMonad.result_bind_code int32] _ := ifb c1 then letm[choice_typeMonad.result_bind_code int32] _ := ifb c2 + then letb _ := assign todo(term) in + ifb c3 then letm[choice_typeMonad.result_bind_code int32] hoist34 := ControlFlow_Break (ret_both (1 : int32)) in ControlFlow_Continue (never_to_any hoist34) + else () else () in ControlFlow_Continue (letb _ := assign todo(term) in ret_both (tt : 'unit)) else () in - ControlFlow_Continue x)) : both (L1 :|: L2 :|: fset [x_loc]) (I1 :|: I2) int32. + ControlFlow_Continue x)) : both (L1 :|: L2 :|: L3 :|: fset [x_loc]) (I1 :|: I2 :|: I3) int32. Fail Next Obligation. Definition t_A : choice_type := diff --git a/tests/side-effects/src/lib.rs b/tests/side-effects/src/lib.rs index a0a4bf377..7bef76cb4 100644 --- a/tests/side-effects/src/lib.rs +++ b/tests/side-effects/src/lib.rs @@ -63,13 +63,16 @@ fn early_returns(mut x: u32) -> u32 { .wrapping_add(x); } -fn simplifiable_return(c1: bool, c2: bool) -> i32 { +fn simplifiable_return(c1: bool, c2: bool, c3: bool) -> i32 { let mut x = 0; if c1 { if c2 { - return 1; + x += 10; + if c3 { + return 1; + } } - x = 1; + x += 1; } x } From 595d6dddc9ed1f538d8c30e3eff7888331f18a87 Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Tue, 24 Sep 2024 14:59:43 +0200 Subject: [PATCH 3/5] Use split_while. --- .../lib/phases/phase_rewrite_control_flow.ml | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/engine/lib/phases/phase_rewrite_control_flow.ml b/engine/lib/phases/phase_rewrite_control_flow.ml index 625b8e1aa..6e351de8e 100644 --- a/engine/lib/phases/phase_rewrite_control_flow.ml +++ b/engine/lib/phases/phase_rewrite_control_flow.ml @@ -42,11 +42,6 @@ module Make (F : Features.T) = control flow containing a return. Rewrite it. *) let stmts, final = U.collect_let_bindings e in - let before_after i = - let stmts_before, stmts_after = List.split_n stmts i in - let stmts_after = List.tl_exn stmts_after in - (stmts_before, stmts_after) - in let inline_in_branch branch p stmts_after final = let branch_stmts, branch_final = U.collect_let_bindings branch @@ -56,17 +51,18 @@ module Make (F : Features.T) = final in match - List.findi stmts ~f:(fun _ (_, e) -> + List.split_while stmts ~f:(fun (_, e) -> match e.e with - | (If _ | Match _) when has_return#visit_expr () e -> true - | Return _ -> true - | _ -> false) + | (If _ | Match _) when has_return#visit_expr () e -> + false + | Return _ -> false + | _ -> true) with - | Some (i, (p, ({ e = If { cond; then_; else_ }; _ } as rhs))) - -> + | ( stmts_before, + (p, ({ e = If { cond; then_; else_ }; _ } as rhs)) + :: stmts_after ) -> (* We know there is no "return" in the condition so we must rewrite the if *) - let stmts_before, stmts_after = before_after i in let then_ = inline_in_branch then_ p stmts_after final in let else_ = Some @@ -78,9 +74,9 @@ module Make (F : Features.T) = U.make_lets stmts_before { rhs with e = If { cond; then_; else_ } } |> self#visit_expr () - | Some (i, (p, ({ e = Match { scrutinee; arms }; _ } as rhs))) - -> - let stmts_before, stmts_after = before_after i in + | ( stmts_before, + (p, ({ e = Match { scrutinee; arms }; _ } as rhs)) + :: stmts_after ) -> let arms = List.map arms ~f:(fun arm -> let body = @@ -92,8 +88,7 @@ module Make (F : Features.T) = { rhs with e = Match { scrutinee; arms } } |> self#visit_expr () (* The statements coming after a "return" are useless. *) - | Some (i, (_, ({ e = Return _; _ } as rhs))) -> - let stmts_before, _ = before_after i in + | stmts_before, (_, ({ e = Return _; _ } as rhs)) :: _ -> U.make_lets stmts_before rhs |> self#visit_expr () | _ -> let stmts = From 94304b5eec0832250e28c449dfff9be5d5e819b9 Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Wed, 25 Sep 2024 10:23:19 +0200 Subject: [PATCH 4/5] Avoid adding extra 'let _ = ()'. --- .../lib/phases/phase_rewrite_control_flow.ml | 28 +++++++++++-------- .../toolchain__side-effects into-fstar.snap | 2 -- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/engine/lib/phases/phase_rewrite_control_flow.ml b/engine/lib/phases/phase_rewrite_control_flow.ml index 6e351de8e..7531f8584 100644 --- a/engine/lib/phases/phase_rewrite_control_flow.ml +++ b/engine/lib/phases/phase_rewrite_control_flow.ml @@ -46,21 +46,26 @@ module Make (F : Features.T) = let branch_stmts, branch_final = U.collect_let_bindings branch in - U.make_lets - (branch_stmts @ ((p, branch_final) :: stmts_after)) - final + let stmts_to_add = + match (p, branch_final) with + (* This avoids adding `let _ = ()` *) + | { p = PWild; _ }, { e = GlobalVar (`TupleCons 0); _ } -> + stmts_after + | stmt -> stmt :: stmts_after + in + U.make_lets (branch_stmts @ stmts_to_add) final in - match + let stmts_before, stmt_and_stmts_after = List.split_while stmts ~f:(fun (_, e) -> match e.e with | (If _ | Match _) when has_return#visit_expr () e -> false | Return _ -> false | _ -> true) - with - | ( stmts_before, - (p, ({ e = If { cond; then_; else_ }; _ } as rhs)) - :: stmts_after ) -> + in + match stmt_and_stmts_after with + | (p, ({ e = If { cond; then_; else_ }; _ } as rhs)) + :: stmts_after -> (* We know there is no "return" in the condition so we must rewrite the if *) let then_ = inline_in_branch then_ p stmts_after final in @@ -74,9 +79,8 @@ module Make (F : Features.T) = U.make_lets stmts_before { rhs with e = If { cond; then_; else_ } } |> self#visit_expr () - | ( stmts_before, - (p, ({ e = Match { scrutinee; arms }; _ } as rhs)) - :: stmts_after ) -> + | (p, ({ e = Match { scrutinee; arms }; _ } as rhs)) + :: stmts_after -> let arms = List.map arms ~f:(fun arm -> let body = @@ -88,7 +92,7 @@ module Make (F : Features.T) = { rhs with e = Match { scrutinee; arms } } |> self#visit_expr () (* The statements coming after a "return" are useless. *) - | stmts_before, (_, ({ e = Return _; _ } as rhs)) :: _ -> + | (_, ({ e = Return _; _ } as rhs)) :: _ -> U.make_lets stmts_before rhs |> self#visit_expr () | _ -> let stmts = diff --git a/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap b/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap index abc8ec0b4..1a54ff379 100644 --- a/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap +++ b/test-harness/src/snapshots/toolchain__side-effects into-fstar.snap @@ -316,11 +316,9 @@ let simplifiable_return (c1 c2 c3: bool) : i32 = then 1l else let x:i32 = x +! 1l in - let _:Prims.unit = () in x else let x:i32 = x +! 1l in - let _:Prims.unit = () in x else x From 120ba54c4418ebb1f3e0856d76797ea2a9ef39bc Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Wed, 25 Sep 2024 10:48:35 +0200 Subject: [PATCH 5/5] Modify comment. --- engine/lib/phases/phase_rewrite_control_flow.mli | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/lib/phases/phase_rewrite_control_flow.mli b/engine/lib/phases/phase_rewrite_control_flow.mli index ccc4ed630..72ecfd60d 100644 --- a/engine/lib/phases/phase_rewrite_control_flow.mli +++ b/engine/lib/phases/phase_rewrite_control_flow.mli @@ -1,6 +1,6 @@ (** This phase finds control flow expression (`if` or `match`) with a `return` expression in one of the branches. We replace them by replicating what comes after in all the branches. This allows the `return` to be eliminated by `drop_needless_returns`. - This phase should come after phase_local_mutation. *) + This phase should come after `phase_local_mutation`. *) module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE