Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: attributes [grind =>] and [grind <=] #6897

Merged
merged 9 commits into from
Feb 1, 2025
10 changes: 6 additions & 4 deletions src/Init/Grind/Tactics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ namespace Attr
syntax grindEq := "= "
syntax grindEqBoth := atomic("_" "=" "_ ")
syntax grindEqRhs := atomic("=" "_ ")
syntax grindEqBwd := atomic("←" "= ")
syntax grindBwd := "← "
syntax grindFwd := "→ "
syntax grindEqBwd := atomic("←" "= ") <|> atomic("<-" "= ")
syntax grindBwd := "← " <|> "-> "
syntax grindFwd := "→ " <|> "<- "
syntax grindRL := "⇐ " <|> "<= "
syntax grindLR := "⇒ " <|> "=> "
syntax grindUsr := &"usr "
syntax grindCases := &"cases "
syntax grindCasesEager := atomic(&"cases" &"eager ")
syntax grindIntro := &"intro "
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro
syntax (name := grind) "grind" (grindMod)? : attr
end Attr
end Lean.Parser
Expand Down
18 changes: 10 additions & 8 deletions src/Lean/Elab/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,16 @@ private def mkGrindOnly
else
let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName)
let param ← match kind with
| .eqLhs => `(Parser.Tactic.grindParam| = $decl)
| .eqRhs => `(Parser.Tactic.grindParam| =_ $decl)
| .eqBoth => `(Parser.Tactic.grindParam| _=_ $decl)
| .eqBwd => `(Parser.Tactic.grindParam| ←= $decl)
| .bwd => `(Parser.Tactic.grindParam| ← $decl)
| .fwd => `(Parser.Tactic.grindParam| → $decl)
| .user => `(Parser.Tactic.grindParam| usr $decl)
| .default => `(Parser.Tactic.grindParam| $decl:ident)
| .eqLhs => `(Parser.Tactic.grindParam| = $decl)
| .eqRhs => `(Parser.Tactic.grindParam| =_ $decl)
| .eqBoth => `(Parser.Tactic.grindParam| _=_ $decl)
| .eqBwd => `(Parser.Tactic.grindParam| ←= $decl)
| .bwd => `(Parser.Tactic.grindParam| ← $decl)
| .fwd => `(Parser.Tactic.grindParam| → $decl)
| .leftRight => `(Parser.Tactic.grindParam| => $decl)
| .rightLeft => `(Parser.Tactic.grindParam| <= $decl)
| .user => `(Parser.Tactic.grindParam| usr $decl)
| .default => `(Parser.Tactic.grindParam| $decl:ident)
params := params.push param
for declName in trace.eagerCases.toList do
unless Grind.isBuiltinEagerCases declName do
Expand Down
10 changes: 8 additions & 2 deletions src/Lean/Meta/Tactic/Grind/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ inductive AttrKind where
def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
match stx with
| `(Parser.Attr.grindMod| =) => return .ematch .eqLhs
| `(Parser.Attr.grindMod| →) => return .ematch .fwd
| `(Parser.Attr.grindMod| ←) => return .ematch .bwd
| `(Parser.Attr.grindMod| →)
| `(Parser.Attr.grindMod| ->) => return .ematch .fwd
| `(Parser.Attr.grindMod| ←)
| `(Parser.Attr.grindMod| <-) => return .ematch .bwd
| `(Parser.Attr.grindMod| =_) => return .ematch .eqRhs
| `(Parser.Attr.grindMod| _=_) => return .ematch .eqBoth
| `(Parser.Attr.grindMod| ←=) => return .ematch .eqBwd
| `(Parser.Attr.grindMod| ⇒)
| `(Parser.Attr.grindMod| =>) => return .ematch .leftRight
| `(Parser.Attr.grindMod| ⇐)
| `(Parser.Attr.grindMod| <=) => return .ematch .rightLeft
| `(Parser.Attr.grindMod| usr) => return .ematch .user
| `(Parser.Attr.grindMod| cases) => return .cases false
| `(Parser.Attr.grindMod| cases eager) => return .cases true
Expand Down
54 changes: 31 additions & 23 deletions src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -93,28 +93,32 @@ instance : Hashable Origin where
hash a := hash a.key

inductive EMatchTheoremKind where
| eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default | user /- pattern specified using `grind_pattern` command -/
| eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | leftRight | rightLeft | default | user /- pattern specified using `grind_pattern` command -/
deriving Inhabited, BEq, Repr, Hashable

private def EMatchTheoremKind.toAttribute : EMatchTheoremKind → String
| .eqLhs => "[grind =]"
| .eqRhs => "[grind =_]"
| .eqBoth => "[grind _=_]"
| .eqBwd => "[grind ←=]"
| .fwd => "[grind →]"
| .bwd => "[grind ←]"
| .default => "[grind]"
| .user => "[grind]"
| .eqLhs => "[grind =]"
| .eqRhs => "[grind =_]"
| .eqBoth => "[grind _=_]"
| .eqBwd => "[grind ←=]"
| .fwd => "[grind →]"
| .bwd => "[grind ←]"
| .leftRight => "[grind =>]"
| .rightLeft => "[grind <=]"
| .default => "[grind]"
| .user => "[grind]"

private def EMatchTheoremKind.explainFailure : EMatchTheoremKind → String
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
| .eqBoth => unreachable! -- eqBoth is a macro
| .eqBwd => "failed to use theorem's conclusion as a pattern"
| .fwd => "failed to find patterns in the antecedents of the theorem"
| .bwd => "failed to find patterns in the theorem's conclusion"
| .default => "failed to find patterns"
| .user => unreachable!
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
| .eqBoth => unreachable! -- eqBoth is a macro
| .eqBwd => "failed to use theorem's conclusion as a pattern"
| .fwd => "failed to find patterns in the antecedents of the theorem"
| .bwd => "failed to find patterns in the theorem's conclusion"
| .leftRight => "failed to find patterns searching from left to right"
| .rightLeft => "failed to find patterns searching from right to left"
| .default => "failed to find patterns"
| .user => unreachable!

/-- A theorem for heuristic instantiation based on E-matching. -/
structure EMatchTheorem where
Expand Down Expand Up @@ -664,22 +668,24 @@ where
| .bvar idx => modify fun s => if s.contains idx then s else idx :: s
| _ => return ()

private def diff (s : List Nat) (found : Std.HashSet Nat) : List Nat :=
if found.isEmpty then s else s.filter fun x => !found.contains x

/--
Returns `true` if pattern `p` contains a child `c` such that
1- `p` and `c` have the same pattern variables.
1- `p` and `c` have the same new pattern variables. We say a pattern variable is new if it is not in `alreadyFound`.
2- `c` is not a support argument. See `NormalizePattern.getPatternSupportMask` for definition.
3- `c` is not an offset pattern.
4- `c` is not a bound variable.
-/
private def hasChildWithSameBVars (p : Expr) (supportMask : Array Bool) : CoreM Bool := do
let s := collectPatternBVars p
private def hasChildWithSameNewBVars (p : Expr) (supportMask : Array Bool) (alreadyFound : Std.HashSet Nat) : CoreM Bool := do
let s := diff (collectPatternBVars p) alreadyFound
for arg in p.getAppArgs, support in supportMask do
unless support do
unless arg.isBVar do
unless isOffsetPattern? arg |>.isSome do
let sArg := collectPatternBVars arg
let sArg := diff (collectPatternBVars arg) alreadyFound
if s ⊆ sArg then
trace[Meta.debug] "SKIPPED: {p}, {arg}, {s}, {sArg}"
return true
return false

Expand All @@ -699,7 +705,7 @@ private partial def collect (e : Expr) : CollectorM Unit := do
return ()
let p ← NormalizePattern.normalizePattern p
if saved.bvarsFound.size < (← getThe NormalizePattern.State).bvarsFound.size then
unless (← hasChildWithSameBVars p supportMask) do
unless (← hasChildWithSameNewBVars p supportMask saved.bvarsFound) do
addNewPattern p
return ()
trace[grind.ematch.pattern.search] "skip, no new variables covered"
Expand Down Expand Up @@ -812,6 +818,8 @@ def mkEMatchTheoremWithKind?
throwError "invalid `grind` forward theorem, theorem `{← origin.pp}` does not have propositional hypotheses"
pure ps
| .bwd => pure #[type]
| .leftRight => pure <| (← getPropTypes xs).push type
| .rightLeft => pure <| #[type] ++ (← getPropTypes xs).reverse
| .default => pure <| #[type] ++ (← getPropTypes xs)
| _ => unreachable!
go xs searchPlaces
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/Tactic/Grind/ForallProp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ private def addLocalEMatchTheorems (e : Expr) : GoalM Unit := do
let size := (← get).newThms.size
let gen ← getGeneration e
-- TODO: we should have a flag for collecting all unary patterns in a local theorem
if let some thm ← mkEMatchTheoremWithKind'? origin proof .fwd then
if let some thm ← mkEMatchTheoremWithKind'? origin proof .leftRight then
activateTheorem thm gen
if let some thm ← mkEMatchTheoremWithKind'? origin proof .bwd then
if let some thm ← mkEMatchTheoremWithKind'? origin proof .rightLeft then
activateTheorem thm gen
if (← get).newThms.size == size then
if let some thm ← mkEMatchTheoremWithKind'? origin proof .default then
Expand Down
52 changes: 52 additions & 0 deletions tests/lean/run/grind_attrs.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
opaque R : Nat → Nat → Prop

@[grind ->]
axiom Rtrans {x y z : Nat} : R x y → R y z → R x z

@[grind →]
axiom Rtrans' {x y z : Nat} : R x y → R y z → R x z

@[grind <-]
axiom Rsymm {x y : Nat} : R x y → R y x

@[grind ←]
axiom Rsymm' {x y : Nat} : R x y → R y x

example : R a b → R b c → R d c → R a d := by
grind only [-> Rtrans, <- Rsymm]

example : R a b → R b c → R d c → R a d := by
grind only [→ Rtrans, ← Rsymm]


opaque State : Type
opaque State.le (σ₁ σ₂ : State) : Prop
axiom State.update : State → Nat → Nat → State
opaque Expr : Type
opaque Expr.eval : Expr → State → Nat
axiom Expr.constProp : Expr → State → Expr


/--
info: [grind.ematch.pattern] Expr.eval_constProp_of_sub: [State.le #3 #2, Expr.constProp #1 #3]
-/
#guard_msgs (info) in
set_option trace.grind.ematch.pattern true in
@[grind =>] theorem Expr.eval_constProp_of_sub (e : Expr) (h : State.le σ' σ) : (e.constProp σ').eval σ = e.eval σ :=
sorry

/--
info: [grind.ematch.pattern] Expr.eval_constProp_of_eq_of_sub: [State.le #3 #2, Expr.constProp #1 #3]
-/
#guard_msgs (info) in
set_option trace.grind.ematch.pattern true in
@[grind =>] theorem Expr.eval_constProp_of_eq_of_sub {e : Expr} (h₂ : State.le σ' σ) : (e.constProp σ').eval σ = e.eval σ :=
sorry

/--
info: [grind.ematch.pattern] State.update_le_update: [State.le #4 #3, State.update #4 #2 #1]
-/
#guard_msgs (info) in
set_option trace.grind.ematch.pattern true in
@[grind =>] theorem State.update_le_update (h : State.le σ' σ) : State.le (σ'.update x v) (σ.update x v) :=
sorry
34 changes: 9 additions & 25 deletions tests/lean/run/grind_constProp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ attribute [local grind] State.le State.erase State.find? State.update
theorem State.le_trans : σ₁ ≼ σ₂ → σ₂ ≼ σ₃ → σ₁ ≼ σ₃ := by
grind

theorem State.bot_le (σ : State) : ⊥ ≼ σ := by
@[grind] theorem State.bot_le (σ : State) : ⊥ ≼ σ := by
grind

theorem State.erase_le_cons (h : σ' ≼ σ) : σ'.erase x ≼ ((x, v) :: σ) := by
Expand Down Expand Up @@ -320,41 +320,25 @@ theorem State.erase_le_of_le_cons (h : σ' ≼ (x, v) :: σ) : σ'.erase x ≼
@[grind] theorem State.erase_le_update (h : σ' ≼ σ) : σ'.erase x ≼ σ.update x v := by
grind

@[grind] theorem State.update_le_update (h : σ' ≼ σ) : σ'.update x v ≼ σ.update x v := by
@[grind =>] theorem State.update_le_update (h : σ' ≼ σ) : σ'.update x v ≼ σ.update x v := by
grind

grind_pattern State.update_le_update => σ' ≼ σ, σ'.update x v

@[grind] theorem Expr.eval_constProp_of_sub (e : Expr) (h : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by
@[grind =>] theorem Expr.eval_constProp_of_sub (e : Expr) (h : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by
induction e <;> grind

-- TODO: better pattern selection heuristic. We want to avoid the following step.
grind_pattern Expr.eval_constProp_of_sub => σ' ≼ σ, e.constProp σ'

theorem Expr.eval_constProp_of_eq_of_sub {e : Expr} (h₂ : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by
@[grind =>] theorem Expr.eval_constProp_of_eq_of_sub {e : Expr} (h₂ : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by
grind

grind_pattern Expr.eval_constProp_of_eq_of_sub => σ' ≼ σ, e.constProp σ'

theorem Stmt.constProp_sub (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (s.constProp σ₁').2 ≼ σ₂ := by
@[grind =>] theorem Stmt.constProp_sub (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (s.constProp σ₁').2 ≼ σ₂ := by
induction h₁ generalizing σ₁' with grind [=_ Expr.eval_simplify]

grind_pattern Stmt.constProp_sub => (σ₁, s) ⇓ σ₂, s.constProp σ₁'

end

theorem Stmt.constProp_correct (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (σ₁, (s.constProp σ₁').1) ⇓ σ₂ := by
@[grind] theorem Stmt.constProp_correct (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼ σ₁) : (σ₁, (s.constProp σ₁').1) ⇓ σ₂ := by
induction h₁ generalizing σ₁' <;> try grind [=_ Expr.eval_simplify, intro Bigstep]
next heq h₁ h₂ ih₁ ih₂ =>
-- TODO: we need better heuristics for selecting patterns for local quantifiers.
-- both `ih₁` and `ih₂` are local, and the current pattern selection picks reall bad patterns.
have ih₁ := ih₁ (State.bot_le _)
have ih₂ := ih₂ (State.bot_le _)
grind [intro Bigstep, constProp]

def Stmt.constPropagation (s : Stmt) : Stmt :=

@[grind] def Stmt.constPropagation (s : Stmt) : Stmt :=
(s.constProp ⊥).1

theorem Stmt.constPropagation_correct (h : (σ, s) ⇓ σ') : (σ, s.constPropagation) ⇓ σ' := by
-- TODO: grind [constProp_correct, State.bot_le]
exact constProp_correct h (State.bot_le _)
grind
26 changes: 13 additions & 13 deletions tests/lean/run/grind_ematch2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ info: [grind] Counters
[thm] Array.size_set ↦ 3
---
info: [diag] Diagnostics
[reduction] unfolded declarations (max: 11839, num: 3):
[reduction] LT.lt ↦ 11839
[reduction] getElem ↦ 64
[reduction] Nat.lt ↦ 32
[reduction] unfolded instances (max: 32, num: 1):
[reduction] Array.instGetElemNatLtSize ↦ 32
[reduction] unfolded reducible declarations (max: 7079, num: 7):
[reduction] Array.size ↦ 7079
[reduction] Array.toList ↦ 1885
[reduction] autoParam ↦ 1715
[reduction] outParam ↦ 124
[reduction] Ne ↦ 57
[reduction] GT.gt ↦ 40
[reduction] unfolded declarations (max: 11842, num: 3):
[reduction] LT.lt ↦ 11842
[reduction] getElem ↦ 76
[reduction] Nat.lt ↦ 35
[reduction] unfolded instances (max: 38, num: 1):
[reduction] Array.instGetElemNatLtSize ↦ 38
[reduction] unfolded reducible declarations (max: 7091, num: 7):
[reduction] Array.size ↦ 7091
[reduction] Array.toList ↦ 1897
[reduction] autoParam ↦ 1724
[reduction] outParam ↦ 172
[reduction] Ne ↦ 60
[reduction] GT.gt ↦ 46
[reduction] List.casesOn ↦ 24
[def_eq] heuristic for solving `f a =?= f b` (max: 5067, num: 2):
[def_eq] Nat.lt ↦ 5067
Expand Down
6 changes: 6 additions & 0 deletions tests/lean/run/grind_eq_bwd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ attribute [grind ←=] inv_eq
example {a b : α} (w : mul a b = one) : inv a = b := by
grind

example {a b : α} (w : mul a b = one) : inv a = b := by
grind only [<-= inv_eq]

structure S where
f : Bool → α
h : mul (f true) (f false) = one
Expand All @@ -36,6 +39,9 @@ attribute [grind =] S.h S.h'
example (s : S) : inv (s.f true) = s.f false := by
grind

example (s : S) : inv (s.f true) = s.f false := by
grind only [<-= inv_eq, = S.h]

example (s : S) : s.f false = inv (s.f true) := by
grind

Expand Down
4 changes: 1 addition & 3 deletions tests/lean/run/grind_pattern2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ info: [grind.internalize] foo x y
[grind.internalize] x
[grind.internalize] y
[grind.internalize] z
---
warning: declaration uses 'sorry'
-/
#guard_msgs in
#guard_msgs (info) in
set_option trace.grind.internalize true in
example : foo x y = z → False := by
fail_if_success grind
Expand Down
4 changes: 1 addition & 3 deletions tests/lean/run/grind_trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ info: Try this: grind only [= List.getElem?_replicate, = List.getElem?_eq_some_i
theorem map_replicate' : (List.replicate n a).map f = List.replicate n (f a) := by
grind?

/--
info: Try this: grind only [= List.getLast?_eq_some_iff, List.mem_concat_self]
-/
/-- info: Try this: grind only [List.mem_concat_self, = List.getLast?_eq_some_iff] -/
#guard_msgs (info) in
theorem mem_of_getLast?_eq_some' {xs : List α} {a : α} (h : xs.getLast? = some a) : a ∈ xs := by
grind?
Expand Down
Loading
Loading