Skip to content

Commit

Permalink
feat: add [grind intro] attribute (#6888)
Browse files Browse the repository at this point in the history
This PR adds the `[grind intro]` attribute. It instructs `grind` to mark
the introduction rules of an inductive predicate as E-matching theorems.
  • Loading branch information
leodemoura authored Jan 31, 2025
1 parent b3a8d5b commit 5900f39
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 39 deletions.
3 changes: 2 additions & 1 deletion src/Init/Grind/Tactics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ syntax grindFwd := "→ "
syntax grindUsr := &"usr "
syntax grindCases := &"cases "
syntax grindCasesEager := atomic(&"cases" &"eager ")
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindUsr <|> grindCasesEager <|> grindCases
syntax grindIntro := &"intro "
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro
syntax (name := grind) "grind" (grindMod)? : attr
end Attr
end Lean.Parser
Expand Down
10 changes: 8 additions & 2 deletions src/Lean/Elab/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.
match p with
| `(Parser.Tactic.grindParam| - $id:ident) =>
let declName ← realizeGlobalConstNoOverloadWithInfo id
if (← Grind.isCasesAttrCandidate declName false) then
if let some declName ← Grind.isCasesAttrCandidate? declName false then
Grind.ensureNotBuiltinCases declName
params := { params with casesTypes := (← params.casesTypes.eraseDecl declName) }
else
Expand All @@ -82,8 +82,14 @@ def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.
| .cases eager =>
withRef p <| Grind.validateCasesAttr declName eager
params := { params with casesTypes := params.casesTypes.insert declName eager }
| .intro =>
if let some info ← Grind.isCasesAttrPredicateCandidate? declName false then
for ctor in info.ctors do
params ← withRef p <| addEMatchTheorem params ctor .default
else
throwError "invalid use of `intro` modifier, `{declName}` is not an inductive predicate"
| .infer =>
if (← Grind.isCasesAttrCandidate declName false) then
if let some declName ← Grind.isCasesAttrCandidate? declName false then
params := { params with casesTypes := params.casesTypes.insert declName false }
if let some info ← isInductivePredicate? declName then
-- If it is an inductive predicate,
Expand Down
10 changes: 9 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace Lean.Meta.Grind
inductive AttrKind where
| ematch (k : EMatchTheoremKind)
| cases (eager : Bool)
| intro
| infer

/-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/
Expand All @@ -26,6 +27,7 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
| `(Parser.Attr.grindMod| usr) => return .ematch .user
| `(Parser.Attr.grindMod| cases) => return .cases false
| `(Parser.Attr.grindMod| cases eager) => return .cases true
| `(Parser.Attr.grindMod| intro) => return .intro
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"

/-- Return theorem kind for `stx` of the form `(Attr.grindMod)?` -/
Expand Down Expand Up @@ -64,8 +66,14 @@ builtin_initialize
| .ematch .user => throwInvalidUsrModifier
| .ematch k => addEMatchAttr declName attrKind k
| .cases eager => addCasesAttr declName eager attrKind
| .intro =>
if let some info ← isCasesAttrPredicateCandidate? declName false then
for ctor in info.ctors do
addEMatchAttr ctor attrKind .default
else
throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate"
| .infer =>
if (← isCasesAttrCandidate declName false) then
if let some declName ← isCasesAttrCandidate? declName false then
addCasesAttr declName false attrKind
if let some info ← isInductivePredicate? declName then
-- If it is an inductive predicate,
Expand Down
17 changes: 12 additions & 5 deletions src/Lean/Meta/Tactic/Grind/Cases.lean
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,21 @@ private def getAlias? (value : Expr) : MetaM (Option Name) :=
else
return none

partial def isCasesAttrCandidate (declName : Name) (eager : Bool) : CoreM Bool := do
partial def isCasesAttrCandidate? (declName : Name) (eager : Bool) : CoreM (Option Name) := do
match (← getConstInfo declName) with
| .inductInfo info => return !info.isRec || !eager
| .inductInfo info => if !info.isRec || !eager then return some declName else return none
| .defnInfo info =>
let some declName ← getAlias? info.value |>.run' {} {}
| return false
isCasesAttrCandidate declName eager
| _ => return false
| return none
isCasesAttrCandidate? declName eager
| _ => return none

def isCasesAttrCandidate (declName : Name) (eager : Bool) : CoreM Bool := do
return (← isCasesAttrCandidate? declName eager).isSome

def isCasesAttrPredicateCandidate? (declName : Name) (eager : Bool) : MetaM (Option InductiveVal) := do
let some declName ← isCasesAttrCandidate? declName eager | return none
isInductivePredicate? declName

def validateCasesAttr (declName : Name) (eager : Bool) : CoreM Unit := do
unless (← isCasesAttrCandidate declName eager) do
Expand Down
56 changes: 26 additions & 30 deletions tests/lean/run/grind_constProp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,38 @@ def State.get (σ : State) (x : Var) : Val :=
section
attribute [local grind] State.update State.find? State.get State.erase

@[simp, grind =] theorem State.find?_update_self (σ : State) (x : Var) (v : Val) : (σ.update x v).find? x = some v := by
@[simp, grind =] theorem State.find?_nil (x : Var) : find? [] x = none := by
grind

@[simp] theorem State.find?_update_self (σ : State) (x : Var) (v : Val) : (σ.update x v).find? x = some v := by
induction σ, x, v using State.update.induct <;> grind

@[simp, grind =] theorem State.find?_update (σ : State) (v : Val) (h : x ≠ z) : (σ.update x v).find? z = σ.find? z := by
@[simp] theorem State.find?_update (σ : State) (v : Val) (h : x ≠ z) : (σ.update x v).find? z = σ.find? z := by
induction σ, x, v using State.update.induct <;> grind

@[grind =] theorem State.find?_update_eq (σ : State) (v : Val)
: (σ.update x v).find? z = if x = z then some v else σ.find? z := by
grind only [= find?_update_self, = find?_update, cases Or]

@[grind] theorem State.get_of_find? {σ : State} (h : σ.find? x = some v) : σ.get x = v := by
grind

@[simp, grind =] theorem State.find?_erase_self (σ : State) (x : Var) : (σ.erase x).find? x = none := by
@[simp] theorem State.find?_erase_self (σ : State) (x : Var) : (σ.erase x).find? x = none := by
induction σ, x using State.erase.induct <;> grind

@[simp, grind =] theorem State.find?_erase (σ : State) (h : x ≠ z) : (σ.erase x).find? z = σ.find? z := by
@[simp] theorem State.find?_erase (σ : State) (h : x ≠ z) : (σ.erase x).find? z = σ.find? z := by
induction σ, x using State.erase.induct <;> grind

@[simp, grind =] theorem State.find?_erase_eq (σ : State)
: (σ.erase x).find? z = if x = z then none else σ.find? z := by
grind only [= find?_erase_self, = find?_erase, cases Or]

@[grind] theorem State.length_erase_le (σ : State) (x : Var) : (σ.erase x).length ≤ σ.length := by
induction σ, x using erase.induct <;> grind

def State.length_erase_lt (σ : State) (x : Var) : (σ.erase x).length < σ.length.succ := by
grind

end

syntax ident " ↦ " term : term
Expand Down Expand Up @@ -206,9 +224,7 @@ def evalExpr (e : Expr) : EvalM Val := do
| c' => .while c' b.simplify

theorem Stmt.simplify_correct (h : (σ, s) ⇓ σ') : (σ, s.simplify) ⇓ σ' := by
-- TODO: we need a mechanism for saying we just want the intro rules
induction h <;> grind [=_ Expr.eval_simplify, Bigstep.skip, Bigstep.assign,
Bigstep.seq, Bigstep.whileFalse, Bigstep.whileTrue, Bigstep.ifTrue, Bigstep.ifFalse]
induction h <;> grind [=_ Expr.eval_simplify, intro Bigstep]

@[simp, grind =] def Expr.constProp (e : Expr) (σ : State) : Expr :=
match e with
Expand All @@ -220,13 +236,7 @@ theorem Stmt.simplify_correct (h : (σ, s) ⇓ σ') : (σ, s.simplify) ⇓ σ' :
| una op arg => una op (arg.constProp σ)

@[simp, grind =] theorem Expr.constProp_nil (e : Expr) : e.constProp [] = e := by
induction e <;> grind [State.find?] -- TODO add missing theorem(s) to avoid unfolding `find?`

@[grind] theorem State.length_erase_le (σ : State) (x : Var) : (σ.erase x).length ≤ σ.length := by
induction σ, x using erase.induct <;> grind [State.erase] -- TODO add missing theorem(s)

def State.length_erase_lt (σ : State) (x : Var) : (σ.erase x).length < σ.length.succ := by
grind
induction e <;> grind

@[simp, grind =] def State.join (σ₁ σ₂ : State) : State :=
match σ₁ with
Expand Down Expand Up @@ -308,25 +318,11 @@ theorem State.erase_le_of_le_cons (h : σ' ≼ (x, v) :: σ) : σ'.erase x ≼
grind

@[grind] theorem State.erase_le_update (h : σ' ≼ σ) : σ'.erase x ≼ σ.update x v := by
intro y w hf'
-- TODO: can we avoid this hint?
by_cases hxy : x = y <;> grind
grind

@[grind] theorem State.update_le_update (h : σ' ≼ σ) : σ'.update x v ≼ σ.update x v := by
intro y w hf
induction σ generalizing σ' hf with
| nil => grind
| cons zw' σ ih =>
have (z, w') := zw'; simp
have : σ'.erase z ≼ σ := erase_le_of_le_cons h
have ih := ih this
revert ih hf
split <;> simp [*] <;> by_cases hyz : y = z <;> simp (config := { contextual := true }) [*]
next => grind
next => grind
sorry
grind

-- TODO: we are missing theorems here, and cannot seal State functions
@[grind] theorem Expr.eval_constProp_of_sub (e : Expr) (h : σ' ≼ σ) : (e.constProp σ').eval σ = e.eval σ := by
induction e <;> grind

Expand Down

0 comments on commit 5900f39

Please sign in to comment.