Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question mark as early returns #559

Merged
merged 12 commits into from
Mar 12, 2024
5 changes: 4 additions & 1 deletion engine/backends/fstar/fstar_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,7 @@ module TransformToInputLanguage =
[%functor_application
Phases.Reject.RawOrMutPointer(Features.Rust)
|> Phases.Drop_sized_trait
|> Phases.Simplify_question_marks
|> Phases.And_mut_defsite
|> Phases.Reconstruct_for_loops
|> Phases.Reconstruct_while_loops
Expand All @@ -1346,15 +1347,17 @@ module TransformToInputLanguage =
|> Phases.Drop_blocks
|> Phases.Drop_references
|> Phases.Trivialize_assign_lhs
|> Phases.Reconstruct_question_marks
|> Side_effect_utils.Hoist
|> Phases.Simplify_match_return
|> Phases.Drop_needless_returns
|> Phases.Local_mutation
|> Phases.Reject.Continue
|> Phases.Cf_into_monads
|> Phases.Reject.EarlyExit
|> Phases.Functionalize_loops
|> Phases.Reject.As_pattern
|> Phases.Traits_specs
|> Phases.Simplify_hoisting
|> SubtypeToInputLanguage
|> Identity
]
Expand Down
31 changes: 11 additions & 20 deletions engine/lib/ast_visitors.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
open Ast
open! Utils
open Base

module Make =
functor
Expand Down Expand Up @@ -729,7 +731,7 @@ functor

method visit_list : 'a. ('env -> 'a -> 'a) -> 'env -> 'a list -> 'a list
=
fun v env this -> Base.List.map ~f:(fun x -> v env x) this
fun v env -> Base.List.map ~f:(v env)

method visit_option
: 'a. ('env -> 'a -> 'a) -> 'env -> 'a option -> 'a option =
Expand Down Expand Up @@ -1883,15 +1885,11 @@ functor
method visit_list
: 'a. ('env -> 'a -> 'a * 'acc) -> 'env -> 'a list -> 'a list * 'acc
=
fun v env this ->
let acc = ref self#zero in
( Base.List.map
~f:(fun x ->
let x, acc' = v env x in
acc := self#plus !acc acc';
x)
this,
!acc )
fun v env ->
Base.List.fold_map ~init:self#zero ~f:(fun acc x ->
let x, acc' = v env x in
(self#plus acc acc', x))
>> swap

method visit_option
: 'a.
Expand Down Expand Up @@ -2946,16 +2944,9 @@ functor
method visit_list : 'a. ('env -> 'a -> 'acc) -> 'env -> 'a list -> 'acc
=
fun v env this ->
let acc = ref self#zero in
let _ =
Base.List.map
~f:(fun x ->
let acc' = v env x in
acc := self#plus !acc acc';
())
this
in
!acc
Base.List.fold ~init:self#zero
~f:(fun acc -> v env >> self#plus acc)
this

method visit_option
: 'a. ('env -> 'a -> 'acc) -> 'env -> 'a option -> 'acc =
Expand Down
5 changes: 5 additions & 0 deletions engine/lib/diagnostics.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,22 @@ module Phase = struct
| Identity
| DropReferences
| DropBlocks
| DropSizedTrait
| RefMut
| ResugarForLoops
| ResugarWhileLoops
| ResugarForIndexLoops
| ResugarQuestionMarks
| SimplifyQuestionMarks
| HoistSideEffects
| LocalMutation
| TrivializeAssignLhs
| CfIntoMonads
| FunctionalizeLoops
| TraitsSpecs
| SimplifyMatchReturn
| SimplifyHoisting
| DropNeedlessReturns
| DummyA
| DummyB
| DummyC
Expand Down
4 changes: 3 additions & 1 deletion engine/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
ocamlgraph)
(preprocessor_deps
; `ppx_inline` is used on the `Subtype` module, thus we need it at PPX time
(file subtype.ml))
(file subtype.ml)
(source_tree phases))
(preprocess
(pps
ppx_yojson_conv
Expand All @@ -26,6 +27,7 @@
ppx_deriving.eq
ppx_string
ppx_inline
ppx_phases_index
ppx_generate_features
ppx_functor_application
ppx_enumerate
Expand Down
5 changes: 4 additions & 1 deletion engine/lib/local_ident.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open! Prelude

module T = struct
type kind = Typ | Cnst | Expr | LILifetime | Final
type kind = Typ | Cnst | Expr | LILifetime | Final | SideEffectHoistVar
[@@deriving show, yojson, hash, compare, sexp, eq]

type id = kind * int [@@deriving show, yojson, hash, compare, sexp, eq]
Expand All @@ -13,6 +13,9 @@ module T = struct

let make_final name = { name; id = mk_id Final 0 }
let is_final { id; _ } = [%matches? Final] @@ fst id

let is_side_effect_hoist_var { id; _ } =
[%matches? SideEffectHoistVar] @@ fst id
end

include Base.Comparator.Make (T)
Expand Down
13 changes: 11 additions & 2 deletions engine/lib/local_ident.mli
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
module T : sig
type kind = Typ | Cnst | Expr | LILifetime | Final
type kind =
| Typ (** type namespace *)
| Cnst (** Generic constant namespace *)
| Expr (** Expression namespace *)
| LILifetime (** Lifetime namespace *)
| Final
(** Frozen identifier: such an identifier will *not* be rewritten by the name policy *)
| SideEffectHoistVar (** A variable generated by `Side_effect_utils` *)
[@@deriving show, yojson, hash, compare, sexp, eq]

type id [@@deriving show, yojson, hash, compare, sexp, eq]
Expand All @@ -9,9 +16,11 @@ module T : sig
type t = { name : string; id : id }
[@@deriving show, yojson, hash, compare, sexp, eq]

(* Create a frozen final local identifier: such an indentifier won't be rewritten by a name policy *)
val make_final : string -> t
(** Creates a frozen final local identifier: such an indentifier won't be rewritten by a name policy *)

val is_final : t -> bool
val is_side_effect_hoist_var : t -> bool
end

include module type of struct
Expand Down
15 changes: 2 additions & 13 deletions engine/lib/phases.ml
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
module Direct_and_mut = Phase_direct_and_mut.Make
module And_mut_defsite = Phase_and_mut_defsite.Make
module Drop_references = Phase_drop_references.Make
module Drop_blocks = Phase_drop_blocks.Make
module Reconstruct_for_loops = Phase_reconstruct_for_loops.Make
module Reconstruct_while_loops = Phase_reconstruct_while_loops.Make
module Reconstruct_question_marks = Phase_reconstruct_question_marks.Make
module Trivialize_assign_lhs = Phase_trivialize_assign_lhs.Make
module Cf_into_monads = Phase_cf_into_monads.Make
module Functionalize_loops = Phase_functionalize_loops.Make
[%%phases_index ()]

module Reject = Phase_reject
module Local_mutation = Phase_local_mutation.Make
module Traits_specs = Phase_traits_specs.Make
module Drop_sized_trait = Phase_drop_sized_trait.Make
48 changes: 48 additions & 0 deletions engine/lib/phases/phase_drop_needless_returns.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
open! Prelude

module Make (F : Features.T) =
Phase_utils.MakeMonomorphicPhase
(F)
(struct
let phase_id = Diagnostics.Phase.DropNeedlessReturns

open Ast.Make (F)
module U = Ast_utils.Make (F)
module Visitors = Ast_visitors.Make (F)

module Error = Phase_utils.MakeError (struct
let ctx = Diagnostics.Context.Phase phase_id
end)

let visitor =
object (self)
inherit [_] Visitors.map as _super

method! visit_expr () e =
match e with
| { e = Return { e; _ }; _ } -> e
(* we know [e] is on an exit position: the return is
thus useless, we can skip it *)
| { e = Let { monadic = None; lhs; rhs; body }; _ } ->
let body = self#visit_expr () body in
{ e with e = Let { monadic = None; lhs; rhs; body } }
(* If a let expression is an exit node, then it's body
is as well *)
| { e = Match { scrutinee; arms }; _ } ->
let arms = List.map ~f:(self#visit_arm ()) arms in
{ e with e = Match { scrutinee; arms } }
| { e = If { cond; then_; else_ }; _ } ->
let then_ = self#visit_expr () then_ in
let else_ = Option.map ~f:(self#visit_expr ()) else_ in
{ e with e = If { cond; then_; else_ } }
| _ -> e
(** The invariant here is that [visit_expr] is called only
on expressions that are on exit positions. [visit_expr]
is first called on root expressions, which are (by
definition) exit nodes. Then, [visit_expr] itself makes
recursive calls to sub expressions that are themselves
in exit nodes. **)
end

let ditems = List.map ~f:(visitor#visit_item ())
end)
4 changes: 4 additions & 0 deletions engine/lib/phases/phase_drop_needless_returns.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
(** This phase transforms `return e` expressions into `e` when `return
e` is on an exit position. *)

module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE
2 changes: 1 addition & 1 deletion engine/lib/phases/phase_drop_sized_trait.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Make (F : Features.T) =
Phase_utils.MakeMonomorphicPhase
(F)
(struct
let phase_id = Diagnostics.Phase.TraitsSpecs
let phase_id = Diagnostics.Phase.DropSizedTrait

open Ast.Make (F)
module U = Ast_utils.Make (F)
Expand Down
4 changes: 1 addition & 3 deletions engine/lib/phases/phase_reconstruct_question_marks.mli
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ While `e?` in Rust might implies an implicit coercion, in our AST, a
question mark is expected to already be of the right type. This phase
inlines a coercion (of the shape `x.map_err(from)`, in the case of a
`Result`).
*)

open! Prelude
*)

(** This phase can be applied to any feature set. *)
module Make (F : Features.T) : sig
include module type of struct
module FA = F
Expand Down
67 changes: 67 additions & 0 deletions engine/lib/phases/phase_simplify_hoisting.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
open! Prelude

module Make (F : Features.T) =
Phase_utils.MakeMonomorphicPhase
(F)
(struct
let phase_id = Diagnostics.Phase.SimplifyHoisting

open Ast.Make (F)
module U = Ast_utils.Make (F)
module Visitors = Ast_visitors.Make (F)

module Error = Phase_utils.MakeError (struct
let ctx = Diagnostics.Context.Phase phase_id
end)

let inline_matches =
object (self)
inherit [_] Visitors.map as super

method! visit_expr () e =
match e with
| {
e =
Let
{
monadic = None;
lhs =
{
p =
PBinding
{
mut = Immutable;
mode = ByValue;
var;
subpat = None;
_;
};
_;
};
rhs;
body;
};
_;
}
when Local_ident.is_side_effect_hoist_var var ->
let body, count =
(object
inherit [_] Visitors.mapreduce as super
method zero = 0
method plus = ( + )

method! visit_expr () e =
match e.e with
| LocalVar v when [%eq: Local_ident.t] v var -> (rhs, 1)
| _ -> super#visit_expr () e
end)
#visit_expr
() body
in
if [%eq: int] count 1 then self#visit_expr () body
else super#visit_expr () e
| _ -> super#visit_expr () e
end

let ditems = List.map ~f:(inline_matches#visit_item ())
end)
4 changes: 4 additions & 0 deletions engine/lib/phases/phase_simplify_hoisting.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
(** This phase rewrites `let pat = match ... { ... => ..., ... => return ... }; e`
into `match ... { ... => let pat = ...; e}`. *)

module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE
Loading
Loading