Skip to content

Commit

Permalink
Merge pull request #559 from hacspec/question-mark-as-early-returns
Browse files Browse the repository at this point in the history
Question mark as early returns
  • Loading branch information
W95Psp authored Mar 12, 2024
2 parents e9e26d4 + f3e21a3 commit 3855c88
Show file tree
Hide file tree
Showing 24 changed files with 933 additions and 146 deletions.
5 changes: 4 additions & 1 deletion engine/backends/fstar/fstar_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,7 @@ module TransformToInputLanguage =
[%functor_application
Phases.Reject.RawOrMutPointer(Features.Rust)
|> Phases.Drop_sized_trait
|> Phases.Simplify_question_marks
|> Phases.And_mut_defsite
|> Phases.Reconstruct_for_loops
|> Phases.Reconstruct_while_loops
Expand All @@ -1346,15 +1347,17 @@ module TransformToInputLanguage =
|> Phases.Drop_blocks
|> Phases.Drop_references
|> Phases.Trivialize_assign_lhs
|> Phases.Reconstruct_question_marks
|> Side_effect_utils.Hoist
|> Phases.Simplify_match_return
|> Phases.Drop_needless_returns
|> Phases.Local_mutation
|> Phases.Reject.Continue
|> Phases.Cf_into_monads
|> Phases.Reject.EarlyExit
|> Phases.Functionalize_loops
|> Phases.Reject.As_pattern
|> Phases.Traits_specs
|> Phases.Simplify_hoisting
|> SubtypeToInputLanguage
|> Identity
]
Expand Down
31 changes: 11 additions & 20 deletions engine/lib/ast_visitors.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
open Ast
open! Utils
open Base

module Make =
functor
Expand Down Expand Up @@ -729,7 +731,7 @@ functor

method visit_list : 'a. ('env -> 'a -> 'a) -> 'env -> 'a list -> 'a list
=
fun v env this -> Base.List.map ~f:(fun x -> v env x) this
fun v env -> Base.List.map ~f:(v env)

method visit_option
: 'a. ('env -> 'a -> 'a) -> 'env -> 'a option -> 'a option =
Expand Down Expand Up @@ -1883,15 +1885,11 @@ functor
method visit_list
: 'a. ('env -> 'a -> 'a * 'acc) -> 'env -> 'a list -> 'a list * 'acc
=
fun v env this ->
let acc = ref self#zero in
( Base.List.map
~f:(fun x ->
let x, acc' = v env x in
acc := self#plus !acc acc';
x)
this,
!acc )
fun v env ->
Base.List.fold_map ~init:self#zero ~f:(fun acc x ->
let x, acc' = v env x in
(self#plus acc acc', x))
>> swap

method visit_option
: 'a.
Expand Down Expand Up @@ -2946,16 +2944,9 @@ functor
method visit_list : 'a. ('env -> 'a -> 'acc) -> 'env -> 'a list -> 'acc
=
fun v env this ->
let acc = ref self#zero in
let _ =
Base.List.map
~f:(fun x ->
let acc' = v env x in
acc := self#plus !acc acc';
())
this
in
!acc
Base.List.fold ~init:self#zero
~f:(fun acc -> v env >> self#plus acc)
this

method visit_option
: 'a. ('env -> 'a -> 'acc) -> 'env -> 'a option -> 'acc =
Expand Down
5 changes: 5 additions & 0 deletions engine/lib/diagnostics.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,22 @@ module Phase = struct
| Identity
| DropReferences
| DropBlocks
| DropSizedTrait
| RefMut
| ResugarForLoops
| ResugarWhileLoops
| ResugarForIndexLoops
| ResugarQuestionMarks
| SimplifyQuestionMarks
| HoistSideEffects
| LocalMutation
| TrivializeAssignLhs
| CfIntoMonads
| FunctionalizeLoops
| TraitsSpecs
| SimplifyMatchReturn
| SimplifyHoisting
| DropNeedlessReturns
| DummyA
| DummyB
| DummyC
Expand Down
4 changes: 3 additions & 1 deletion engine/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
ocamlgraph)
(preprocessor_deps
; `ppx_inline` is used on the `Subtype` module, thus we need it at PPX time
(file subtype.ml))
(file subtype.ml)
(source_tree phases))
(preprocess
(pps
ppx_yojson_conv
Expand All @@ -26,6 +27,7 @@
ppx_deriving.eq
ppx_string
ppx_inline
ppx_phases_index
ppx_generate_features
ppx_functor_application
ppx_enumerate
Expand Down
5 changes: 4 additions & 1 deletion engine/lib/local_ident.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open! Prelude

module T = struct
type kind = Typ | Cnst | Expr | LILifetime | Final
type kind = Typ | Cnst | Expr | LILifetime | Final | SideEffectHoistVar
[@@deriving show, yojson, hash, compare, sexp, eq]

type id = kind * int [@@deriving show, yojson, hash, compare, sexp, eq]
Expand All @@ -13,6 +13,9 @@ module T = struct

let make_final name = { name; id = mk_id Final 0 }
let is_final { id; _ } = [%matches? Final] @@ fst id

let is_side_effect_hoist_var { id; _ } =
[%matches? SideEffectHoistVar] @@ fst id
end

include Base.Comparator.Make (T)
Expand Down
13 changes: 11 additions & 2 deletions engine/lib/local_ident.mli
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
module T : sig
type kind = Typ | Cnst | Expr | LILifetime | Final
type kind =
| Typ (** type namespace *)
| Cnst (** Generic constant namespace *)
| Expr (** Expression namespace *)
| LILifetime (** Lifetime namespace *)
| Final
(** Frozen identifier: such an identifier will *not* be rewritten by the name policy *)
| SideEffectHoistVar (** A variable generated by `Side_effect_utils` *)
[@@deriving show, yojson, hash, compare, sexp, eq]

type id [@@deriving show, yojson, hash, compare, sexp, eq]
Expand All @@ -9,9 +16,11 @@ module T : sig
type t = { name : string; id : id }
[@@deriving show, yojson, hash, compare, sexp, eq]

(* Create a frozen final local identifier: such an indentifier won't be rewritten by a name policy *)
val make_final : string -> t
(** Creates a frozen final local identifier: such an indentifier won't be rewritten by a name policy *)

val is_final : t -> bool
val is_side_effect_hoist_var : t -> bool
end

include module type of struct
Expand Down
15 changes: 2 additions & 13 deletions engine/lib/phases.ml
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
module Direct_and_mut = Phase_direct_and_mut.Make
module And_mut_defsite = Phase_and_mut_defsite.Make
module Drop_references = Phase_drop_references.Make
module Drop_blocks = Phase_drop_blocks.Make
module Reconstruct_for_loops = Phase_reconstruct_for_loops.Make
module Reconstruct_while_loops = Phase_reconstruct_while_loops.Make
module Reconstruct_question_marks = Phase_reconstruct_question_marks.Make
module Trivialize_assign_lhs = Phase_trivialize_assign_lhs.Make
module Cf_into_monads = Phase_cf_into_monads.Make
module Functionalize_loops = Phase_functionalize_loops.Make
[%%phases_index ()]

module Reject = Phase_reject
module Local_mutation = Phase_local_mutation.Make
module Traits_specs = Phase_traits_specs.Make
module Drop_sized_trait = Phase_drop_sized_trait.Make
48 changes: 48 additions & 0 deletions engine/lib/phases/phase_drop_needless_returns.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
open! Prelude

module Make (F : Features.T) =
Phase_utils.MakeMonomorphicPhase
(F)
(struct
let phase_id = Diagnostics.Phase.DropNeedlessReturns

open Ast.Make (F)
module U = Ast_utils.Make (F)
module Visitors = Ast_visitors.Make (F)

module Error = Phase_utils.MakeError (struct
let ctx = Diagnostics.Context.Phase phase_id
end)

let visitor =
object (self)
inherit [_] Visitors.map as _super

method! visit_expr () e =
match e with
| { e = Return { e; _ }; _ } -> e
(* we know [e] is on an exit position: the return is
thus useless, we can skip it *)
| { e = Let { monadic = None; lhs; rhs; body }; _ } ->
let body = self#visit_expr () body in
{ e with e = Let { monadic = None; lhs; rhs; body } }
(* If a let expression is an exit node, then it's body
is as well *)
| { e = Match { scrutinee; arms }; _ } ->
let arms = List.map ~f:(self#visit_arm ()) arms in
{ e with e = Match { scrutinee; arms } }
| { e = If { cond; then_; else_ }; _ } ->
let then_ = self#visit_expr () then_ in
let else_ = Option.map ~f:(self#visit_expr ()) else_ in
{ e with e = If { cond; then_; else_ } }
| _ -> e
(** The invariant here is that [visit_expr] is called only
on expressions that are on exit positions. [visit_expr]
is first called on root expressions, which are (by
definition) exit nodes. Then, [visit_expr] itself makes
recursive calls to sub expressions that are themselves
in exit nodes. **)
end

let ditems = List.map ~f:(visitor#visit_item ())
end)
4 changes: 4 additions & 0 deletions engine/lib/phases/phase_drop_needless_returns.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
(** This phase transforms `return e` expressions into `e` when `return
e` is on an exit position. *)

module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE
2 changes: 1 addition & 1 deletion engine/lib/phases/phase_drop_sized_trait.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Make (F : Features.T) =
Phase_utils.MakeMonomorphicPhase
(F)
(struct
let phase_id = Diagnostics.Phase.TraitsSpecs
let phase_id = Diagnostics.Phase.DropSizedTrait

open Ast.Make (F)
module U = Ast_utils.Make (F)
Expand Down
4 changes: 1 addition & 3 deletions engine/lib/phases/phase_reconstruct_question_marks.mli
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ While `e?` in Rust might implies an implicit coercion, in our AST, a
question mark is expected to already be of the right type. This phase
inlines a coercion (of the shape `x.map_err(from)`, in the case of a
`Result`).
*)
open! Prelude
*)

(** This phase can be applied to any feature set. *)
module Make (F : Features.T) : sig
include module type of struct
module FA = F
Expand Down
67 changes: 67 additions & 0 deletions engine/lib/phases/phase_simplify_hoisting.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
open! Prelude

module Make (F : Features.T) =
Phase_utils.MakeMonomorphicPhase
(F)
(struct
let phase_id = Diagnostics.Phase.SimplifyHoisting

open Ast.Make (F)
module U = Ast_utils.Make (F)
module Visitors = Ast_visitors.Make (F)

module Error = Phase_utils.MakeError (struct
let ctx = Diagnostics.Context.Phase phase_id
end)

let inline_matches =
object (self)
inherit [_] Visitors.map as super

method! visit_expr () e =
match e with
| {
e =
Let
{
monadic = None;
lhs =
{
p =
PBinding
{
mut = Immutable;
mode = ByValue;
var;
subpat = None;
_;
};
_;
};
rhs;
body;
};
_;
}
when Local_ident.is_side_effect_hoist_var var ->
let body, count =
(object
inherit [_] Visitors.mapreduce as super
method zero = 0
method plus = ( + )

method! visit_expr () e =
match e.e with
| LocalVar v when [%eq: Local_ident.t] v var -> (rhs, 1)
| _ -> super#visit_expr () e
end)
#visit_expr
() body
in
if [%eq: int] count 1 then self#visit_expr () body
else super#visit_expr () e
| _ -> super#visit_expr () e
end

let ditems = List.map ~f:(inline_matches#visit_item ())
end)
4 changes: 4 additions & 0 deletions engine/lib/phases/phase_simplify_hoisting.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
(** This phase rewrites `let pat = match ... { ... => ..., ... => return ... }; e`
into `match ... { ... => let pat = ...; e}`. *)

module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE
Loading

0 comments on commit 3855c88

Please sign in to comment.