Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: traversal functions for Expr and SubExpr #1208

Merged
merged 16 commits into from
Jun 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Init/Data/Array/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,8 @@ def isPrefixOfAux [BEq α] (as bs : Array α) (hle : as.size ≤ bs.size) (i : N
true
termination_by _ => as.size - i

/- Return true iff `as` is a prefix of `bs` -/
/-- Return true iff `as` is a prefix of `bs`.
That is, `bs = as ++ t` for some `t : List α`.-/
def isPrefixOf [BEq α] (as bs : Array α) : Bool :=
if h : as.size ≤ bs.size then
isPrefixOfAux as bs h 0
Expand Down
6 changes: 4 additions & 2 deletions src/Init/Data/List/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,15 @@ instance [LT α] : LE (List α) := ⟨List.le⟩
instance [LT α] [DecidableRel ((· < ·) : α → α → Prop)] : (l₁ l₂ : List α) → Decidable (l₁ ≤ l₂) :=
fun _ _ => inferInstanceAs (Decidable (Not _))

/-- `isPrefixOf l₁ l₂` returns `true` Iff `l₁` is a prefix of `l₂`. -/
/-- `isPrefixOf l₁ l₂` returns `true` Iff `l₁` is a prefix of `l₂`.
That is, there exists a `t` such that `l₂ == l₁ ++ t`. -/
def isPrefixOf [BEq α] : List α → List α → Bool
| [], _ => true
| _, [] => false
| a::as, b::bs => a == b && isPrefixOf as bs

/-- `isSuffixOf l₁ l₂` returns `true` Iff `l₁` is a suffix of `l₂`. -/
/-- `isSuffixOf l₁ l₂` returns `true` Iff `l₁` is a suffix of `l₂`.
That is, there exists a `t` such that `l₂ == t ++ l₁`. -/
def isSuffixOf [BEq α] (l₁ l₂ : List α) : Bool :=
isPrefixOf l₁.reverse l₂.reverse

Expand Down
6 changes: 6 additions & 0 deletions src/Lean/Expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,12 @@ private def getAppRevArgsAux : Expr → Array Expr → Array Expr
let nargs := e.getAppNumArgs
withAppAux k e (mkArray nargs dummy) (nargs-1)

/-- Given `e = fn a₁ ... aₙ`, runs `f` on `fn` and each of the arguments `aᵢ` and
makes a new function application with the results. -/
def traverseApp {M} [Monad M]
(f : Expr → M Expr) (e : Expr) : M Expr :=
e.withApp fun fn args => mkAppN <$> f fn <*> args.mapM f

@[specialize] private def withAppRevAux (k : Expr → Array Expr → α) : Expr → Array Expr → α
| app f a _, as => withAppRevAux k f (as.push a)
| f, as => k f as
Expand Down
2 changes: 2 additions & 0 deletions src/Lean/Meta.lean
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@ import Lean.Meta.Constructions
import Lean.Meta.CongrTheorems
import Lean.Meta.Eqns
import Lean.Meta.CasesOn
import Lean.Meta.ExprLens
import Lean.Meta.ExprTraverse
175 changes: 175 additions & 0 deletions src/Lean/Meta/ExprLens.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/-
Copyright (c) 2022 E.W.Ayers. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: E.W.Ayers
-/
import Lean.Meta.Basic
import Lean.SubExpr

/-!

# Expression Lenses

Functions for manipulating subexpressions using `SubExpr.Pos`.

-/

namespace Lean.Meta

section ExprLens

open Lean.SubExpr

variable {M} [Monad M] [MonadLiftT MetaM M] [MonadControlT MetaM M] [MonadError M]

/-- Given a constructor index for Expr, runs `g` on the value of that subexpression and replaces it.
If the subexpression is under a binder it will instantiate and abstract the binder body correctly.
Mdata is ignored. An index of 3 is interpreted as the type of the expression. An index of 3 will throw since we can't replace types.

See also `Lean.Meta.transform`, `Lean.Meta.traverseChildren`. -/
private def lensCoord (g : Expr → M Expr) : Nat → Expr → M Expr
| 0, e@(Expr.app f a _) => return e.updateApp! (← g f) a
| 1, e@(Expr.app f a _) => return e.updateApp! f (← g a)
| 0, e@(Expr.lam _ y b _) => return e.updateLambdaE! (← g y) b
| 1, (Expr.lam n y b c) => withLocalDecl n c.binderInfo y fun x => do mkLambdaFVars #[x] <|← g <| b.instantiateRev #[x]
| 0, e@(Expr.forallE _ y b _) => return e.updateForallE! (← g y) b
| 1, (Expr.forallE n y b c) => withLocalDecl n c.binderInfo y fun x => do mkForallFVars #[x] <|← g <| b.instantiateRev #[x]
| 0, e@(Expr.letE _ y a b _) => return e.updateLet! (← g y) a b
| 1, e@(Expr.letE _ y a b _) => return e.updateLet! y (← g a) b
| 2, (Expr.letE n y a b _) => withLetDecl n y a fun x => do mkLetFVars #[x] <|← g <| b.instantiateRev #[x]
| 0, e@(Expr.proj _ _ b _) => e.updateProj! <$> g b
| n, e@(Expr.mdata _ a _) => e.updateMData! <$> lensCoord g n a
| 3, _ => throwError "Lensing on types is not supported"
| c, e => throwError "Invalid coordinate {c} for {e}"

private def lensAux (g : Expr → M Expr) : List Nat → Expr → M Expr
| [], e => g e
| head::tail, e => lensCoord (lensAux g tail) head e

/-- Run the given `replace` function to replace the expression at the subexpression position. If the subexpression is below a binder
the bound variables will be appropriately instantiated with free variables and reabstracted after the replacement.
If the subexpression is invalid or points to a type then this will throw. -/
def replaceSubexpr (replace : (subexpr : Expr) → M Expr) (p : Pos) (root : Expr) : M Expr :=
lensAux replace p.toArray.toList root

/-- Runs `k` on the given coordinate, including handling binders properly.
The subexpression value passed to `k` is not instantiated with respect to the
array of free variables. -/
private def viewCoordAux (k : Array Expr → Expr → M α) (fvars: Array Expr) : Nat → Expr → M α
| 3, _ => throwError "Internal: Types should be handled by viewAux"
| 0, (Expr.app f _ _) => k fvars f
| 1, (Expr.app _ a _) => k fvars a
| 0, (Expr.lam _ y _ _) => k fvars y
| 1, (Expr.lam n y b c) => withLocalDecl n c.binderInfo (y.instantiateRev fvars) fun x => k (fvars.push x) b
| 0, (Expr.forallE _ y _ _) => k fvars y
| 1, (Expr.forallE n y b c) => withLocalDecl n c.binderInfo (y.instantiateRev fvars) fun x => k (fvars.push x) b
| 0, (Expr.letE _ y _ _ _) => k fvars y
| 1, (Expr.letE _ _ a _ _) => k fvars a
| 2, (Expr.letE n y a b _) => withLetDecl n (y.instantiateRev fvars) (a.instantiateRev fvars) fun x => k (fvars.push x) b
| 0, (Expr.proj _ _ b _) => k fvars b
| n, (Expr.mdata _ a _) => viewCoordAux k fvars n a
| c, e => throwError "Invalid coordinate {c} for {e}"

private def viewAux (k : Array Expr → Expr → M α) (fvars : Array Expr) : List Nat → Expr → M α
| [], e => k fvars <| e.instantiateRev fvars
| 3::tail, e => do
let y ← inferType <| e.instantiateRev fvars
viewAux (fun otherFvars => k (fvars ++ otherFvars)) #[] tail y
| head::tail, e => viewCoordAux (fun fvars => viewAux k fvars tail) fvars head e

/-- `view visit p e` runs `visit fvars s` where `s : Expr` is the subexpression of `e` at `p`.
and `fvars` are the free variables for the binders that `s` is under.
`s` is already instantiated with respect to these.
The role of the `visit` function is analogous to the `k` function in `Lean.Meta.forallTelescope`. -/
def viewSubexpr
(visit : (fvars : Array Expr) → (subexpr : Expr) → M α)
(p : Pos) (root : Expr) : M α :=
viewAux visit #[] p.toArray.toList root

private def foldAncestorsAux
(k : Array Expr → Expr → Nat → α → M α)
(acc : α) (address : List Nat) (fvars : Array Expr) (current : Expr) : M α :=
match address with
| [] => return acc
| 3 :: tail => do
let current := current.instantiateRev fvars
let y ← inferType current
let acc ← k fvars current 3 acc
foldAncestorsAux (fun otherFvars => k (fvars ++ otherFvars)) acc tail #[] y
| head :: tail => do
let acc ← k fvars (current.instantiateRev fvars) head acc
viewCoordAux (foldAncestorsAux k acc tail) fvars head current

/-- `foldAncestors k init p e` folds over the strict ancestor subexpressions of the given expression `e` above position `p`, starting at the root expression and working down.
The fold function `k` is given the newly instantiated free variables, the ancestor subexpression, and the coordinate
that will be explored next.-/
def foldAncestors
(k : (fvars: Array Expr) → (subexpr : Expr) → (nextCoord : Nat) → α → M α)
(init : α) (p : Pos) (e : Expr) : M α :=
foldAncestorsAux k init p.toArray.toList #[] e

end ExprLens

end Lean.Meta

namespace Lean.Core

open Lean.SubExpr

section ViewRaw

variable {M} [Monad M] [MonadError M]

/-- Get the raw subexpression without performing any instantiation. -/
private def viewCoordRaw: Expr → Nat → M Expr
| e , 3 => throwError "Can't viewRaw the type of {e}"
| (Expr.app f _ _) , 0 => pure f
| (Expr.app _ a _) , 1 => pure a
| (Expr.lam _ y _ _) , 0 => pure y
| (Expr.lam _ _ b _) , 1 => pure b
| (Expr.forallE _ y _ _), 0 => pure y
| (Expr.forallE _ _ b _), 1 => pure b
| (Expr.letE _ y _ _ _) , 0 => pure y
| (Expr.letE _ _ a _ _) , 1 => pure a
| (Expr.letE _ _ _ b _) , 2 => pure b
| (Expr.proj _ _ b _) , 0 => pure b
| (Expr.mdata _ a _) , n => viewCoordRaw a n
| e , c => throwError "Bad coordinate {c} for {e}"


/-- Given a valid SubExpr, will return the raw current expression without performing any instantiation.
If the SubExpr has a type subexpression coordinate then will error.

This is a cheaper version of `Lean.Meta.viewSubexpr` and can be used to quickly view the
subexpression at a position. Note that because the resulting expression will contain
loose bound variables it can't be used in any `MetaM` methods. -/
def viewSubexpr (p : Pos) (root : Expr) : M Expr :=
p.foldlM viewCoordRaw root

private def viewBindersCoord : Nat → Expr → Option (Name × Expr)
| 1, (Expr.lam n y _ _) => some (n, y)
| 1, (Expr.forallE n y _ _) => some (n, y)
| 2, (Expr.letE n y _ _ _) => some (n, y)
| _, _ => none

/-- `viewBinders p e` returns a list of all of the binders (name, type) above the given position `p` in the root expression `e` -/
def viewBinders (p : Pos) (root : Expr) : M (Array (Name × Expr)) := do
let (acc, _) ← p.foldlM (fun (acc, e) c => do
let e₂ ← viewCoordRaw e c
let acc :=
match viewBindersCoord c e with
| none => acc
| some b => acc.push b
return (acc, e₂)
) (#[], root)
return acc

/-- Returns the number of binders above a given subexpr position. -/
def numBinders (p : Pos) (e : Expr) : M Nat :=
Array.size <$> viewBinders p e

end ViewRaw

end Lean.Core


96 changes: 96 additions & 0 deletions src/Lean/Meta/ExprTraverse.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/-
Copyright (c) 2022 E.W.Ayers. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: E.W.Ayers
-/
import Lean.Meta.Basic
import Lean.SubExpr

namespace Lean.Meta

open Lean.SubExpr (Pos)
open Lean.SubExpr.Pos

variable {M} [Monad M] [MonadLiftT MetaM M] [MonadControlT MetaM M]

/-- Convert a traversal function to a form without the `Pos` argument. -/
private def forgetPos (t : (Pos → Expr → M Expr) → (Pos → Expr → M Expr)) (visit : Expr → M Expr) (e : Expr) : M Expr :=
t (fun _ => visit) Pos.root e

/-- Similar to `traverseLambda` but with an additional pos argument to track position. -/
def traverseLambdaWithPos
(f : Pos → Expr → M Expr) (p : Pos) (e : Expr) : M Expr := visit #[] p e
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative formulation: make

structure ExprTraverseContext with
  (pos : Pos)
  (fvars : Array Expr)

and then

traverseXXXWithPos (f : Expr → ReaderT ExprTraverseContext Expr) : Expr → ReaderT ExprTraverseContext Expr

Or perhaps instead include the variable instances [MonadReader ExprTraverseContext M] and [MonadReaderWith ExprTraverseContext M]?

where visit (fvars : Array Expr) (p : Pos) : Expr → M Expr
| (Expr.lam n d b c) => do
let d ← f p.pushBindingDomain <| d.instantiateRev fvars
withLocalDecl n c.binderInfo d fun x =>
visit (fvars.push x) p.pushBindingBody b
| e => do
let body ← f p <| e.instantiateRev fvars
mkLambdaFVars fvars body

/-- Similar to `traverseForall` but with an additional pos argument to track position. -/
def traverseForallWithPos
(f : Pos → Expr → M Expr) (p : Pos) (e : Expr) : M Expr := visit #[] p e
where visit fvars (p : Pos): Expr → M Expr
| (Expr.forallE n d b c) => do
let d ← f p.pushBindingDomain <| d.instantiateRev fvars
withLocalDecl n c.binderInfo d fun x =>
visit (fvars.push x) p.pushBindingBody b
| e => do
let body ← f p <| e.instantiateRev fvars
mkForallFVars fvars body

/-- Similar to `traverseLet` but with an additional pos argument to track position. -/
def traverseLetWithPos
(f : Pos → Expr → M Expr) (p : Pos) (e : Expr) : M Expr := visit #[] p e
where visit fvars (p : Pos)
| Expr.letE n t v b _ => do
let type ← f p.pushLetVarType <| t.instantiateRev fvars
let value ← f p.pushLetValue <| v.instantiateRev fvars
withLetDecl n type value fun x =>
visit (fvars.push x) p.pushLetBody b
| e => do
let body ← f p <| e.instantiateRev fvars
-- if usedLetOnly = true then let binders will be eliminated
-- if their var doesn't appear in the body.
mkLetFVars (usedLetOnly := false) fvars body

/-- Similar to `Lean.Meta.traverseChildren` except that `visit` also includes a `Pos` argument so you can
track the subexpression position. -/
def traverseChildrenWithPos (visit : Pos → Expr → M Expr) (p : Pos) (e: Expr) : M Expr :=
match e with
| Expr.forallE .. => traverseForallWithPos visit p e
| Expr.lam .. => traverseLambdaWithPos visit p e
| Expr.letE .. => traverseLetWithPos visit p e
| Expr.app .. => Expr.traverseAppWithPos visit p e
| Expr.mdata _ b _ => e.updateMData! <$> visit p b
| Expr.proj _ _ b _ => e.updateProj! <$> visit p.pushProj b
| _ => pure e

/-- Given an expression `fun (x₁ : α₁) ... (xₙ : αₙ) => b`, will run
`f` on each of the variable types `αᵢ` and `b` with the correct MetaM context,
replacing each expression with the output of `f` and creating a new lambda.
(that is, correctly instantiating bound variables and repackaging them after) -/
def traverseLambda (visit : Expr → M Expr) := forgetPos traverseLambdaWithPos visit

/-- Given an expression ` (x₁ : α₁) → ... → (xₙ : αₙ) → b`, will run
`f` on each of the variable types `αᵢ` and `b` with the correct MetaM context,
replacing the expression with the output of `f` and creating a new forall expression.
(that is, correctly instantiating bound variables and repackaging them after) -/
def traverseForall (visit : Expr → M Expr) := forgetPos traverseForallWithPos visit

/-- Similar to `traverseLambda` and `traverseForall` but with let binders. -/
def traverseLet (visit : Expr → M Expr) := forgetPos traverseLetWithPos visit

/-- Maps `visit` on each child of the given expression.

Applications, foralls, lambdas and let binders are bundled (as they are bundled in `Expr.traverseApp`, `traverseForall`, ...).
So `traverseChildren f e` where ``e = `(fn a₁ ... aₙ)`` will return
``(← f `(fn)) (← f `(a₁)) ... (← f `(aₙ))`` rather than ``(← f `(fn a₁ ... aₙ₋₁)) (← f `(aₙ))``

See also `Lean.Core.traverseChildren`.
-/
def traverseChildren (visit : Expr → M Expr) := forgetPos traverseChildrenWithPos visit

end Lean.Meta
15 changes: 7 additions & 8 deletions src/Lean/PrettyPrinter/Delaborator/SubExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ variable [MonadLiftT IO m]
def getExpr : m Expr := return (← readThe SubExpr).expr
def getPos : m Pos := return (← readThe SubExpr).pos

def descend (child : Expr) (childIdx : Pos) (x : m α) : m α :=
withTheReader SubExpr (fun cfg => { cfg with expr := child, pos := cfg.pos * maxChildren + childIdx }) x
def descend (child : Expr) (childIdx : Nat) (x : m α) : m α :=
withTheReader SubExpr (fun cfg => { cfg with expr := child, pos := cfg.pos.push childIdx }) x

def withAppFn (x : m α) : m α := do descend (← getExpr).appFn! 0 x
def withAppArg (x : m α) : m α := do descend (← getExpr).appArg! 1 x

def withType (x : m α) : m α := do
descend (← Meta.inferType (← getExpr)) (maxChildren - 1) x -- phantom positions for types
descend (← Meta.inferType (← getExpr)) Pos.typeCoord x -- phantom positions for types

partial def withAppFnArgs (xf : m α) (xa : α → m α) : m α := do
if (← getExpr).isApp then
Expand Down Expand Up @@ -79,21 +79,20 @@ def withLetBody (x : m α) : m α := do

def withNaryFn (x : m α) : m α := do
let e ← getExpr
let n := e.getAppNumArgs
let newPos := (← getPos) * (maxChildren ^ n)
let newPos := (← getPos).pushNaryFn e.getAppNumArgs
withTheReader SubExpr (fun cfg => { cfg with expr := e.getAppFn, pos := newPos }) x

def withNaryArg (argIdx : Nat) (x : m α) : m α := do
let e ← getExpr
let args := e.getAppArgs
let newPos := (← getPos) * (maxChildren ^ (args.size - argIdx)) + 1
let newPos := (← getPos).pushNaryArg args.size argIdx
withTheReader SubExpr (fun cfg => { cfg with expr := args[argIdx], pos := newPos }) x

end Descend

structure HoleIterator where
curr : Nat := 2
top : Nat := maxChildren
top : Nat := Pos.maxChildren
deriving Inhabited

section Hole
Expand All @@ -107,7 +106,7 @@ def HoleIterator.toPos (iter : HoleIterator) : Pos :=

def HoleIterator.next (iter : HoleIterator) : HoleIterator :=
if (iter.curr+1) == iter.top then
⟨2*iter.top, maxChildren*iter.top⟩
⟨2*iter.top, Pos.maxChildren*iter.top⟩
else ⟨iter.curr+1, iter.top⟩

/-- The positioning scheme guarantees that there will be an infinite number of extra positions
Expand Down
Loading