Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tactic code actions framework #132

Merged
merged 9 commits into from
May 16, 2023
3 changes: 3 additions & 0 deletions Std.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import Std.CodeAction.Hole
import Std.CodeAction.Hole.Attr
import Std.CodeAction.Hole.Basic
import Std.CodeAction.Hole.Misc
import Std.CodeAction.Tactic.Attr
import Std.CodeAction.Tactic.Basic
import Std.CodeAction.Tactic.Misc
import Std.Control.ForInStep
import Std.Control.ForInStep.Basic
import Std.Control.ForInStep.Lemmas
Expand Down
22 changes: 7 additions & 15 deletions Std/CodeAction/Hole/Misc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,6 @@ def findStack? (root target : Syntax) : Option Syntax.Stack := do
root.findStack? (·.getRange?.any (·.includes range))
(fun s => s.getKind == target.getKind && s.getRange? == range)

/--
Return the indentation (number of leading spaces) of the line containing `pos`,
and whether `pos` is the first non-whitespace character in the line.
-/
def findIndentAndIsStart (s : String) (pos : String.Pos) : Nat × Bool :=
let start := findLineStart s pos
let body := s.findAux (· ≠ ' ') pos start
((body - start).1, body == pos)

/-- Constructs a hole with a kind matching the provided hole elaborator. -/
def holeKindToHoleString : (elaborator : Name) → (synthName : String) → String
| ``Elab.Term.elabSyntheticHole, name => "?" ++ name
Expand Down Expand Up @@ -107,6 +98,12 @@ where
| some substructName => collectFields env substructName fields
| none => fields.push field

/-- Returns the explicit arguments given a type. -/
def getExplicitArgs : Expr → Array Name → Array Name
| .forallE n _ body bi, args =>
getExplicitArgs body <| if bi.isExplicit then args.push n else args
| _, args => args

/--
Invoking hole code action "Generate a list of equations for a recursive definition" in the
following:
Expand Down Expand Up @@ -154,7 +151,7 @@ def foo : Expr → Unit := fun
let some (.ctorInfo ci) := snap.env.find? ctor | panic! "bad inductive"
let ctor := toString (ctor.updatePrefix .anonymous)
str := str ++ indent ++ s!"| .{ctor}"
for arg in getArgs ci.type #[] do
for arg in getExplicitArgs ci.type #[] do
str := str ++ if arg.hasNum || arg.isInternal then " _" else s!" {arg}"
str := str ++ s!" => {holeKindToHoleString info.elaborator ctor}"
pure { eager with
Expand All @@ -164,8 +161,3 @@ def foo : Expr → Unit := fun
}
}
}]
where
/-- Returns the explicit arguments given a type. -/
getArgs : Expr → Array Name → Array Name
| .forallE n _ body bi, args => getArgs body <| if bi.isExplicit then args.push n else args
| _, args => args
129 changes: 129 additions & 0 deletions Std/CodeAction/Tactic/Attr.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/-
Copyright (c) 2023 Mario Carneiro. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mario Carneiro
-/
import Lean.Server.CodeActions
import Std.Util.TermUnsafe

/-!
# Initial setup for tactic commands

This declares an attribute `@[tactic_code_action]` which collects code actions which will be called
on each occurrence of a tactic (`_`, `?_` or `sorry`).
-/
namespace Std.CodeAction

open Lean Elab Server Lsp RequestM Snapshots

/-- A tactic code action extension. -/
abbrev TacticCodeAction :=
CodeActionParams → Snapshot →
(ctx : ContextInfo) → (stack : Syntax.Stack) → (node : InfoTree) →
RequestM (Array LazyCodeAction)

/-- A tactic code action extension. -/
abbrev TacticSeqCodeAction :=
CodeActionParams → Snapshot →
(ctx : ContextInfo) → (i : Nat) → (stack : Syntax.Stack) → (goals : List MVarId) →
RequestM (Array LazyCodeAction)

/-- Read a tactic code action from a declaration of the right type. -/
def mkTacticCodeAction (n : Name) : ImportM TacticCodeAction := do
let { env, opts, .. } ← read
IO.ofExcept <| unsafe env.evalConstCheck TacticCodeAction opts ``TacticCodeAction n

/-- Read a tacticSeq code action from a declaration of the right type. -/
def mkTacticSeqCodeAction (n : Name) : ImportM TacticSeqCodeAction := do
let { env, opts, .. } ← read
IO.ofExcept <| unsafe env.evalConstCheck TacticSeqCodeAction opts ``TacticSeqCodeAction n

/-- An entry in the tactic code actions extension, containing the attribute arguments. -/
structure TacticCodeActionEntry where
/-- The declaration to tag -/
declName : Name
/-- The tactic kinds that this extension supports. If empty, it is called on tactic insertion
on the spaces between tactics, and if none it is called on all tactic kinds. -/
tacticKinds : Array Name
deriving Inhabited

/-- The state of the tactic code actions extension. -/
structure TacticCodeActions where
/-- The list of tactic code actions to apply on any tactic. -/
onAnyTactic : Array TacticCodeAction := {}
/-- The list of tactic code actions to apply when a particular tactic kind is highlighted. -/
onTactic : NameMap (Array TacticCodeAction) := {}
deriving Inhabited

/-- Insert a tactic code action entry into the `TacticCodeActions` structure. -/
def TacticCodeActions.insert (self : TacticCodeActions)
(tacticKinds : Array Name) (action : TacticCodeAction) : TacticCodeActions :=
if tacticKinds.isEmpty then
{ self with onAnyTactic := self.onAnyTactic.push action }
else
{ self with onTactic := tacticKinds.foldl (init := self.onTactic) fun m a =>
m.insert a ((m.findD a #[]).push action) }

/-- An extension which collects all the tactic code actions. -/
initialize tacticSeqCodeActionExt :
PersistentEnvExtension Name (Name × TacticSeqCodeAction)
(Array Name × Array TacticSeqCodeAction) ←
registerPersistentEnvExtension {
mkInitial := pure (#[], #[])
addImportedFn := fun as => return (#[], ← as.foldlM (init := #[]) fun m as =>
as.foldlM (init := m) fun m a => return m.push (← mkTacticSeqCodeAction a))
addEntryFn := fun (s₁, s₂) (n₁, n₂) => (s₁.push n₁, s₂.push n₂)
exportEntriesFn := (·.1)
}

/-- An extension which collects all the tactic code actions. -/
initialize tacticCodeActionExt :
PersistentEnvExtension TacticCodeActionEntry (TacticCodeActionEntry × TacticCodeAction)
(Array TacticCodeActionEntry × TacticCodeActions) ←
registerPersistentEnvExtension {
mkInitial := pure (#[], {})
addImportedFn := fun as => return (#[], ← as.foldlM (init := {}) fun m as =>
as.foldlM (init := m) fun m ⟨name, kinds⟩ =>
return m.insert kinds (← mkTacticCodeAction name))
addEntryFn := fun (s₁, s₂) (e, n₂) => (s₁.push e, s₂.insert e.tacticKinds n₂)
exportEntriesFn := (·.1)
}

/--
This attribute marks a code action, which is used to suggest new tactics or replace existing ones.

* `@[tactic_code_action]`: This is a code action which applies to the spaces between tactics,
to suggest a new tactic to change the goal state.

* `@[tactic_code_action kind]`: This is a code action which applies to applications of the tactic
`kind` (a tactic syntax kind), which can replace the tactic or insert things before or after it.

* `@[tactic_code_action kind₁ kind₂]`: shorthand for
`@[tactic_code_action kind₁, tactic_code_action kind₂]`.
-/
syntax (name := tactic_code_action) "tactic_code_action" ("*" <|> ident*) : attr

initialize
registerBuiltinAttribute {
name := `tactic_code_action
descr := "Declare a new tactic code action, to appear in the code actions on tactics"
applicationTime := .afterCompilation
add := fun decl stx kind => do
unless kind == AttributeKind.global do
throwError "invalid attribute 'tactic_code_action', must be global"
let _ := (decl, stx)
match stx with
| `(attr| tactic_code_action *) =>
if (IR.getSorryDep (← getEnv) decl).isSome then return -- ignore in progress definitions
modifyEnv (tacticCodeActionExt.addEntry · (⟨decl, #[]⟩, ← mkTacticCodeAction decl))
| `(attr| tactic_code_action $[$args]*) =>
if args.isEmpty then
if (IR.getSorryDep (← getEnv) decl).isSome then return -- ignore in progress definitions
modifyEnv (tacticSeqCodeActionExt.addEntry · (decl, ← mkTacticSeqCodeAction decl))
else
let args ← args.mapM fun arg => do
resolveGlobalConstNoOverloadWithInfo arg
if (IR.getSorryDep (← getEnv) decl).isSome then return -- ignore in progress definitions
modifyEnv (tacticCodeActionExt.addEntry · (⟨decl, args⟩, ← mkTacticCodeAction decl))
| _ => pure ()
}
183 changes: 183 additions & 0 deletions Std/CodeAction/Tactic/Basic.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/-
Copyright (c) 2023 Mario Carneiro. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mario Carneiro
-/
import Std.CodeAction.Tactic.Attr

/-!
# Initial setup for tactic code actions

This declares a code action provider that calls `@[tactic_code_action]` definitions.

(This is in a separate file from `Std.CodeAction.Tactic.Attr` so that the server does not attempt
to use this code action provider when browsing the `Std.CodeAction.Tactic.Attr` file itself.)
-/
namespace Std.CodeAction

open Lean Elab Server RequestM

/--
The return value of `findTactic?`.
This is the syntax for which code actions will be triggered.
-/
inductive FindTacticResult
/-- The nearest enclosing tactic is a tactic, with the given syntax stack. -/
| tactic : Syntax.Stack → FindTacticResult
/-- The cursor is between tactics, and the nearest enclosing range is a tactic sequence.
Code actions will insert tactics at index `insertIdx` into the syntax
(which is a nullNode of `tactic;*` inside a `tacticSeqBracketed` or `tacticSeq1Indented`). -/
| tacticSeq : (preferred : Bool) → (insertIdx : Nat) → Syntax.Stack → FindTacticResult

/--
Find the syntax on which to trigger tactic code actions.
This is a pure syntax pass, without regard to elaboration information.

* `preferred : String.Pos → Bool`: used to select "preferred `tacticSeq`s" based on the cursor
column, when the cursor selection would otherwise be ambiguous. For example, in:
```
· foo
· bar
baz
|
```
where the cursor is at the `|`, we select the `tacticSeq` starting with `foo`, while if the
cursor was indented to align with `baz` then we would select the `bar; baz` sequence instead.

* `range`: the cursor selection. We do not do much with range selections; if a range selection
covers more than one tactic then we abort.

* `root`: the root syntax to process

The return value is either a selected tactic, or a selected point in a tactic sequence.
-/
partial def findTactic? (preferred : String.Pos → Bool) (range : String.Range)
(root : Syntax) : Option FindTacticResult := do _ ← visit root; ← go [] root
where
/-- Returns `none` if we should not visit this syntax at all, and `some false` if we only
want to visit it in "extended" mode (where we include trailing characters). -/
visit (stx : Syntax) : Option Bool := do
guard <| (← stx.getPos? true) ≤ range.start
let .original (endPos := right) (trailing := trailing) .. := stx.getTailInfo | none
guard <| right.byteIdx + trailing.bsize ≥ range.stop.byteIdx
return right ≥ range.stop

/-- Merges the results of two `FindTacticResult`s. This just prefers the second (inner) one,
unless the inner tactic is a dispreferred tactic sequence and the outer one is preferred.
This is used to implement whitespace-sensitive selection of tactic sequences. -/
merge : (r₁ : Option FindTacticResult) → (r₂ : FindTacticResult) → FindTacticResult
| some r₁@(.tacticSeq (preferred := true) ..), .tacticSeq (preferred := false) .. => r₁
| _, r₂ => r₂

/-- Main recursion for `findTactic?`. This takes a `stack` context and a root syntax `stx`,
and returns the best `FindTacticResult` it can find. It returns `none` (abort) if two or more
results are found, and `some none` (none yet) if no results are found. -/
go (stack : Syntax.Stack) (stx : Syntax) : Option (Option FindTacticResult) := do
if stx.getKind == ``Parser.Tactic.tacticSeq then
-- TODO: this implementation is a bit too strict about the beginning of tacticSeqs.
-- We would like to be able to parse
-- · |
-- foo
-- (where `|` is the cursor position) as an insertion into the sequence containing foo
-- at index 0, but we currently use the start of the tacticSeq, which is the foo token,
-- as the earliest possible location that will be associated to the sequence.
let bracket := stx[0].getKind == ``Parser.Tactic.tacticSeqBracketed
let argIdx := if bracket then 1 else 0
let (stack, stx) := ((stx[0], argIdx) :: (stx, 0) :: stack, stx[0][argIdx])
let mainRes := stx[0].getPos?.map fun pos =>
let i := Id.run do
for i in [0:stx.getNumArgs] do
if let some pos' := stx[2*i].getPos? then
if range.stop < pos' then
return i
(stx.getNumArgs + 1) / 2
.tacticSeq (bracket || preferred pos) i ((stx, 0) :: stack)
let mut childRes := none
for i in [0:stx.getNumArgs:2] do
if let some inner := visit stx[i] then
let stack := (stx, i) :: stack
if let some child := (← go stack stx[i]) <|>
(if inner then some (.tactic ((stx[i], 0) :: stack)) else none)
then
if childRes.isSome then failure
childRes := merge mainRes child
return childRes <|> mainRes
else
let mut childRes := none
for i in [0:stx.getNumArgs] do
if let some _ := visit stx[i] then
if let some child ← go ((stx, i) :: stack) stx[i] then
if childRes.isSome then failure
childRes := child
return childRes

/--
Returns the info tree corresponding to a syntax, using `kind` and `range` for identification.
(This is not foolproof, but it is a fairly accurate proxy for `Syntax` equality and a lot cheaper
than deep comparison.)
-/
partial def findInfoTree? (kind : SyntaxNodeKind) (tgtRange : String.Range) (t : InfoTree)
(f : ContextInfo → Info → Bool) (canonicalOnly := false) :
Option (ContextInfo × InfoTree) :=
go none t
where
/-- `go ctx?` is like `findInfoTree?` but uses `ctx?` as the ambient `ContextInfo`. -/
go ctx?
| .context ctx t => go ctx t
| node@(.node i ts) => do
if let some ctx := ctx? then
let range ← i.stx.getRange? canonicalOnly
-- FIXME: info tree needs to be organized better so that this works
-- guard <| range.includes tgtRange
if i.stx.getKind == kind && range == tgtRange && f ctx i then
return (ctx, node)
for t in ts do
if let some res := go (i.updateContext? ctx?) t then
return res
none
| _ => none

/-- A code action which calls `@[tactic_code_action]` code actions. -/
@[codeActionProvider] def tacticCodeActionProvider : CodeActionProvider := fun params snap => do
let doc ← readDoc
let startPos := doc.meta.text.lspPosToUtf8Pos params.range.start
let endPos := doc.meta.text.lspPosToUtf8Pos params.range.end
let pointerCol :=
if params.range.start.line == params.range.end.line then
max params.range.start.character params.range.end.character
else 0
let some result := findTactic?
(fun pos => (doc.meta.text.utf8PosToLspPos pos).character ≤ pointerCol)
⟨startPos, endPos⟩ snap.stx | return #[]
let tgtTac := match result with
| .tactic (tac :: _)
| .tacticSeq _ _ (_ :: tac :: _) => tac.1
| _ => unreachable!
let tgtRange := tgtTac.getRange?.get!
have info := findInfoTree? tgtTac.getKind tgtRange snap.infoTree (canonicalOnly := true)
fun _ info => info matches .ofTacticInfo _
let some (ctx, node@(.node (.ofTacticInfo info) _)) := info | return #[]
let mut out := #[]
match result with
| .tactic stk@((tac, _) :: _) => do
let ctx := { ctx with mctx := info.mctxBefore }
let actions := (tacticCodeActionExt.getState snap.env).2
if let some arr := actions.onTactic.find? tac.getKind then
for act in arr do
try out := out ++ (← act params snap ctx stk node) catch _ => pure ()
for act in actions.onAnyTactic do
try out := out ++ (← act params snap ctx stk node) catch _ => pure ()
| .tacticSeq _ i stk@((seq, _) :: _) =>
let (ctx, goals) ← if 2*i < seq.getNumArgs then
pure ({ ctx with mctx := info.mctxAfter }, info.goalsAfter)
else
let stx := seq[2*i]
let some stxRange := stx.getRange? | return #[]
let some (ctx, .node (.ofTacticInfo info') _) :=
findInfoTree? stx.getKind stxRange node fun _ info => (info matches .ofTacticInfo _)
| return #[]
pure ({ ctx with mctx := info'.mctxBefore }, info'.goalsBefore)
for act in (tacticSeqCodeActionExt.getState snap.env).2 do
try out := out ++ (← act params snap ctx i stk goals) catch _ => pure ()
| _ => unreachable!
pure out
Loading