Skip to content

Commit

Permalink
feat: cache and optimize mkCongrSimp? at simp
Browse files Browse the repository at this point in the history
see #988
  • Loading branch information
leodemoura committed Feb 8, 2022
1 parent 007f0e1 commit 9d34d9b
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 102 deletions.
75 changes: 37 additions & 38 deletions src/Lean/Meta/CongrTheorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,44 @@ private def hasCastLike (kinds : Array CongrArgKind) : Bool :=
private def withNext (type : Expr) (k : Expr → Expr → MetaM α) : MetaM α := do
forallBoundedTelescope type (some 1) fun xs type => k xs[0] type

/--
Test whether we should use `subsingletonInst` kind for instances which depend on `eq`.
(Otherwise `fixKindsForDependencies`will downgrade them to Fixed -/
private def shouldUseSubsingletonInst (info : FunInfo) (kinds : Array CongrArgKind) (i : Nat) : Bool := Id.run do
if info.paramInfo[i].isDecInst then
for j in info.paramInfo[i].backDeps do
if kinds[j] matches CongrArgKind.eq then
return true
return false

def getCongrSimpKinds (info : FunInfo) : Array CongrArgKind := Id.run do
/- The default `CongrArgKind` is `eq`, which allows `simp` to rewrite this
argument. However, if there are references from `i` to `j`, we cannot
rewrite both `i` and `j`. So we must change the `CongrArgKind` at
either `i` or `j`. In principle, if there is a dependency with `i`
appearing after `j`, then we set `j` to `fixed` (or `cast`). But there is
an optimization: if `i` is a subsingleton, we can fix it instead of
`j`, since all subsingletons are equal anyway. The fixing happens in
two loops: one for the special cases, and one for the general case. -/
let mut result := #[]
for i in [:info.paramInfo.size] do
if info.resultDeps.contains i then
result := result.push CongrArgKind.fixed
else if info.paramInfo[i].isProp then
result := result.push CongrArgKind.cast
else if info.paramInfo[i].isInstImplicit then
if shouldUseSubsingletonInst info result i then
result := result.push CongrArgKind.subsingletonInst
else
result := result.push CongrArgKind.fixed
else
result := result.push CongrArgKind.eq
return fixKindsForDependencies info result

/--
Create a congruence theorem that is useful for the simplifier.
-/
partial def mkCongrSimpWithArity? (f : Expr) (numArgs : Nat) : MetaM (Option CongrTheorem) := do
let info ← getFunInfo f
let kinds := getKinds info
partial def mkCongrSimpCore? (f : Expr) (info : FunInfo) (kinds : Array CongrArgKind) : MetaM (Option CongrTheorem) := do
if let some result ← mk? f info kinds then
return some result
else if hasCastLike kinds then
Expand Down Expand Up @@ -246,41 +278,8 @@ where
mkLambdaFVars #[lhs, rhs] (← mkEqNDRec motive proofSub heq)
go 0 type

getKinds (info : FunInfo) : Array CongrArgKind := Id.run do
/- The default `CongrArgKind` is `eq`, which allows `simp` to rewrite this
argument. However, if there are references from `i` to `j`, we cannot
rewrite both `i` and `j`. So we must change the `CongrArgKind` at
either `i` or `j`. In principle, if there is a dependency with `i`
appearing after `j`, then we set `j` to `fixed` (or `cast`). But there is
an optimization: if `i` is a subsingleton, we can fix it instead of
`j`, since all subsingletons are equal anyway. The fixing happens in
two loops: one for the special cases, and one for the general case. -/
let mut result := #[]
for i in [:info.paramInfo.size] do
if info.resultDeps.contains i then
result := result.push CongrArgKind.fixed
else if info.paramInfo[i].isProp then
result := result.push CongrArgKind.cast
else if info.paramInfo[i].isInstImplicit then
if shouldUseSubsingletonInst info result i then
result := result.push CongrArgKind.subsingletonInst
else
result := result.push CongrArgKind.fixed
else
result := result.push CongrArgKind.eq
return fixKindsForDependencies info result

/--
Test whether we should use `subsingletonInst` kind for instances which depend on `eq`.
(Otherwise `fixKindsForDependencies`will downgrade them to Fixed -/
shouldUseSubsingletonInst (info : FunInfo) (kinds : Array CongrArgKind) (i : Nat) : Bool := Id.run do
if info.paramInfo[i].isDecInst then
for j in info.paramInfo[i].backDeps do
if kinds[j] matches CongrArgKind.eq then
return true
return false

def mkCongrSimp? (f : Expr) : MetaM (Option CongrTheorem) := do
mkCongrSimpWithArity? f (← getFunInfo f).getArity
let info ← getFunInfo f
mkCongrSimpCore? f info (getCongrSimpKinds info)

end Lean.Meta
17 changes: 14 additions & 3 deletions src/Lean/Meta/Tactic/Simp/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,26 @@ where
proof ← Meta.mkCongrFun proof arg
return { expr := eNew, proof? := proof }

mkCongrSimp? (f : Expr) : M (Option CongrTheorem) := do
let info ← getFunInfo f
let kinds := getCongrSimpKinds info
if kinds.all fun k => match k with | CongrArgKind.fixed => true | CongrArgKind.eq => true | _ => false then
/- If all argument kinds are `fixed` or `eq`, then using
simple congruence theorems `congr`, `congrArg`, and `congrFun` produces a more compact proof -/
return none
match (← get).congrCache.find? f with
| some thm? => return thm?
| none =>
let thm? ← mkCongrSimpCore? f info kinds
modify fun s => { s with congrCache := s.congrCache.insert f thm? }
return thm?

/-- Try to use automatically generated congruence theorems. See `mkCongrSimp?`. -/
tryAutoCongrTheorem? (e : Expr) : M (Option Result) := do
if (← isMatcherApp e) then return none
let f := e.getAppFn
-- TODO: cache
let some cgrThm ← mkCongrSimp? f | return none
if cgrThm.argKinds.all fun k => match k with | CongrArgKind.fixed => true | CongrArgKind.eq => true | _ => false then
-- If all argument kinds are `fixed` or `eq`, then using simple congruence theorems `congr`, `congrArg`, and `congrFun` produces a more compact proof
return none
if cgrThm.argKinds.size != e.getAppNumArgs then return none
let mut simplified := false
let mut hasProof := false
Expand Down
8 changes: 6 additions & 2 deletions src/Lean/Meta/Tactic/Simp/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.AppBuilder
import Lean.Meta.CongrTheorems
import Lean.Meta.Tactic.Simp.SimpTheorems
import Lean.Meta.Tactic.Simp.SimpCongrTheorems

Expand All @@ -17,6 +18,8 @@ structure Result where

abbrev Cache := ExprMap Result

abbrev CongrCache := ExprMap (Option CongrTheorem)

structure Context where
config : Config := {}
simpTheorems : SimpTheorems := {}
Expand All @@ -29,8 +32,9 @@ def Context.mkDefault : MetaM Context :=
return { config := {}, simpTheorems := (← getSimpTheorems), congrTheorems := (← getSimpCongrTheorems) }

structure State where
cache : Cache := {}
numSteps : Nat := 0
cache : Cache := {}
congrCache : CongrCache := {}
numSteps : Nat := 0

abbrev SimpM := ReaderT Context $ StateRefT State MetaM

Expand Down
Loading

0 comments on commit 9d34d9b

Please sign in to comment.